Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer, TokenType
 12from sqlglot.trie import new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18class Dialects(str, Enum):
 19    DIALECT = ""
 20
 21    BIGQUERY = "bigquery"
 22    CLICKHOUSE = "clickhouse"
 23    DUCKDB = "duckdb"
 24    HIVE = "hive"
 25    MYSQL = "mysql"
 26    ORACLE = "oracle"
 27    POSTGRES = "postgres"
 28    PRESTO = "presto"
 29    REDSHIFT = "redshift"
 30    SNOWFLAKE = "snowflake"
 31    SPARK = "spark"
 32    SPARK2 = "spark2"
 33    SQLITE = "sqlite"
 34    STARROCKS = "starrocks"
 35    TABLEAU = "tableau"
 36    TRINO = "trino"
 37    TSQL = "tsql"
 38    DATABRICKS = "databricks"
 39    DRILL = "drill"
 40    TERADATA = "teradata"
 41
 42
 43class _Dialect(type):
 44    classes: t.Dict[str, t.Type[Dialect]] = {}
 45
 46    def __eq__(cls, other: t.Any) -> bool:
 47        if cls is other:
 48            return True
 49        if isinstance(other, str):
 50            return cls is cls.get(other)
 51        if isinstance(other, Dialect):
 52            return cls is type(other)
 53
 54        return False
 55
 56    def __hash__(cls) -> int:
 57        return hash(cls.__name__.lower())
 58
 59    @classmethod
 60    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 61        return cls.classes[key]
 62
 63    @classmethod
 64    def get(
 65        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 66    ) -> t.Optional[t.Type[Dialect]]:
 67        return cls.classes.get(key, default)
 68
 69    def __new__(cls, clsname, bases, attrs):
 70        klass = super().__new__(cls, clsname, bases, attrs)
 71        enum = Dialects.__members__.get(clsname.upper())
 72        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 73
 74        klass.time_trie = new_trie(klass.time_mapping)
 75        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 76        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 77
 78        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 79        klass.parser_class = getattr(klass, "Parser", Parser)
 80        klass.generator_class = getattr(klass, "Generator", Generator)
 81
 82        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 83        klass.identifier_start, klass.identifier_end = list(
 84            klass.tokenizer_class._IDENTIFIERS.items()
 85        )[0]
 86
 87        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 88            return next(
 89                (
 90                    (s, e)
 91                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 92                    if t == token_type
 93                ),
 94                (None, None),
 95            )
 96
 97        klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
 98        klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
 99        klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
100        klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
101
102        return klass
103
104
105class Dialect(metaclass=_Dialect):
106    index_offset = 0
107    unnest_column_only = False
108    alias_post_tablesample = False
109    normalize_functions: t.Optional[str] = "upper"
110    null_ordering = "nulls_are_small"
111
112    date_format = "'%Y-%m-%d'"
113    dateint_format = "'%Y%m%d'"
114    time_format = "'%Y-%m-%d %H:%M:%S'"
115    time_mapping: t.Dict[str, str] = {}
116
117    # autofilled
118    quote_start = None
119    quote_end = None
120    identifier_start = None
121    identifier_end = None
122
123    time_trie = None
124    inverse_time_mapping = None
125    inverse_time_trie = None
126    tokenizer_class = None
127    parser_class = None
128    generator_class = None
129
130    def __eq__(self, other: t.Any) -> bool:
131        return type(self) == other
132
133    def __hash__(self) -> int:
134        return hash(type(self))
135
136    @classmethod
137    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
138        if not dialect:
139            return cls
140        if isinstance(dialect, _Dialect):
141            return dialect
142        if isinstance(dialect, Dialect):
143            return dialect.__class__
144
145        result = cls.get(dialect)
146        if not result:
147            raise ValueError(f"Unknown dialect '{dialect}'")
148
149        return result
150
151    @classmethod
152    def format_time(
153        cls, expression: t.Optional[str | exp.Expression]
154    ) -> t.Optional[exp.Expression]:
155        if isinstance(expression, str):
156            return exp.Literal.string(
157                format_time(
158                    expression[1:-1],  # the time formats are quoted
159                    cls.time_mapping,
160                    cls.time_trie,
161                )
162            )
163        if expression and expression.is_string:
164            return exp.Literal.string(
165                format_time(
166                    expression.this,
167                    cls.time_mapping,
168                    cls.time_trie,
169                )
170            )
171        return expression
172
173    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
174        return self.parser(**opts).parse(self.tokenize(sql), sql)
175
176    def parse_into(
177        self, expression_type: exp.IntoType, sql: str, **opts
178    ) -> t.List[t.Optional[exp.Expression]]:
179        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
180
181    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
182        return self.generator(**opts).generate(expression)
183
184    def transpile(self, sql: str, **opts) -> t.List[str]:
185        return [self.generate(expression, **opts) for expression in self.parse(sql)]
186
187    def tokenize(self, sql: str) -> t.List[Token]:
188        return self.tokenizer.tokenize(sql)
189
190    @property
191    def tokenizer(self) -> Tokenizer:
192        if not hasattr(self, "_tokenizer"):
193            self._tokenizer = self.tokenizer_class()  # type: ignore
194        return self._tokenizer
195
196    def parser(self, **opts) -> Parser:
197        return self.parser_class(  # type: ignore
198            **{
199                "index_offset": self.index_offset,
200                "unnest_column_only": self.unnest_column_only,
201                "alias_post_tablesample": self.alias_post_tablesample,
202                "null_ordering": self.null_ordering,
203                **opts,
204            },
205        )
206
207    def generator(self, **opts) -> Generator:
208        return self.generator_class(  # type: ignore
209            **{
210                "quote_start": self.quote_start,
211                "quote_end": self.quote_end,
212                "bit_start": self.bit_start,
213                "bit_end": self.bit_end,
214                "hex_start": self.hex_start,
215                "hex_end": self.hex_end,
216                "byte_start": self.byte_start,
217                "byte_end": self.byte_end,
218                "raw_start": self.raw_start,
219                "raw_end": self.raw_end,
220                "identifier_start": self.identifier_start,
221                "identifier_end": self.identifier_end,
222                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
223                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
224                "index_offset": self.index_offset,
225                "time_mapping": self.inverse_time_mapping,
226                "time_trie": self.inverse_time_trie,
227                "unnest_column_only": self.unnest_column_only,
228                "alias_post_tablesample": self.alias_post_tablesample,
229                "normalize_functions": self.normalize_functions,
230                "null_ordering": self.null_ordering,
231                **opts,
232            }
233        )
234
235
236DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
237
238
239def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
240    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
241
242
243def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
244    if expression.args.get("accuracy"):
245        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
246    return self.func("APPROX_COUNT_DISTINCT", expression.this)
247
248
249def if_sql(self: Generator, expression: exp.If) -> str:
250    return self.func(
251        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
252    )
253
254
255def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
256    return self.binary(expression, "->")
257
258
259def arrow_json_extract_scalar_sql(
260    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
261) -> str:
262    return self.binary(expression, "->>")
263
264
265def inline_array_sql(self: Generator, expression: exp.Array) -> str:
266    return f"[{self.expressions(expression)}]"
267
268
269def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
270    return self.like_sql(
271        exp.Like(
272            this=exp.Lower(this=expression.this),
273            expression=expression.args["expression"],
274        )
275    )
276
277
278def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
279    zone = self.sql(expression, "this")
280    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
281
282
283def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
284    if expression.args.get("recursive"):
285        self.unsupported("Recursive CTEs are unsupported")
286        expression.args["recursive"] = False
287    return self.with_sql(expression)
288
289
290def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
291    n = self.sql(expression, "this")
292    d = self.sql(expression, "expression")
293    return f"IF({d} <> 0, {n} / {d}, NULL)"
294
295
296def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
297    self.unsupported("TABLESAMPLE unsupported")
298    return self.sql(expression.this)
299
300
301def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
302    self.unsupported("PIVOT unsupported")
303    return ""
304
305
306def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
307    return self.cast_sql(expression)
308
309
310def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
311    self.unsupported("Properties unsupported")
312    return ""
313
314
315def no_comment_column_constraint_sql(
316    self: Generator, expression: exp.CommentColumnConstraint
317) -> str:
318    self.unsupported("CommentColumnConstraint unsupported")
319    return ""
320
321
322def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
323    this = self.sql(expression, "this")
324    substr = self.sql(expression, "substr")
325    position = self.sql(expression, "position")
326    if position:
327        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
328    return f"STRPOS({this}, {substr})"
329
330
331def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
332    this = self.sql(expression, "this")
333    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
334    return f"{this}.{struct_key}"
335
336
337def var_map_sql(
338    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
339) -> str:
340    keys = expression.args["keys"]
341    values = expression.args["values"]
342
343    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
344        self.unsupported("Cannot convert array columns into map.")
345        return self.func(map_func_name, keys, values)
346
347    args = []
348    for key, value in zip(keys.expressions, values.expressions):
349        args.append(self.sql(key))
350        args.append(self.sql(value))
351    return self.func(map_func_name, *args)
352
353
354def format_time_lambda(
355    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
356) -> t.Callable[[t.List], E]:
357    """Helper used for time expressions.
358
359    Args:
360        exp_class: the expression class to instantiate.
361        dialect: target sql dialect.
362        default: the default format, True being time.
363
364    Returns:
365        A callable that can be used to return the appropriately formatted time expression.
366    """
367
368    def _format_time(args: t.List):
369        return exp_class(
370            this=seq_get(args, 0),
371            format=Dialect[dialect].format_time(
372                seq_get(args, 1)
373                or (Dialect[dialect].time_format if default is True else default or None)
374            ),
375        )
376
377    return _format_time
378
379
380def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
381    """
382    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
383    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
384    columns are removed from the create statement.
385    """
386    has_schema = isinstance(expression.this, exp.Schema)
387    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
388
389    if has_schema and is_partitionable:
390        expression = expression.copy()
391        prop = expression.find(exp.PartitionedByProperty)
392        if prop and prop.this and not isinstance(prop.this, exp.Schema):
393            schema = expression.this
394            columns = {v.name.upper() for v in prop.this.expressions}
395            partitions = [col for col in schema.expressions if col.name.upper() in columns]
396            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
397            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
398            expression.set("this", schema)
399
400    return self.create_sql(expression)
401
402
403def parse_date_delta(
404    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
405) -> t.Callable[[t.List], E]:
406    def inner_func(args: t.List) -> E:
407        unit_based = len(args) == 3
408        this = args[2] if unit_based else seq_get(args, 0)
409        unit = args[0] if unit_based else exp.Literal.string("DAY")
410        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
411        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
412
413    return inner_func
414
415
416def parse_date_delta_with_interval(
417    expression_class: t.Type[E],
418) -> t.Callable[[t.List], t.Optional[E]]:
419    def func(args: t.List) -> t.Optional[E]:
420        if len(args) < 2:
421            return None
422
423        interval = args[1]
424        expression = interval.this
425        if expression and expression.is_string:
426            expression = exp.Literal.number(expression.this)
427
428        return expression_class(
429            this=args[0],
430            expression=expression,
431            unit=exp.Literal.string(interval.text("unit")),
432        )
433
434    return func
435
436
437def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
438    unit = seq_get(args, 0)
439    this = seq_get(args, 1)
440
441    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
442        return exp.DateTrunc(unit=unit, this=this)
443    return exp.TimestampTrunc(this=this, unit=unit)
444
445
446def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
447    return self.func(
448        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
449    )
450
451
452def locate_to_strposition(args: t.List) -> exp.Expression:
453    return exp.StrPosition(
454        this=seq_get(args, 1),
455        substr=seq_get(args, 0),
456        position=seq_get(args, 2),
457    )
458
459
460def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
461    return self.func(
462        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
463    )
464
465
466def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
467    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
468
469
470def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
471    return f"CAST({self.sql(expression, 'this')} AS DATE)"
472
473
474def min_or_least(self: Generator, expression: exp.Min) -> str:
475    name = "LEAST" if expression.expressions else "MIN"
476    return rename_func(name)(self, expression)
477
478
479def max_or_greatest(self: Generator, expression: exp.Max) -> str:
480    name = "GREATEST" if expression.expressions else "MAX"
481    return rename_func(name)(self, expression)
482
483
484def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
485    cond = expression.this
486
487    if isinstance(expression.this, exp.Distinct):
488        cond = expression.this.expressions[0]
489        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
490
491    return self.func("sum", exp.func("if", cond, 1, 0))
492
493
494def trim_sql(self: Generator, expression: exp.Trim) -> str:
495    target = self.sql(expression, "this")
496    trim_type = self.sql(expression, "position")
497    remove_chars = self.sql(expression, "expression")
498    collation = self.sql(expression, "collation")
499
500    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
501    if not remove_chars and not collation:
502        return self.trim_sql(expression)
503
504    trim_type = f"{trim_type} " if trim_type else ""
505    remove_chars = f"{remove_chars} " if remove_chars else ""
506    from_part = "FROM " if trim_type or remove_chars else ""
507    collation = f" COLLATE {collation}" if collation else ""
508    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
509
510
511def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
512    return self.func("STRPTIME", expression.this, self.format_time(expression))
513
514
515def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
516    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
517        _dialect = Dialect.get_or_raise(dialect)
518        time_format = self.format_time(expression)
519        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
520            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
521        return f"CAST({self.sql(expression, 'this')} AS DATE)"
522
523    return _ts_or_ds_to_date_sql
524
525
526# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
527def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
528    names = []
529    for agg in aggregations:
530        if isinstance(agg, exp.Alias):
531            names.append(agg.alias)
532        else:
533            """
534            This case corresponds to aggregations without aliases being used as suffixes
535            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
536            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
537            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
538            """
539            agg_all_unquoted = agg.transform(
540                lambda node: exp.Identifier(this=node.name, quoted=False)
541                if isinstance(node, exp.Identifier)
542                else node
543            )
544            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
545
546    return names
class Dialects(builtins.str, enum.Enum):
19class Dialects(str, Enum):
20    DIALECT = ""
21
22    BIGQUERY = "bigquery"
23    CLICKHOUSE = "clickhouse"
24    DUCKDB = "duckdb"
25    HIVE = "hive"
26    MYSQL = "mysql"
27    ORACLE = "oracle"
28    POSTGRES = "postgres"
29    PRESTO = "presto"
30    REDSHIFT = "redshift"
31    SNOWFLAKE = "snowflake"
32    SPARK = "spark"
33    SPARK2 = "spark2"
34    SQLITE = "sqlite"
35    STARROCKS = "starrocks"
36    TABLEAU = "tableau"
37    TRINO = "trino"
38    TSQL = "tsql"
39    DATABRICKS = "databricks"
40    DRILL = "drill"
41    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
106class Dialect(metaclass=_Dialect):
107    index_offset = 0
108    unnest_column_only = False
109    alias_post_tablesample = False
110    normalize_functions: t.Optional[str] = "upper"
111    null_ordering = "nulls_are_small"
112
113    date_format = "'%Y-%m-%d'"
114    dateint_format = "'%Y%m%d'"
115    time_format = "'%Y-%m-%d %H:%M:%S'"
116    time_mapping: t.Dict[str, str] = {}
117
118    # autofilled
119    quote_start = None
120    quote_end = None
121    identifier_start = None
122    identifier_end = None
123
124    time_trie = None
125    inverse_time_mapping = None
126    inverse_time_trie = None
127    tokenizer_class = None
128    parser_class = None
129    generator_class = None
130
131    def __eq__(self, other: t.Any) -> bool:
132        return type(self) == other
133
134    def __hash__(self) -> int:
135        return hash(type(self))
136
137    @classmethod
138    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
139        if not dialect:
140            return cls
141        if isinstance(dialect, _Dialect):
142            return dialect
143        if isinstance(dialect, Dialect):
144            return dialect.__class__
145
146        result = cls.get(dialect)
147        if not result:
148            raise ValueError(f"Unknown dialect '{dialect}'")
149
150        return result
151
152    @classmethod
153    def format_time(
154        cls, expression: t.Optional[str | exp.Expression]
155    ) -> t.Optional[exp.Expression]:
156        if isinstance(expression, str):
157            return exp.Literal.string(
158                format_time(
159                    expression[1:-1],  # the time formats are quoted
160                    cls.time_mapping,
161                    cls.time_trie,
162                )
163            )
164        if expression and expression.is_string:
165            return exp.Literal.string(
166                format_time(
167                    expression.this,
168                    cls.time_mapping,
169                    cls.time_trie,
170                )
171            )
172        return expression
173
174    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
175        return self.parser(**opts).parse(self.tokenize(sql), sql)
176
177    def parse_into(
178        self, expression_type: exp.IntoType, sql: str, **opts
179    ) -> t.List[t.Optional[exp.Expression]]:
180        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
181
182    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
183        return self.generator(**opts).generate(expression)
184
185    def transpile(self, sql: str, **opts) -> t.List[str]:
186        return [self.generate(expression, **opts) for expression in self.parse(sql)]
187
188    def tokenize(self, sql: str) -> t.List[Token]:
189        return self.tokenizer.tokenize(sql)
190
191    @property
192    def tokenizer(self) -> Tokenizer:
193        if not hasattr(self, "_tokenizer"):
194            self._tokenizer = self.tokenizer_class()  # type: ignore
195        return self._tokenizer
196
197    def parser(self, **opts) -> Parser:
198        return self.parser_class(  # type: ignore
199            **{
200                "index_offset": self.index_offset,
201                "unnest_column_only": self.unnest_column_only,
202                "alias_post_tablesample": self.alias_post_tablesample,
203                "null_ordering": self.null_ordering,
204                **opts,
205            },
206        )
207
208    def generator(self, **opts) -> Generator:
209        return self.generator_class(  # type: ignore
210            **{
211                "quote_start": self.quote_start,
212                "quote_end": self.quote_end,
213                "bit_start": self.bit_start,
214                "bit_end": self.bit_end,
215                "hex_start": self.hex_start,
216                "hex_end": self.hex_end,
217                "byte_start": self.byte_start,
218                "byte_end": self.byte_end,
219                "raw_start": self.raw_start,
220                "raw_end": self.raw_end,
221                "identifier_start": self.identifier_start,
222                "identifier_end": self.identifier_end,
223                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
224                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
225                "index_offset": self.index_offset,
226                "time_mapping": self.inverse_time_mapping,
227                "time_trie": self.inverse_time_trie,
228                "unnest_column_only": self.unnest_column_only,
229                "alias_post_tablesample": self.alias_post_tablesample,
230                "normalize_functions": self.normalize_functions,
231                "null_ordering": self.null_ordering,
232                **opts,
233            }
234        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
137    @classmethod
138    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
139        if not dialect:
140            return cls
141        if isinstance(dialect, _Dialect):
142            return dialect
143        if isinstance(dialect, Dialect):
144            return dialect.__class__
145
146        result = cls.get(dialect)
147        if not result:
148            raise ValueError(f"Unknown dialect '{dialect}'")
149
150        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
152    @classmethod
153    def format_time(
154        cls, expression: t.Optional[str | exp.Expression]
155    ) -> t.Optional[exp.Expression]:
156        if isinstance(expression, str):
157            return exp.Literal.string(
158                format_time(
159                    expression[1:-1],  # the time formats are quoted
160                    cls.time_mapping,
161                    cls.time_trie,
162                )
163            )
164        if expression and expression.is_string:
165            return exp.Literal.string(
166                format_time(
167                    expression.this,
168                    cls.time_mapping,
169                    cls.time_trie,
170                )
171            )
172        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
174    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
175        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
177    def parse_into(
178        self, expression_type: exp.IntoType, sql: str, **opts
179    ) -> t.List[t.Optional[exp.Expression]]:
180        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
182    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
183        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
185    def transpile(self, sql: str, **opts) -> t.List[str]:
186        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
188    def tokenize(self, sql: str) -> t.List[Token]:
189        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
197    def parser(self, **opts) -> Parser:
198        return self.parser_class(  # type: ignore
199            **{
200                "index_offset": self.index_offset,
201                "unnest_column_only": self.unnest_column_only,
202                "alias_post_tablesample": self.alias_post_tablesample,
203                "null_ordering": self.null_ordering,
204                **opts,
205            },
206        )
def generator(self, **opts) -> sqlglot.generator.Generator:
208    def generator(self, **opts) -> Generator:
209        return self.generator_class(  # type: ignore
210            **{
211                "quote_start": self.quote_start,
212                "quote_end": self.quote_end,
213                "bit_start": self.bit_start,
214                "bit_end": self.bit_end,
215                "hex_start": self.hex_start,
216                "hex_end": self.hex_end,
217                "byte_start": self.byte_start,
218                "byte_end": self.byte_end,
219                "raw_start": self.raw_start,
220                "raw_end": self.raw_end,
221                "identifier_start": self.identifier_start,
222                "identifier_end": self.identifier_end,
223                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
224                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
225                "index_offset": self.index_offset,
226                "time_mapping": self.inverse_time_mapping,
227                "time_trie": self.inverse_time_trie,
228                "unnest_column_only": self.unnest_column_only,
229                "alias_post_tablesample": self.alias_post_tablesample,
230                "normalize_functions": self.normalize_functions,
231                "null_ordering": self.null_ordering,
232                **opts,
233            }
234        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
240def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
241    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
244def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
245    if expression.args.get("accuracy"):
246        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
247    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
250def if_sql(self: Generator, expression: exp.If) -> str:
251    return self.func(
252        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
253    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
256def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
257    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
260def arrow_json_extract_scalar_sql(
261    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
262) -> str:
263    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
266def inline_array_sql(self: Generator, expression: exp.Array) -> str:
267    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
270def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
271    return self.like_sql(
272        exp.Like(
273            this=exp.Lower(this=expression.this),
274            expression=expression.args["expression"],
275        )
276    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
279def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
280    zone = self.sql(expression, "this")
281    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
284def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
285    if expression.args.get("recursive"):
286        self.unsupported("Recursive CTEs are unsupported")
287        expression.args["recursive"] = False
288    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
291def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
292    n = self.sql(expression, "this")
293    d = self.sql(expression, "expression")
294    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
297def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
298    self.unsupported("TABLESAMPLE unsupported")
299    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
302def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
303    self.unsupported("PIVOT unsupported")
304    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
307def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
308    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
311def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
312    self.unsupported("Properties unsupported")
313    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
316def no_comment_column_constraint_sql(
317    self: Generator, expression: exp.CommentColumnConstraint
318) -> str:
319    self.unsupported("CommentColumnConstraint unsupported")
320    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
323def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
324    this = self.sql(expression, "this")
325    substr = self.sql(expression, "substr")
326    position = self.sql(expression, "position")
327    if position:
328        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
329    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
332def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
333    this = self.sql(expression, "this")
334    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
335    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
338def var_map_sql(
339    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
340) -> str:
341    keys = expression.args["keys"]
342    values = expression.args["values"]
343
344    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
345        self.unsupported("Cannot convert array columns into map.")
346        return self.func(map_func_name, keys, values)
347
348    args = []
349    for key, value in zip(keys.expressions, values.expressions):
350        args.append(self.sql(key))
351        args.append(self.sql(value))
352    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[List], ~E]:
355def format_time_lambda(
356    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
357) -> t.Callable[[t.List], E]:
358    """Helper used for time expressions.
359
360    Args:
361        exp_class: the expression class to instantiate.
362        dialect: target sql dialect.
363        default: the default format, True being time.
364
365    Returns:
366        A callable that can be used to return the appropriately formatted time expression.
367    """
368
369    def _format_time(args: t.List):
370        return exp_class(
371            this=seq_get(args, 0),
372            format=Dialect[dialect].format_time(
373                seq_get(args, 1)
374                or (Dialect[dialect].time_format if default is True else default or None)
375            ),
376        )
377
378    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
381def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
382    """
383    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
384    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
385    columns are removed from the create statement.
386    """
387    has_schema = isinstance(expression.this, exp.Schema)
388    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
389
390    if has_schema and is_partitionable:
391        expression = expression.copy()
392        prop = expression.find(exp.PartitionedByProperty)
393        if prop and prop.this and not isinstance(prop.this, exp.Schema):
394            schema = expression.this
395            columns = {v.name.upper() for v in prop.this.expressions}
396            partitions = [col for col in schema.expressions if col.name.upper() in columns]
397            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
398            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
399            expression.set("this", schema)
400
401    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
404def parse_date_delta(
405    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
406) -> t.Callable[[t.List], E]:
407    def inner_func(args: t.List) -> E:
408        unit_based = len(args) == 3
409        this = args[2] if unit_based else seq_get(args, 0)
410        unit = args[0] if unit_based else exp.Literal.string("DAY")
411        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
412        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
413
414    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
417def parse_date_delta_with_interval(
418    expression_class: t.Type[E],
419) -> t.Callable[[t.List], t.Optional[E]]:
420    def func(args: t.List) -> t.Optional[E]:
421        if len(args) < 2:
422            return None
423
424        interval = args[1]
425        expression = interval.this
426        if expression and expression.is_string:
427            expression = exp.Literal.number(expression.this)
428
429        return expression_class(
430            this=args[0],
431            expression=expression,
432            unit=exp.Literal.string(interval.text("unit")),
433        )
434
435    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
438def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
439    unit = seq_get(args, 0)
440    this = seq_get(args, 1)
441
442    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
443        return exp.DateTrunc(unit=unit, this=this)
444    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
447def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
448    return self.func(
449        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
450    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
453def locate_to_strposition(args: t.List) -> exp.Expression:
454    return exp.StrPosition(
455        this=seq_get(args, 1),
456        substr=seq_get(args, 0),
457        position=seq_get(args, 2),
458    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
461def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
462    return self.func(
463        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
464    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
467def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
468    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
471def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
472    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
475def min_or_least(self: Generator, expression: exp.Min) -> str:
476    name = "LEAST" if expression.expressions else "MIN"
477    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
480def max_or_greatest(self: Generator, expression: exp.Max) -> str:
481    name = "GREATEST" if expression.expressions else "MAX"
482    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
485def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
486    cond = expression.this
487
488    if isinstance(expression.this, exp.Distinct):
489        cond = expression.this.expressions[0]
490        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
491
492    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
495def trim_sql(self: Generator, expression: exp.Trim) -> str:
496    target = self.sql(expression, "this")
497    trim_type = self.sql(expression, "position")
498    remove_chars = self.sql(expression, "expression")
499    collation = self.sql(expression, "collation")
500
501    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
502    if not remove_chars and not collation:
503        return self.trim_sql(expression)
504
505    trim_type = f"{trim_type} " if trim_type else ""
506    remove_chars = f"{remove_chars} " if remove_chars else ""
507    from_part = "FROM " if trim_type or remove_chars else ""
508    collation = f" COLLATE {collation}" if collation else ""
509    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
512def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
513    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
516def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
517    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
518        _dialect = Dialect.get_or_raise(dialect)
519        time_format = self.format_time(expression)
520        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
521            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
522        return f"CAST({self.sql(expression, 'this')} AS DATE)"
523
524    return _ts_or_ds_to_date_sql
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
528def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
529    names = []
530    for agg in aggregations:
531        if isinstance(agg, exp.Alias):
532            names.append(agg.alias)
533        else:
534            """
535            This case corresponds to aggregations without aliases being used as suffixes
536            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
537            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
538            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
539            """
540            agg_all_unquoted = agg.transform(
541                lambda node: exp.Identifier(this=node.name, quoted=False)
542                if isinstance(node, exp.Identifier)
543                else node
544            )
545            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
546
547    return names