Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer
 12from sqlglot.trie import new_trie
 13
 14E = t.TypeVar("E", bound=exp.Expression)
 15
 16
 17class Dialects(str, Enum):
 18    DIALECT = ""
 19
 20    BIGQUERY = "bigquery"
 21    CLICKHOUSE = "clickhouse"
 22    DUCKDB = "duckdb"
 23    HIVE = "hive"
 24    MYSQL = "mysql"
 25    ORACLE = "oracle"
 26    POSTGRES = "postgres"
 27    PRESTO = "presto"
 28    REDSHIFT = "redshift"
 29    SNOWFLAKE = "snowflake"
 30    SPARK = "spark"
 31    SPARK2 = "spark2"
 32    SQLITE = "sqlite"
 33    STARROCKS = "starrocks"
 34    TABLEAU = "tableau"
 35    TRINO = "trino"
 36    TSQL = "tsql"
 37    DATABRICKS = "databricks"
 38    DRILL = "drill"
 39    TERADATA = "teradata"
 40
 41
 42class _Dialect(type):
 43    classes: t.Dict[str, t.Type[Dialect]] = {}
 44
 45    @classmethod
 46    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 47        return cls.classes[key]
 48
 49    @classmethod
 50    def get(
 51        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 52    ) -> t.Optional[t.Type[Dialect]]:
 53        return cls.classes.get(key, default)
 54
 55    def __new__(cls, clsname, bases, attrs):
 56        klass = super().__new__(cls, clsname, bases, attrs)
 57        enum = Dialects.__members__.get(clsname.upper())
 58        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 59
 60        klass.time_trie = new_trie(klass.time_mapping)
 61        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 62        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 63
 64        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 65        klass.parser_class = getattr(klass, "Parser", Parser)
 66        klass.generator_class = getattr(klass, "Generator", Generator)
 67
 68        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 69        klass.identifier_start, klass.identifier_end = list(
 70            klass.tokenizer_class._IDENTIFIERS.items()
 71        )[0]
 72
 73        if (
 74            klass.tokenizer_class._BIT_STRINGS
 75            and exp.BitString not in klass.generator_class.TRANSFORMS
 76        ):
 77            bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
 78            klass.generator_class.TRANSFORMS[
 79                exp.BitString
 80            ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
 81        if (
 82            klass.tokenizer_class._HEX_STRINGS
 83            and exp.HexString not in klass.generator_class.TRANSFORMS
 84        ):
 85            hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
 86            klass.generator_class.TRANSFORMS[
 87                exp.HexString
 88            ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
 89        if (
 90            klass.tokenizer_class._BYTE_STRINGS
 91            and exp.ByteString not in klass.generator_class.TRANSFORMS
 92        ):
 93            be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
 94            klass.generator_class.TRANSFORMS[
 95                exp.ByteString
 96            ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
 97
 98        return klass
 99
100
101class Dialect(metaclass=_Dialect):
102    index_offset = 0
103    unnest_column_only = False
104    alias_post_tablesample = False
105    normalize_functions: t.Optional[str] = "upper"
106    null_ordering = "nulls_are_small"
107
108    date_format = "'%Y-%m-%d'"
109    dateint_format = "'%Y%m%d'"
110    time_format = "'%Y-%m-%d %H:%M:%S'"
111    time_mapping: t.Dict[str, str] = {}
112
113    # autofilled
114    quote_start = None
115    quote_end = None
116    identifier_start = None
117    identifier_end = None
118
119    time_trie = None
120    inverse_time_mapping = None
121    inverse_time_trie = None
122    tokenizer_class = None
123    parser_class = None
124    generator_class = None
125
126    @classmethod
127    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
128        if not dialect:
129            return cls
130        if isinstance(dialect, _Dialect):
131            return dialect
132        if isinstance(dialect, Dialect):
133            return dialect.__class__
134
135        result = cls.get(dialect)
136        if not result:
137            raise ValueError(f"Unknown dialect '{dialect}'")
138
139        return result
140
141    @classmethod
142    def format_time(
143        cls, expression: t.Optional[str | exp.Expression]
144    ) -> t.Optional[exp.Expression]:
145        if isinstance(expression, str):
146            return exp.Literal.string(
147                format_time(
148                    expression[1:-1],  # the time formats are quoted
149                    cls.time_mapping,
150                    cls.time_trie,
151                )
152            )
153        if expression and expression.is_string:
154            return exp.Literal.string(
155                format_time(
156                    expression.this,
157                    cls.time_mapping,
158                    cls.time_trie,
159                )
160            )
161        return expression
162
163    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
164        return self.parser(**opts).parse(self.tokenize(sql), sql)
165
166    def parse_into(
167        self, expression_type: exp.IntoType, sql: str, **opts
168    ) -> t.List[t.Optional[exp.Expression]]:
169        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
170
171    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
172        return self.generator(**opts).generate(expression)
173
174    def transpile(self, sql: str, **opts) -> t.List[str]:
175        return [self.generate(expression, **opts) for expression in self.parse(sql)]
176
177    def tokenize(self, sql: str) -> t.List[Token]:
178        return self.tokenizer.tokenize(sql)
179
180    @property
181    def tokenizer(self) -> Tokenizer:
182        if not hasattr(self, "_tokenizer"):
183            self._tokenizer = self.tokenizer_class()  # type: ignore
184        return self._tokenizer
185
186    def parser(self, **opts) -> Parser:
187        return self.parser_class(  # type: ignore
188            **{
189                "index_offset": self.index_offset,
190                "unnest_column_only": self.unnest_column_only,
191                "alias_post_tablesample": self.alias_post_tablesample,
192                "null_ordering": self.null_ordering,
193                **opts,
194            },
195        )
196
197    def generator(self, **opts) -> Generator:
198        return self.generator_class(  # type: ignore
199            **{
200                "quote_start": self.quote_start,
201                "quote_end": self.quote_end,
202                "identifier_start": self.identifier_start,
203                "identifier_end": self.identifier_end,
204                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
205                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
206                "index_offset": self.index_offset,
207                "time_mapping": self.inverse_time_mapping,
208                "time_trie": self.inverse_time_trie,
209                "unnest_column_only": self.unnest_column_only,
210                "alias_post_tablesample": self.alias_post_tablesample,
211                "normalize_functions": self.normalize_functions,
212                "null_ordering": self.null_ordering,
213                **opts,
214            }
215        )
216
217
218DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
219
220
221def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
222    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
223
224
225def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
226    if expression.args.get("accuracy"):
227        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
228    return self.func("APPROX_COUNT_DISTINCT", expression.this)
229
230
231def if_sql(self: Generator, expression: exp.If) -> str:
232    return self.func(
233        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
234    )
235
236
237def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
238    return self.binary(expression, "->")
239
240
241def arrow_json_extract_scalar_sql(
242    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
243) -> str:
244    return self.binary(expression, "->>")
245
246
247def inline_array_sql(self: Generator, expression: exp.Array) -> str:
248    return f"[{self.expressions(expression)}]"
249
250
251def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
252    return self.like_sql(
253        exp.Like(
254            this=exp.Lower(this=expression.this),
255            expression=expression.args["expression"],
256        )
257    )
258
259
260def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
261    zone = self.sql(expression, "this")
262    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
263
264
265def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
266    if expression.args.get("recursive"):
267        self.unsupported("Recursive CTEs are unsupported")
268        expression.args["recursive"] = False
269    return self.with_sql(expression)
270
271
272def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
273    n = self.sql(expression, "this")
274    d = self.sql(expression, "expression")
275    return f"IF({d} <> 0, {n} / {d}, NULL)"
276
277
278def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
279    self.unsupported("TABLESAMPLE unsupported")
280    return self.sql(expression.this)
281
282
283def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
284    self.unsupported("PIVOT unsupported")
285    return self.sql(expression)
286
287
288def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
289    return self.cast_sql(expression)
290
291
292def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
293    self.unsupported("Properties unsupported")
294    return ""
295
296
297def no_comment_column_constraint_sql(
298    self: Generator, expression: exp.CommentColumnConstraint
299) -> str:
300    self.unsupported("CommentColumnConstraint unsupported")
301    return ""
302
303
304def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
305    this = self.sql(expression, "this")
306    substr = self.sql(expression, "substr")
307    position = self.sql(expression, "position")
308    if position:
309        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
310    return f"STRPOS({this}, {substr})"
311
312
313def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
314    this = self.sql(expression, "this")
315    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
316    return f"{this}.{struct_key}"
317
318
319def var_map_sql(
320    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
321) -> str:
322    keys = expression.args["keys"]
323    values = expression.args["values"]
324
325    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
326        self.unsupported("Cannot convert array columns into map.")
327        return self.func(map_func_name, keys, values)
328
329    args = []
330    for key, value in zip(keys.expressions, values.expressions):
331        args.append(self.sql(key))
332        args.append(self.sql(value))
333    return self.func(map_func_name, *args)
334
335
336def format_time_lambda(
337    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
338) -> t.Callable[[t.Sequence], E]:
339    """Helper used for time expressions.
340
341    Args:
342        exp_class: the expression class to instantiate.
343        dialect: target sql dialect.
344        default: the default format, True being time.
345
346    Returns:
347        A callable that can be used to return the appropriately formatted time expression.
348    """
349
350    def _format_time(args: t.Sequence):
351        return exp_class(
352            this=seq_get(args, 0),
353            format=Dialect[dialect].format_time(
354                seq_get(args, 1)
355                or (Dialect[dialect].time_format if default is True else default or None)
356            ),
357        )
358
359    return _format_time
360
361
362def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
363    """
364    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
365    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
366    columns are removed from the create statement.
367    """
368    has_schema = isinstance(expression.this, exp.Schema)
369    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
370
371    if has_schema and is_partitionable:
372        expression = expression.copy()
373        prop = expression.find(exp.PartitionedByProperty)
374        if prop and prop.this and not isinstance(prop.this, exp.Schema):
375            schema = expression.this
376            columns = {v.name.upper() for v in prop.this.expressions}
377            partitions = [col for col in schema.expressions if col.name.upper() in columns]
378            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
379            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
380            expression.set("this", schema)
381
382    return self.create_sql(expression)
383
384
385def parse_date_delta(
386    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
387) -> t.Callable[[t.Sequence], E]:
388    def inner_func(args: t.Sequence) -> E:
389        unit_based = len(args) == 3
390        this = args[2] if unit_based else seq_get(args, 0)
391        unit = args[0] if unit_based else exp.Literal.string("DAY")
392        unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
393        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
394
395    return inner_func
396
397
398def parse_date_delta_with_interval(
399    expression_class: t.Type[E],
400) -> t.Callable[[t.Sequence], t.Optional[E]]:
401    def func(args: t.Sequence) -> t.Optional[E]:
402        if len(args) < 2:
403            return None
404
405        interval = args[1]
406        expression = interval.this
407        if expression and expression.is_string:
408            expression = exp.Literal.number(expression.this)
409
410        return expression_class(
411            this=args[0],
412            expression=expression,
413            unit=exp.Literal.string(interval.text("unit")),
414        )
415
416    return func
417
418
419def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
420    unit = seq_get(args, 0)
421    this = seq_get(args, 1)
422
423    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
424        return exp.DateTrunc(unit=unit, this=this)
425    return exp.TimestampTrunc(this=this, unit=unit)
426
427
428def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
429    return self.func(
430        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
431    )
432
433
434def locate_to_strposition(args: t.Sequence) -> exp.Expression:
435    return exp.StrPosition(
436        this=seq_get(args, 1),
437        substr=seq_get(args, 0),
438        position=seq_get(args, 2),
439    )
440
441
442def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
443    return self.func(
444        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
445    )
446
447
448def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
449    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
450
451
452def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
453    return f"CAST({self.sql(expression, 'this')} AS DATE)"
454
455
456def min_or_least(self: Generator, expression: exp.Min) -> str:
457    name = "LEAST" if expression.expressions else "MIN"
458    return rename_func(name)(self, expression)
459
460
461def max_or_greatest(self: Generator, expression: exp.Max) -> str:
462    name = "GREATEST" if expression.expressions else "MAX"
463    return rename_func(name)(self, expression)
464
465
466def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
467    cond = expression.this
468
469    if isinstance(expression.this, exp.Distinct):
470        cond = expression.this.expressions[0]
471        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
472
473    return self.func("sum", exp.func("if", cond, 1, 0))
474
475
476def trim_sql(self: Generator, expression: exp.Trim) -> str:
477    target = self.sql(expression, "this")
478    trim_type = self.sql(expression, "position")
479    remove_chars = self.sql(expression, "expression")
480    collation = self.sql(expression, "collation")
481
482    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
483    if not remove_chars and not collation:
484        return self.trim_sql(expression)
485
486    trim_type = f"{trim_type} " if trim_type else ""
487    remove_chars = f"{remove_chars} " if remove_chars else ""
488    from_part = "FROM " if trim_type or remove_chars else ""
489    collation = f" COLLATE {collation}" if collation else ""
490    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
491
492
493def str_to_time_sql(self, expression: exp.Expression) -> str:
494    return self.func("STRPTIME", expression.this, self.format_time(expression))
495
496
497def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
498    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
499        _dialect = Dialect.get_or_raise(dialect)
500        time_format = self.format_time(expression)
501        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
502            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
503        return f"CAST({self.sql(expression, 'this')} AS DATE)"
504
505    return _ts_or_ds_to_date_sql
class Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum):
19    DIALECT = ""
20
21    BIGQUERY = "bigquery"
22    CLICKHOUSE = "clickhouse"
23    DUCKDB = "duckdb"
24    HIVE = "hive"
25    MYSQL = "mysql"
26    ORACLE = "oracle"
27    POSTGRES = "postgres"
28    PRESTO = "presto"
29    REDSHIFT = "redshift"
30    SNOWFLAKE = "snowflake"
31    SPARK = "spark"
32    SPARK2 = "spark2"
33    SQLITE = "sqlite"
34    STARROCKS = "starrocks"
35    TABLEAU = "tableau"
36    TRINO = "trino"
37    TSQL = "tsql"
38    DATABRICKS = "databricks"
39    DRILL = "drill"
40    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
102class Dialect(metaclass=_Dialect):
103    index_offset = 0
104    unnest_column_only = False
105    alias_post_tablesample = False
106    normalize_functions: t.Optional[str] = "upper"
107    null_ordering = "nulls_are_small"
108
109    date_format = "'%Y-%m-%d'"
110    dateint_format = "'%Y%m%d'"
111    time_format = "'%Y-%m-%d %H:%M:%S'"
112    time_mapping: t.Dict[str, str] = {}
113
114    # autofilled
115    quote_start = None
116    quote_end = None
117    identifier_start = None
118    identifier_end = None
119
120    time_trie = None
121    inverse_time_mapping = None
122    inverse_time_trie = None
123    tokenizer_class = None
124    parser_class = None
125    generator_class = None
126
127    @classmethod
128    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
129        if not dialect:
130            return cls
131        if isinstance(dialect, _Dialect):
132            return dialect
133        if isinstance(dialect, Dialect):
134            return dialect.__class__
135
136        result = cls.get(dialect)
137        if not result:
138            raise ValueError(f"Unknown dialect '{dialect}'")
139
140        return result
141
142    @classmethod
143    def format_time(
144        cls, expression: t.Optional[str | exp.Expression]
145    ) -> t.Optional[exp.Expression]:
146        if isinstance(expression, str):
147            return exp.Literal.string(
148                format_time(
149                    expression[1:-1],  # the time formats are quoted
150                    cls.time_mapping,
151                    cls.time_trie,
152                )
153            )
154        if expression and expression.is_string:
155            return exp.Literal.string(
156                format_time(
157                    expression.this,
158                    cls.time_mapping,
159                    cls.time_trie,
160                )
161            )
162        return expression
163
164    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
165        return self.parser(**opts).parse(self.tokenize(sql), sql)
166
167    def parse_into(
168        self, expression_type: exp.IntoType, sql: str, **opts
169    ) -> t.List[t.Optional[exp.Expression]]:
170        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
171
172    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
173        return self.generator(**opts).generate(expression)
174
175    def transpile(self, sql: str, **opts) -> t.List[str]:
176        return [self.generate(expression, **opts) for expression in self.parse(sql)]
177
178    def tokenize(self, sql: str) -> t.List[Token]:
179        return self.tokenizer.tokenize(sql)
180
181    @property
182    def tokenizer(self) -> Tokenizer:
183        if not hasattr(self, "_tokenizer"):
184            self._tokenizer = self.tokenizer_class()  # type: ignore
185        return self._tokenizer
186
187    def parser(self, **opts) -> Parser:
188        return self.parser_class(  # type: ignore
189            **{
190                "index_offset": self.index_offset,
191                "unnest_column_only": self.unnest_column_only,
192                "alias_post_tablesample": self.alias_post_tablesample,
193                "null_ordering": self.null_ordering,
194                **opts,
195            },
196        )
197
198    def generator(self, **opts) -> Generator:
199        return self.generator_class(  # type: ignore
200            **{
201                "quote_start": self.quote_start,
202                "quote_end": self.quote_end,
203                "identifier_start": self.identifier_start,
204                "identifier_end": self.identifier_end,
205                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
206                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
207                "index_offset": self.index_offset,
208                "time_mapping": self.inverse_time_mapping,
209                "time_trie": self.inverse_time_trie,
210                "unnest_column_only": self.unnest_column_only,
211                "alias_post_tablesample": self.alias_post_tablesample,
212                "normalize_functions": self.normalize_functions,
213                "null_ordering": self.null_ordering,
214                **opts,
215            }
216        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
127    @classmethod
128    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
129        if not dialect:
130            return cls
131        if isinstance(dialect, _Dialect):
132            return dialect
133        if isinstance(dialect, Dialect):
134            return dialect.__class__
135
136        result = cls.get(dialect)
137        if not result:
138            raise ValueError(f"Unknown dialect '{dialect}'")
139
140        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
142    @classmethod
143    def format_time(
144        cls, expression: t.Optional[str | exp.Expression]
145    ) -> t.Optional[exp.Expression]:
146        if isinstance(expression, str):
147            return exp.Literal.string(
148                format_time(
149                    expression[1:-1],  # the time formats are quoted
150                    cls.time_mapping,
151                    cls.time_trie,
152                )
153            )
154        if expression and expression.is_string:
155            return exp.Literal.string(
156                format_time(
157                    expression.this,
158                    cls.time_mapping,
159                    cls.time_trie,
160                )
161            )
162        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
164    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
165        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
167    def parse_into(
168        self, expression_type: exp.IntoType, sql: str, **opts
169    ) -> t.List[t.Optional[exp.Expression]]:
170        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
172    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
173        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
175    def transpile(self, sql: str, **opts) -> t.List[str]:
176        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
178    def tokenize(self, sql: str) -> t.List[Token]:
179        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
187    def parser(self, **opts) -> Parser:
188        return self.parser_class(  # type: ignore
189            **{
190                "index_offset": self.index_offset,
191                "unnest_column_only": self.unnest_column_only,
192                "alias_post_tablesample": self.alias_post_tablesample,
193                "null_ordering": self.null_ordering,
194                **opts,
195            },
196        )
def generator(self, **opts) -> sqlglot.generator.Generator:
198    def generator(self, **opts) -> Generator:
199        return self.generator_class(  # type: ignore
200            **{
201                "quote_start": self.quote_start,
202                "quote_end": self.quote_end,
203                "identifier_start": self.identifier_start,
204                "identifier_end": self.identifier_end,
205                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
206                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
207                "index_offset": self.index_offset,
208                "time_mapping": self.inverse_time_mapping,
209                "time_trie": self.inverse_time_trie,
210                "unnest_column_only": self.unnest_column_only,
211                "alias_post_tablesample": self.alias_post_tablesample,
212                "normalize_functions": self.normalize_functions,
213                "null_ordering": self.null_ordering,
214                **opts,
215            }
216        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
222def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
223    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
226def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
227    if expression.args.get("accuracy"):
228        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
229    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
232def if_sql(self: Generator, expression: exp.If) -> str:
233    return self.func(
234        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
235    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
238def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
239    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
242def arrow_json_extract_scalar_sql(
243    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
244) -> str:
245    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
248def inline_array_sql(self: Generator, expression: exp.Array) -> str:
249    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
252def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
253    return self.like_sql(
254        exp.Like(
255            this=exp.Lower(this=expression.this),
256            expression=expression.args["expression"],
257        )
258    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
261def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
262    zone = self.sql(expression, "this")
263    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
266def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
267    if expression.args.get("recursive"):
268        self.unsupported("Recursive CTEs are unsupported")
269        expression.args["recursive"] = False
270    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
273def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
274    n = self.sql(expression, "this")
275    d = self.sql(expression, "expression")
276    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
279def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
280    self.unsupported("TABLESAMPLE unsupported")
281    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
284def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
285    self.unsupported("PIVOT unsupported")
286    return self.sql(expression)
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
289def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
290    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
293def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
294    self.unsupported("Properties unsupported")
295    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
298def no_comment_column_constraint_sql(
299    self: Generator, expression: exp.CommentColumnConstraint
300) -> str:
301    self.unsupported("CommentColumnConstraint unsupported")
302    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
305def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
306    this = self.sql(expression, "this")
307    substr = self.sql(expression, "substr")
308    position = self.sql(expression, "position")
309    if position:
310        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
311    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
314def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
315    this = self.sql(expression, "this")
316    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
317    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
320def var_map_sql(
321    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
322) -> str:
323    keys = expression.args["keys"]
324    values = expression.args["values"]
325
326    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
327        self.unsupported("Cannot convert array columns into map.")
328        return self.func(map_func_name, keys, values)
329
330    args = []
331    for key, value in zip(keys.expressions, values.expressions):
332        args.append(self.sql(key))
333        args.append(self.sql(value))
334    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
337def format_time_lambda(
338    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
339) -> t.Callable[[t.Sequence], E]:
340    """Helper used for time expressions.
341
342    Args:
343        exp_class: the expression class to instantiate.
344        dialect: target sql dialect.
345        default: the default format, True being time.
346
347    Returns:
348        A callable that can be used to return the appropriately formatted time expression.
349    """
350
351    def _format_time(args: t.Sequence):
352        return exp_class(
353            this=seq_get(args, 0),
354            format=Dialect[dialect].format_time(
355                seq_get(args, 1)
356                or (Dialect[dialect].time_format if default is True else default or None)
357            ),
358        )
359
360    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
363def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
364    """
365    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
366    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
367    columns are removed from the create statement.
368    """
369    has_schema = isinstance(expression.this, exp.Schema)
370    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
371
372    if has_schema and is_partitionable:
373        expression = expression.copy()
374        prop = expression.find(exp.PartitionedByProperty)
375        if prop and prop.this and not isinstance(prop.this, exp.Schema):
376            schema = expression.this
377            columns = {v.name.upper() for v in prop.this.expressions}
378            partitions = [col for col in schema.expressions if col.name.upper() in columns]
379            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
380            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
381            expression.set("this", schema)
382
383    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
386def parse_date_delta(
387    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
388) -> t.Callable[[t.Sequence], E]:
389    def inner_func(args: t.Sequence) -> E:
390        unit_based = len(args) == 3
391        this = args[2] if unit_based else seq_get(args, 0)
392        unit = args[0] if unit_based else exp.Literal.string("DAY")
393        unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
394        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
395
396    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[Sequence], Optional[~E]]:
399def parse_date_delta_with_interval(
400    expression_class: t.Type[E],
401) -> t.Callable[[t.Sequence], t.Optional[E]]:
402    def func(args: t.Sequence) -> t.Optional[E]:
403        if len(args) < 2:
404            return None
405
406        interval = args[1]
407        expression = interval.this
408        if expression and expression.is_string:
409            expression = exp.Literal.number(expression.this)
410
411        return expression_class(
412            this=args[0],
413            expression=expression,
414            unit=exp.Literal.string(interval.text("unit")),
415        )
416
417    return func
def date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
420def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
421    unit = seq_get(args, 0)
422    this = seq_get(args, 1)
423
424    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
425        return exp.DateTrunc(unit=unit, this=this)
426    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
429def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
430    return self.func(
431        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
432    )
def locate_to_strposition(args: Sequence) -> sqlglot.expressions.Expression:
435def locate_to_strposition(args: t.Sequence) -> exp.Expression:
436    return exp.StrPosition(
437        this=seq_get(args, 1),
438        substr=seq_get(args, 0),
439        position=seq_get(args, 2),
440    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
443def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
444    return self.func(
445        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
446    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
449def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
450    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
453def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
454    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
457def min_or_least(self: Generator, expression: exp.Min) -> str:
458    name = "LEAST" if expression.expressions else "MIN"
459    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
462def max_or_greatest(self: Generator, expression: exp.Max) -> str:
463    name = "GREATEST" if expression.expressions else "MAX"
464    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
467def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
468    cond = expression.this
469
470    if isinstance(expression.this, exp.Distinct):
471        cond = expression.this.expressions[0]
472        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
473
474    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
477def trim_sql(self: Generator, expression: exp.Trim) -> str:
478    target = self.sql(expression, "this")
479    trim_type = self.sql(expression, "position")
480    remove_chars = self.sql(expression, "expression")
481    collation = self.sql(expression, "collation")
482
483    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
484    if not remove_chars and not collation:
485        return self.trim_sql(expression)
486
487    trim_type = f"{trim_type} " if trim_type else ""
488    remove_chars = f"{remove_chars} " if remove_chars else ""
489    from_part = "FROM " if trim_type or remove_chars else ""
490    collation = f" COLLATE {collation}" if collation else ""
491    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql(self, expression: sqlglot.expressions.Expression) -> str:
494def str_to_time_sql(self, expression: exp.Expression) -> str:
495    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
498def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
499    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
500        _dialect = Dialect.get_or_raise(dialect)
501        time_format = self.format_time(expression)
502        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
503            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
504        return f"CAST({self.sql(expression, 'this')} AS DATE)"
505
506    return _ts_or_ds_to_date_sql