Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer
 12from sqlglot.trie import new_trie
 13
 14E = t.TypeVar("E", bound=exp.Expression)
 15
 16
 17class Dialects(str, Enum):
 18    DIALECT = ""
 19
 20    BIGQUERY = "bigquery"
 21    CLICKHOUSE = "clickhouse"
 22    DUCKDB = "duckdb"
 23    HIVE = "hive"
 24    MYSQL = "mysql"
 25    ORACLE = "oracle"
 26    POSTGRES = "postgres"
 27    PRESTO = "presto"
 28    REDSHIFT = "redshift"
 29    SNOWFLAKE = "snowflake"
 30    SPARK = "spark"
 31    SQLITE = "sqlite"
 32    STARROCKS = "starrocks"
 33    TABLEAU = "tableau"
 34    TRINO = "trino"
 35    TSQL = "tsql"
 36    DATABRICKS = "databricks"
 37    DRILL = "drill"
 38    TERADATA = "teradata"
 39
 40
 41class _Dialect(type):
 42    classes: t.Dict[str, t.Type[Dialect]] = {}
 43
 44    @classmethod
 45    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 46        return cls.classes[key]
 47
 48    @classmethod
 49    def get(
 50        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 51    ) -> t.Optional[t.Type[Dialect]]:
 52        return cls.classes.get(key, default)
 53
 54    def __new__(cls, clsname, bases, attrs):
 55        klass = super().__new__(cls, clsname, bases, attrs)
 56        enum = Dialects.__members__.get(clsname.upper())
 57        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 58
 59        klass.time_trie = new_trie(klass.time_mapping)
 60        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 61        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 62
 63        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 64        klass.parser_class = getattr(klass, "Parser", Parser)
 65        klass.generator_class = getattr(klass, "Generator", Generator)
 66
 67        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 68        klass.identifier_start, klass.identifier_end = list(
 69            klass.tokenizer_class._IDENTIFIERS.items()
 70        )[0]
 71
 72        if (
 73            klass.tokenizer_class._BIT_STRINGS
 74            and exp.BitString not in klass.generator_class.TRANSFORMS
 75        ):
 76            bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
 77            klass.generator_class.TRANSFORMS[
 78                exp.BitString
 79            ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
 80        if (
 81            klass.tokenizer_class._HEX_STRINGS
 82            and exp.HexString not in klass.generator_class.TRANSFORMS
 83        ):
 84            hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
 85            klass.generator_class.TRANSFORMS[
 86                exp.HexString
 87            ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
 88        if (
 89            klass.tokenizer_class._BYTE_STRINGS
 90            and exp.ByteString not in klass.generator_class.TRANSFORMS
 91        ):
 92            be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
 93            klass.generator_class.TRANSFORMS[
 94                exp.ByteString
 95            ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
 96
 97        return klass
 98
 99
100class Dialect(metaclass=_Dialect):
101    index_offset = 0
102    unnest_column_only = False
103    alias_post_tablesample = False
104    normalize_functions: t.Optional[str] = "upper"
105    null_ordering = "nulls_are_small"
106
107    date_format = "'%Y-%m-%d'"
108    dateint_format = "'%Y%m%d'"
109    time_format = "'%Y-%m-%d %H:%M:%S'"
110    time_mapping: t.Dict[str, str] = {}
111
112    # autofilled
113    quote_start = None
114    quote_end = None
115    identifier_start = None
116    identifier_end = None
117
118    time_trie = None
119    inverse_time_mapping = None
120    inverse_time_trie = None
121    tokenizer_class = None
122    parser_class = None
123    generator_class = None
124
125    @classmethod
126    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
127        if not dialect:
128            return cls
129        if isinstance(dialect, _Dialect):
130            return dialect
131        if isinstance(dialect, Dialect):
132            return dialect.__class__
133
134        result = cls.get(dialect)
135        if not result:
136            raise ValueError(f"Unknown dialect '{dialect}'")
137
138        return result
139
140    @classmethod
141    def format_time(
142        cls, expression: t.Optional[str | exp.Expression]
143    ) -> t.Optional[exp.Expression]:
144        if isinstance(expression, str):
145            return exp.Literal.string(
146                format_time(
147                    expression[1:-1],  # the time formats are quoted
148                    cls.time_mapping,
149                    cls.time_trie,
150                )
151            )
152        if expression and expression.is_string:
153            return exp.Literal.string(
154                format_time(
155                    expression.this,
156                    cls.time_mapping,
157                    cls.time_trie,
158                )
159            )
160        return expression
161
162    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
163        return self.parser(**opts).parse(self.tokenize(sql), sql)
164
165    def parse_into(
166        self, expression_type: exp.IntoType, sql: str, **opts
167    ) -> t.List[t.Optional[exp.Expression]]:
168        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
169
170    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
171        return self.generator(**opts).generate(expression)
172
173    def transpile(self, sql: str, **opts) -> t.List[str]:
174        return [self.generate(expression, **opts) for expression in self.parse(sql)]
175
176    def tokenize(self, sql: str) -> t.List[Token]:
177        return self.tokenizer.tokenize(sql)
178
179    @property
180    def tokenizer(self) -> Tokenizer:
181        if not hasattr(self, "_tokenizer"):
182            self._tokenizer = self.tokenizer_class()  # type: ignore
183        return self._tokenizer
184
185    def parser(self, **opts) -> Parser:
186        return self.parser_class(  # type: ignore
187            **{
188                "index_offset": self.index_offset,
189                "unnest_column_only": self.unnest_column_only,
190                "alias_post_tablesample": self.alias_post_tablesample,
191                "null_ordering": self.null_ordering,
192                **opts,
193            },
194        )
195
196    def generator(self, **opts) -> Generator:
197        return self.generator_class(  # type: ignore
198            **{
199                "quote_start": self.quote_start,
200                "quote_end": self.quote_end,
201                "identifier_start": self.identifier_start,
202                "identifier_end": self.identifier_end,
203                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
204                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
205                "index_offset": self.index_offset,
206                "time_mapping": self.inverse_time_mapping,
207                "time_trie": self.inverse_time_trie,
208                "unnest_column_only": self.unnest_column_only,
209                "alias_post_tablesample": self.alias_post_tablesample,
210                "normalize_functions": self.normalize_functions,
211                "null_ordering": self.null_ordering,
212                **opts,
213            }
214        )
215
216
217DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
218
219
220def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
221    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
222
223
224def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
225    if expression.args.get("accuracy"):
226        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
227    return self.func("APPROX_COUNT_DISTINCT", expression.this)
228
229
230def if_sql(self: Generator, expression: exp.If) -> str:
231    return self.func(
232        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
233    )
234
235
236def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
237    return self.binary(expression, "->")
238
239
240def arrow_json_extract_scalar_sql(
241    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
242) -> str:
243    return self.binary(expression, "->>")
244
245
246def inline_array_sql(self: Generator, expression: exp.Array) -> str:
247    return f"[{self.expressions(expression)}]"
248
249
250def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
251    return self.like_sql(
252        exp.Like(
253            this=exp.Lower(this=expression.this),
254            expression=expression.args["expression"],
255        )
256    )
257
258
259def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
260    zone = self.sql(expression, "this")
261    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
262
263
264def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
265    if expression.args.get("recursive"):
266        self.unsupported("Recursive CTEs are unsupported")
267        expression.args["recursive"] = False
268    return self.with_sql(expression)
269
270
271def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
272    n = self.sql(expression, "this")
273    d = self.sql(expression, "expression")
274    return f"IF({d} <> 0, {n} / {d}, NULL)"
275
276
277def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
278    self.unsupported("TABLESAMPLE unsupported")
279    return self.sql(expression.this)
280
281
282def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
283    self.unsupported("PIVOT unsupported")
284    return self.sql(expression)
285
286
287def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
288    return self.cast_sql(expression)
289
290
291def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
292    self.unsupported("Properties unsupported")
293    return ""
294
295
296def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
297    this = self.sql(expression, "this")
298    substr = self.sql(expression, "substr")
299    position = self.sql(expression, "position")
300    if position:
301        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
302    return f"STRPOS({this}, {substr})"
303
304
305def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
306    this = self.sql(expression, "this")
307    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
308    return f"{this}.{struct_key}"
309
310
311def var_map_sql(
312    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
313) -> str:
314    keys = expression.args["keys"]
315    values = expression.args["values"]
316
317    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
318        self.unsupported("Cannot convert array columns into map.")
319        return self.func(map_func_name, keys, values)
320
321    args = []
322    for key, value in zip(keys.expressions, values.expressions):
323        args.append(self.sql(key))
324        args.append(self.sql(value))
325    return self.func(map_func_name, *args)
326
327
328def format_time_lambda(
329    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
330) -> t.Callable[[t.Sequence], E]:
331    """Helper used for time expressions.
332
333    Args:
334        exp_class: the expression class to instantiate.
335        dialect: target sql dialect.
336        default: the default format, True being time.
337
338    Returns:
339        A callable that can be used to return the appropriately formatted time expression.
340    """
341
342    def _format_time(args: t.Sequence):
343        return exp_class(
344            this=seq_get(args, 0),
345            format=Dialect[dialect].format_time(
346                seq_get(args, 1)
347                or (Dialect[dialect].time_format if default is True else default or None)
348            ),
349        )
350
351    return _format_time
352
353
354def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
355    """
356    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
357    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
358    columns are removed from the create statement.
359    """
360    has_schema = isinstance(expression.this, exp.Schema)
361    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
362
363    if has_schema and is_partitionable:
364        expression = expression.copy()
365        prop = expression.find(exp.PartitionedByProperty)
366        if prop and prop.this and not isinstance(prop.this, exp.Schema):
367            schema = expression.this
368            columns = {v.name.upper() for v in prop.this.expressions}
369            partitions = [col for col in schema.expressions if col.name.upper() in columns]
370            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
371            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
372            expression.set("this", schema)
373
374    return self.create_sql(expression)
375
376
377def parse_date_delta(
378    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
379) -> t.Callable[[t.Sequence], E]:
380    def inner_func(args: t.Sequence) -> E:
381        unit_based = len(args) == 3
382        this = seq_get(args, 2) if unit_based else seq_get(args, 0)
383        expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
384        unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
385        unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit  # type: ignore
386        return exp_class(this=this, expression=expression, unit=unit)
387
388    return inner_func
389
390
391def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
392    unit = seq_get(args, 0)
393    this = seq_get(args, 1)
394
395    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
396        return exp.DateTrunc(unit=unit, this=this)
397    return exp.TimestampTrunc(this=this, unit=unit)
398
399
400def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
401    return self.func(
402        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
403    )
404
405
406def locate_to_strposition(args: t.Sequence) -> exp.Expression:
407    return exp.StrPosition(
408        this=seq_get(args, 1),
409        substr=seq_get(args, 0),
410        position=seq_get(args, 2),
411    )
412
413
414def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
415    return self.func(
416        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
417    )
418
419
420def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
421    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
422
423
424def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
425    return f"CAST({self.sql(expression, 'this')} AS DATE)"
426
427
428def min_or_least(self: Generator, expression: exp.Min) -> str:
429    name = "LEAST" if expression.expressions else "MIN"
430    return rename_func(name)(self, expression)
431
432
433def max_or_greatest(self: Generator, expression: exp.Max) -> str:
434    name = "GREATEST" if expression.expressions else "MAX"
435    return rename_func(name)(self, expression)
436
437
438def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
439    cond = expression.this
440
441    if isinstance(expression.this, exp.Distinct):
442        cond = expression.this.expressions[0]
443        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
444
445    return self.func("sum", exp.func("if", cond, 1, 0))
446
447
448def trim_sql(self: Generator, expression: exp.Trim) -> str:
449    target = self.sql(expression, "this")
450    trim_type = self.sql(expression, "position")
451    remove_chars = self.sql(expression, "expression")
452    collation = self.sql(expression, "collation")
453
454    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
455    if not remove_chars and not collation:
456        return self.trim_sql(expression)
457
458    trim_type = f"{trim_type} " if trim_type else ""
459    remove_chars = f"{remove_chars} " if remove_chars else ""
460    from_part = "FROM " if trim_type or remove_chars else ""
461    collation = f" COLLATE {collation}" if collation else ""
462    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
463
464
465def str_to_time_sql(self, expression: exp.Expression) -> str:
466    return self.func("STRPTIME", expression.this, self.format_time(expression))
467
468
469def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
470    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
471        _dialect = Dialect.get_or_raise(dialect)
472        time_format = self.format_time(expression)
473        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
474            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
475        return f"CAST({self.sql(expression, 'this')} AS DATE)"
476
477    return _ts_or_ds_to_date_sql
class Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum):
19    DIALECT = ""
20
21    BIGQUERY = "bigquery"
22    CLICKHOUSE = "clickhouse"
23    DUCKDB = "duckdb"
24    HIVE = "hive"
25    MYSQL = "mysql"
26    ORACLE = "oracle"
27    POSTGRES = "postgres"
28    PRESTO = "presto"
29    REDSHIFT = "redshift"
30    SNOWFLAKE = "snowflake"
31    SPARK = "spark"
32    SQLITE = "sqlite"
33    STARROCKS = "starrocks"
34    TABLEAU = "tableau"
35    TRINO = "trino"
36    TSQL = "tsql"
37    DATABRICKS = "databricks"
38    DRILL = "drill"
39    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
101class Dialect(metaclass=_Dialect):
102    index_offset = 0
103    unnest_column_only = False
104    alias_post_tablesample = False
105    normalize_functions: t.Optional[str] = "upper"
106    null_ordering = "nulls_are_small"
107
108    date_format = "'%Y-%m-%d'"
109    dateint_format = "'%Y%m%d'"
110    time_format = "'%Y-%m-%d %H:%M:%S'"
111    time_mapping: t.Dict[str, str] = {}
112
113    # autofilled
114    quote_start = None
115    quote_end = None
116    identifier_start = None
117    identifier_end = None
118
119    time_trie = None
120    inverse_time_mapping = None
121    inverse_time_trie = None
122    tokenizer_class = None
123    parser_class = None
124    generator_class = None
125
126    @classmethod
127    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
128        if not dialect:
129            return cls
130        if isinstance(dialect, _Dialect):
131            return dialect
132        if isinstance(dialect, Dialect):
133            return dialect.__class__
134
135        result = cls.get(dialect)
136        if not result:
137            raise ValueError(f"Unknown dialect '{dialect}'")
138
139        return result
140
141    @classmethod
142    def format_time(
143        cls, expression: t.Optional[str | exp.Expression]
144    ) -> t.Optional[exp.Expression]:
145        if isinstance(expression, str):
146            return exp.Literal.string(
147                format_time(
148                    expression[1:-1],  # the time formats are quoted
149                    cls.time_mapping,
150                    cls.time_trie,
151                )
152            )
153        if expression and expression.is_string:
154            return exp.Literal.string(
155                format_time(
156                    expression.this,
157                    cls.time_mapping,
158                    cls.time_trie,
159                )
160            )
161        return expression
162
163    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
164        return self.parser(**opts).parse(self.tokenize(sql), sql)
165
166    def parse_into(
167        self, expression_type: exp.IntoType, sql: str, **opts
168    ) -> t.List[t.Optional[exp.Expression]]:
169        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
170
171    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
172        return self.generator(**opts).generate(expression)
173
174    def transpile(self, sql: str, **opts) -> t.List[str]:
175        return [self.generate(expression, **opts) for expression in self.parse(sql)]
176
177    def tokenize(self, sql: str) -> t.List[Token]:
178        return self.tokenizer.tokenize(sql)
179
180    @property
181    def tokenizer(self) -> Tokenizer:
182        if not hasattr(self, "_tokenizer"):
183            self._tokenizer = self.tokenizer_class()  # type: ignore
184        return self._tokenizer
185
186    def parser(self, **opts) -> Parser:
187        return self.parser_class(  # type: ignore
188            **{
189                "index_offset": self.index_offset,
190                "unnest_column_only": self.unnest_column_only,
191                "alias_post_tablesample": self.alias_post_tablesample,
192                "null_ordering": self.null_ordering,
193                **opts,
194            },
195        )
196
197    def generator(self, **opts) -> Generator:
198        return self.generator_class(  # type: ignore
199            **{
200                "quote_start": self.quote_start,
201                "quote_end": self.quote_end,
202                "identifier_start": self.identifier_start,
203                "identifier_end": self.identifier_end,
204                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
205                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
206                "index_offset": self.index_offset,
207                "time_mapping": self.inverse_time_mapping,
208                "time_trie": self.inverse_time_trie,
209                "unnest_column_only": self.unnest_column_only,
210                "alias_post_tablesample": self.alias_post_tablesample,
211                "normalize_functions": self.normalize_functions,
212                "null_ordering": self.null_ordering,
213                **opts,
214            }
215        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
126    @classmethod
127    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
128        if not dialect:
129            return cls
130        if isinstance(dialect, _Dialect):
131            return dialect
132        if isinstance(dialect, Dialect):
133            return dialect.__class__
134
135        result = cls.get(dialect)
136        if not result:
137            raise ValueError(f"Unknown dialect '{dialect}'")
138
139        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
141    @classmethod
142    def format_time(
143        cls, expression: t.Optional[str | exp.Expression]
144    ) -> t.Optional[exp.Expression]:
145        if isinstance(expression, str):
146            return exp.Literal.string(
147                format_time(
148                    expression[1:-1],  # the time formats are quoted
149                    cls.time_mapping,
150                    cls.time_trie,
151                )
152            )
153        if expression and expression.is_string:
154            return exp.Literal.string(
155                format_time(
156                    expression.this,
157                    cls.time_mapping,
158                    cls.time_trie,
159                )
160            )
161        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
163    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
164        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
166    def parse_into(
167        self, expression_type: exp.IntoType, sql: str, **opts
168    ) -> t.List[t.Optional[exp.Expression]]:
169        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
171    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
172        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
174    def transpile(self, sql: str, **opts) -> t.List[str]:
175        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
177    def tokenize(self, sql: str) -> t.List[Token]:
178        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
186    def parser(self, **opts) -> Parser:
187        return self.parser_class(  # type: ignore
188            **{
189                "index_offset": self.index_offset,
190                "unnest_column_only": self.unnest_column_only,
191                "alias_post_tablesample": self.alias_post_tablesample,
192                "null_ordering": self.null_ordering,
193                **opts,
194            },
195        )
def generator(self, **opts) -> sqlglot.generator.Generator:
197    def generator(self, **opts) -> Generator:
198        return self.generator_class(  # type: ignore
199            **{
200                "quote_start": self.quote_start,
201                "quote_end": self.quote_end,
202                "identifier_start": self.identifier_start,
203                "identifier_end": self.identifier_end,
204                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
205                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
206                "index_offset": self.index_offset,
207                "time_mapping": self.inverse_time_mapping,
208                "time_trie": self.inverse_time_trie,
209                "unnest_column_only": self.unnest_column_only,
210                "alias_post_tablesample": self.alias_post_tablesample,
211                "normalize_functions": self.normalize_functions,
212                "null_ordering": self.null_ordering,
213                **opts,
214            }
215        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
221def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
222    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
225def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
226    if expression.args.get("accuracy"):
227        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
228    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
231def if_sql(self: Generator, expression: exp.If) -> str:
232    return self.func(
233        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
234    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
237def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
238    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
241def arrow_json_extract_scalar_sql(
242    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
243) -> str:
244    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
247def inline_array_sql(self: Generator, expression: exp.Array) -> str:
248    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
251def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
252    return self.like_sql(
253        exp.Like(
254            this=exp.Lower(this=expression.this),
255            expression=expression.args["expression"],
256        )
257    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
260def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
261    zone = self.sql(expression, "this")
262    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
265def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
266    if expression.args.get("recursive"):
267        self.unsupported("Recursive CTEs are unsupported")
268        expression.args["recursive"] = False
269    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
272def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
273    n = self.sql(expression, "this")
274    d = self.sql(expression, "expression")
275    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
278def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
279    self.unsupported("TABLESAMPLE unsupported")
280    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
283def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
284    self.unsupported("PIVOT unsupported")
285    return self.sql(expression)
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
288def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
289    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
292def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
293    self.unsupported("Properties unsupported")
294    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
297def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
298    this = self.sql(expression, "this")
299    substr = self.sql(expression, "substr")
300    position = self.sql(expression, "position")
301    if position:
302        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
303    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
306def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
307    this = self.sql(expression, "this")
308    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
309    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
312def var_map_sql(
313    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
314) -> str:
315    keys = expression.args["keys"]
316    values = expression.args["values"]
317
318    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
319        self.unsupported("Cannot convert array columns into map.")
320        return self.func(map_func_name, keys, values)
321
322    args = []
323    for key, value in zip(keys.expressions, values.expressions):
324        args.append(self.sql(key))
325        args.append(self.sql(value))
326    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
329def format_time_lambda(
330    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
331) -> t.Callable[[t.Sequence], E]:
332    """Helper used for time expressions.
333
334    Args:
335        exp_class: the expression class to instantiate.
336        dialect: target sql dialect.
337        default: the default format, True being time.
338
339    Returns:
340        A callable that can be used to return the appropriately formatted time expression.
341    """
342
343    def _format_time(args: t.Sequence):
344        return exp_class(
345            this=seq_get(args, 0),
346            format=Dialect[dialect].format_time(
347                seq_get(args, 1)
348                or (Dialect[dialect].time_format if default is True else default or None)
349            ),
350        )
351
352    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
355def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
356    """
357    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
358    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
359    columns are removed from the create statement.
360    """
361    has_schema = isinstance(expression.this, exp.Schema)
362    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
363
364    if has_schema and is_partitionable:
365        expression = expression.copy()
366        prop = expression.find(exp.PartitionedByProperty)
367        if prop and prop.this and not isinstance(prop.this, exp.Schema):
368            schema = expression.this
369            columns = {v.name.upper() for v in prop.this.expressions}
370            partitions = [col for col in schema.expressions if col.name.upper() in columns]
371            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
372            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
373            expression.set("this", schema)
374
375    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
378def parse_date_delta(
379    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
380) -> t.Callable[[t.Sequence], E]:
381    def inner_func(args: t.Sequence) -> E:
382        unit_based = len(args) == 3
383        this = seq_get(args, 2) if unit_based else seq_get(args, 0)
384        expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
385        unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
386        unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit  # type: ignore
387        return exp_class(this=this, expression=expression, unit=unit)
388
389    return inner_func
def date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
392def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
393    unit = seq_get(args, 0)
394    this = seq_get(args, 1)
395
396    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
397        return exp.DateTrunc(unit=unit, this=this)
398    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
401def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
402    return self.func(
403        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
404    )
def locate_to_strposition(args: Sequence) -> sqlglot.expressions.Expression:
407def locate_to_strposition(args: t.Sequence) -> exp.Expression:
408    return exp.StrPosition(
409        this=seq_get(args, 1),
410        substr=seq_get(args, 0),
411        position=seq_get(args, 2),
412    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
415def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
416    return self.func(
417        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
418    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
421def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
422    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
425def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
426    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
429def min_or_least(self: Generator, expression: exp.Min) -> str:
430    name = "LEAST" if expression.expressions else "MIN"
431    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
434def max_or_greatest(self: Generator, expression: exp.Max) -> str:
435    name = "GREATEST" if expression.expressions else "MAX"
436    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
439def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
440    cond = expression.this
441
442    if isinstance(expression.this, exp.Distinct):
443        cond = expression.this.expressions[0]
444        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
445
446    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
449def trim_sql(self: Generator, expression: exp.Trim) -> str:
450    target = self.sql(expression, "this")
451    trim_type = self.sql(expression, "position")
452    remove_chars = self.sql(expression, "expression")
453    collation = self.sql(expression, "collation")
454
455    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
456    if not remove_chars and not collation:
457        return self.trim_sql(expression)
458
459    trim_type = f"{trim_type} " if trim_type else ""
460    remove_chars = f"{remove_chars} " if remove_chars else ""
461    from_part = "FROM " if trim_type or remove_chars else ""
462    collation = f" COLLATE {collation}" if collation else ""
463    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql(self, expression: sqlglot.expressions.Expression) -> str:
466def str_to_time_sql(self, expression: exp.Expression) -> str:
467    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
470def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
471    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
472        _dialect = Dialect.get_or_raise(dialect)
473        time_format = self.format_time(expression)
474        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
475            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
476        return f"CAST({self.sql(expression, 'this')} AS DATE)"
477
478    return _ts_or_ds_to_date_sql