Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer
 12from sqlglot.trie import new_trie
 13
 14E = t.TypeVar("E", bound=exp.Expression)
 15
 16
 17class Dialects(str, Enum):
 18    DIALECT = ""
 19
 20    BIGQUERY = "bigquery"
 21    CLICKHOUSE = "clickhouse"
 22    DUCKDB = "duckdb"
 23    HIVE = "hive"
 24    MYSQL = "mysql"
 25    ORACLE = "oracle"
 26    POSTGRES = "postgres"
 27    PRESTO = "presto"
 28    REDSHIFT = "redshift"
 29    SNOWFLAKE = "snowflake"
 30    SPARK = "spark"
 31    SQLITE = "sqlite"
 32    STARROCKS = "starrocks"
 33    TABLEAU = "tableau"
 34    TRINO = "trino"
 35    TSQL = "tsql"
 36    DATABRICKS = "databricks"
 37    DRILL = "drill"
 38    TERADATA = "teradata"
 39
 40
 41class _Dialect(type):
 42    classes: t.Dict[str, t.Type[Dialect]] = {}
 43
 44    @classmethod
 45    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 46        return cls.classes[key]
 47
 48    @classmethod
 49    def get(
 50        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 51    ) -> t.Optional[t.Type[Dialect]]:
 52        return cls.classes.get(key, default)
 53
 54    def __new__(cls, clsname, bases, attrs):
 55        klass = super().__new__(cls, clsname, bases, attrs)
 56        enum = Dialects.__members__.get(clsname.upper())
 57        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 58
 59        klass.time_trie = new_trie(klass.time_mapping)
 60        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 61        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 62
 63        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 64        klass.parser_class = getattr(klass, "Parser", Parser)
 65        klass.generator_class = getattr(klass, "Generator", Generator)
 66
 67        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 68        klass.identifier_start, klass.identifier_end = list(
 69            klass.tokenizer_class._IDENTIFIERS.items()
 70        )[0]
 71
 72        if (
 73            klass.tokenizer_class._BIT_STRINGS
 74            and exp.BitString not in klass.generator_class.TRANSFORMS
 75        ):
 76            bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
 77            klass.generator_class.TRANSFORMS[
 78                exp.BitString
 79            ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
 80        if (
 81            klass.tokenizer_class._HEX_STRINGS
 82            and exp.HexString not in klass.generator_class.TRANSFORMS
 83        ):
 84            hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
 85            klass.generator_class.TRANSFORMS[
 86                exp.HexString
 87            ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
 88        if (
 89            klass.tokenizer_class._BYTE_STRINGS
 90            and exp.ByteString not in klass.generator_class.TRANSFORMS
 91        ):
 92            be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
 93            klass.generator_class.TRANSFORMS[
 94                exp.ByteString
 95            ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
 96
 97        return klass
 98
 99
100class Dialect(metaclass=_Dialect):
101    index_offset = 0
102    unnest_column_only = False
103    alias_post_tablesample = False
104    normalize_functions: t.Optional[str] = "upper"
105    null_ordering = "nulls_are_small"
106
107    date_format = "'%Y-%m-%d'"
108    dateint_format = "'%Y%m%d'"
109    time_format = "'%Y-%m-%d %H:%M:%S'"
110    time_mapping: t.Dict[str, str] = {}
111
112    # autofilled
113    quote_start = None
114    quote_end = None
115    identifier_start = None
116    identifier_end = None
117
118    time_trie = None
119    inverse_time_mapping = None
120    inverse_time_trie = None
121    tokenizer_class = None
122    parser_class = None
123    generator_class = None
124
125    @classmethod
126    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
127        if not dialect:
128            return cls
129        if isinstance(dialect, _Dialect):
130            return dialect
131        if isinstance(dialect, Dialect):
132            return dialect.__class__
133
134        result = cls.get(dialect)
135        if not result:
136            raise ValueError(f"Unknown dialect '{dialect}'")
137
138        return result
139
140    @classmethod
141    def format_time(
142        cls, expression: t.Optional[str | exp.Expression]
143    ) -> t.Optional[exp.Expression]:
144        if isinstance(expression, str):
145            return exp.Literal.string(
146                format_time(
147                    expression[1:-1],  # the time formats are quoted
148                    cls.time_mapping,
149                    cls.time_trie,
150                )
151            )
152        if expression and expression.is_string:
153            return exp.Literal.string(
154                format_time(
155                    expression.this,
156                    cls.time_mapping,
157                    cls.time_trie,
158                )
159            )
160        return expression
161
162    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
163        return self.parser(**opts).parse(self.tokenize(sql), sql)
164
165    def parse_into(
166        self, expression_type: exp.IntoType, sql: str, **opts
167    ) -> t.List[t.Optional[exp.Expression]]:
168        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
169
170    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
171        return self.generator(**opts).generate(expression)
172
173    def transpile(self, sql: str, **opts) -> t.List[str]:
174        return [self.generate(expression, **opts) for expression in self.parse(sql)]
175
176    def tokenize(self, sql: str) -> t.List[Token]:
177        return self.tokenizer.tokenize(sql)
178
179    @property
180    def tokenizer(self) -> Tokenizer:
181        if not hasattr(self, "_tokenizer"):
182            self._tokenizer = self.tokenizer_class()  # type: ignore
183        return self._tokenizer
184
185    def parser(self, **opts) -> Parser:
186        return self.parser_class(  # type: ignore
187            **{
188                "index_offset": self.index_offset,
189                "unnest_column_only": self.unnest_column_only,
190                "alias_post_tablesample": self.alias_post_tablesample,
191                "null_ordering": self.null_ordering,
192                **opts,
193            },
194        )
195
196    def generator(self, **opts) -> Generator:
197        return self.generator_class(  # type: ignore
198            **{
199                "quote_start": self.quote_start,
200                "quote_end": self.quote_end,
201                "identifier_start": self.identifier_start,
202                "identifier_end": self.identifier_end,
203                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
204                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
205                "index_offset": self.index_offset,
206                "time_mapping": self.inverse_time_mapping,
207                "time_trie": self.inverse_time_trie,
208                "unnest_column_only": self.unnest_column_only,
209                "alias_post_tablesample": self.alias_post_tablesample,
210                "normalize_functions": self.normalize_functions,
211                "null_ordering": self.null_ordering,
212                **opts,
213            }
214        )
215
216
217DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
218
219
220def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
221    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
222
223
224def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
225    if expression.args.get("accuracy"):
226        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
227    return self.func("APPROX_COUNT_DISTINCT", expression.this)
228
229
230def if_sql(self: Generator, expression: exp.If) -> str:
231    return self.func(
232        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
233    )
234
235
236def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
237    return self.binary(expression, "->")
238
239
240def arrow_json_extract_scalar_sql(
241    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
242) -> str:
243    return self.binary(expression, "->>")
244
245
246def inline_array_sql(self: Generator, expression: exp.Array) -> str:
247    return f"[{self.expressions(expression)}]"
248
249
250def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
251    return self.like_sql(
252        exp.Like(
253            this=exp.Lower(this=expression.this),
254            expression=expression.args["expression"],
255        )
256    )
257
258
259def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
260    zone = self.sql(expression, "this")
261    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
262
263
264def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
265    if expression.args.get("recursive"):
266        self.unsupported("Recursive CTEs are unsupported")
267        expression.args["recursive"] = False
268    return self.with_sql(expression)
269
270
271def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
272    n = self.sql(expression, "this")
273    d = self.sql(expression, "expression")
274    return f"IF({d} <> 0, {n} / {d}, NULL)"
275
276
277def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
278    self.unsupported("TABLESAMPLE unsupported")
279    return self.sql(expression.this)
280
281
282def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
283    self.unsupported("PIVOT unsupported")
284    return self.sql(expression)
285
286
287def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
288    return self.cast_sql(expression)
289
290
291def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
292    self.unsupported("Properties unsupported")
293    return ""
294
295
296def no_comment_column_constraint_sql(
297    self: Generator, expression: exp.CommentColumnConstraint
298) -> str:
299    self.unsupported("CommentColumnConstraint unsupported")
300    return ""
301
302
303def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
304    this = self.sql(expression, "this")
305    substr = self.sql(expression, "substr")
306    position = self.sql(expression, "position")
307    if position:
308        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
309    return f"STRPOS({this}, {substr})"
310
311
312def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
313    this = self.sql(expression, "this")
314    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
315    return f"{this}.{struct_key}"
316
317
318def var_map_sql(
319    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
320) -> str:
321    keys = expression.args["keys"]
322    values = expression.args["values"]
323
324    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
325        self.unsupported("Cannot convert array columns into map.")
326        return self.func(map_func_name, keys, values)
327
328    args = []
329    for key, value in zip(keys.expressions, values.expressions):
330        args.append(self.sql(key))
331        args.append(self.sql(value))
332    return self.func(map_func_name, *args)
333
334
335def format_time_lambda(
336    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
337) -> t.Callable[[t.Sequence], E]:
338    """Helper used for time expressions.
339
340    Args:
341        exp_class: the expression class to instantiate.
342        dialect: target sql dialect.
343        default: the default format, True being time.
344
345    Returns:
346        A callable that can be used to return the appropriately formatted time expression.
347    """
348
349    def _format_time(args: t.Sequence):
350        return exp_class(
351            this=seq_get(args, 0),
352            format=Dialect[dialect].format_time(
353                seq_get(args, 1)
354                or (Dialect[dialect].time_format if default is True else default or None)
355            ),
356        )
357
358    return _format_time
359
360
361def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
362    """
363    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
364    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
365    columns are removed from the create statement.
366    """
367    has_schema = isinstance(expression.this, exp.Schema)
368    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
369
370    if has_schema and is_partitionable:
371        expression = expression.copy()
372        prop = expression.find(exp.PartitionedByProperty)
373        if prop and prop.this and not isinstance(prop.this, exp.Schema):
374            schema = expression.this
375            columns = {v.name.upper() for v in prop.this.expressions}
376            partitions = [col for col in schema.expressions if col.name.upper() in columns]
377            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
378            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
379            expression.set("this", schema)
380
381    return self.create_sql(expression)
382
383
384def parse_date_delta(
385    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
386) -> t.Callable[[t.Sequence], E]:
387    def inner_func(args: t.Sequence) -> E:
388        unit_based = len(args) == 3
389        this = args[2] if unit_based else seq_get(args, 0)
390        unit = args[0] if unit_based else exp.Literal.string("DAY")
391        unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
392        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
393
394    return inner_func
395
396
397def parse_date_delta_with_interval(
398    expression_class: t.Type[E],
399) -> t.Callable[[t.Sequence], t.Optional[E]]:
400    def func(args: t.Sequence) -> t.Optional[E]:
401        if len(args) < 2:
402            return None
403
404        interval = args[1]
405        expression = interval.this
406        if expression and expression.is_string:
407            expression = exp.Literal.number(expression.this)
408
409        return expression_class(
410            this=args[0],
411            expression=expression,
412            unit=exp.Literal.string(interval.text("unit")),
413        )
414
415    return func
416
417
418def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
419    unit = seq_get(args, 0)
420    this = seq_get(args, 1)
421
422    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
423        return exp.DateTrunc(unit=unit, this=this)
424    return exp.TimestampTrunc(this=this, unit=unit)
425
426
427def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
428    return self.func(
429        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
430    )
431
432
433def locate_to_strposition(args: t.Sequence) -> exp.Expression:
434    return exp.StrPosition(
435        this=seq_get(args, 1),
436        substr=seq_get(args, 0),
437        position=seq_get(args, 2),
438    )
439
440
441def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
442    return self.func(
443        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
444    )
445
446
447def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
448    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
449
450
451def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
452    return f"CAST({self.sql(expression, 'this')} AS DATE)"
453
454
455def min_or_least(self: Generator, expression: exp.Min) -> str:
456    name = "LEAST" if expression.expressions else "MIN"
457    return rename_func(name)(self, expression)
458
459
460def max_or_greatest(self: Generator, expression: exp.Max) -> str:
461    name = "GREATEST" if expression.expressions else "MAX"
462    return rename_func(name)(self, expression)
463
464
465def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
466    cond = expression.this
467
468    if isinstance(expression.this, exp.Distinct):
469        cond = expression.this.expressions[0]
470        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
471
472    return self.func("sum", exp.func("if", cond, 1, 0))
473
474
475def trim_sql(self: Generator, expression: exp.Trim) -> str:
476    target = self.sql(expression, "this")
477    trim_type = self.sql(expression, "position")
478    remove_chars = self.sql(expression, "expression")
479    collation = self.sql(expression, "collation")
480
481    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
482    if not remove_chars and not collation:
483        return self.trim_sql(expression)
484
485    trim_type = f"{trim_type} " if trim_type else ""
486    remove_chars = f"{remove_chars} " if remove_chars else ""
487    from_part = "FROM " if trim_type or remove_chars else ""
488    collation = f" COLLATE {collation}" if collation else ""
489    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
490
491
492def str_to_time_sql(self, expression: exp.Expression) -> str:
493    return self.func("STRPTIME", expression.this, self.format_time(expression))
494
495
496def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
497    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
498        _dialect = Dialect.get_or_raise(dialect)
499        time_format = self.format_time(expression)
500        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
501            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
502        return f"CAST({self.sql(expression, 'this')} AS DATE)"
503
504    return _ts_or_ds_to_date_sql
class Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum):
19    DIALECT = ""
20
21    BIGQUERY = "bigquery"
22    CLICKHOUSE = "clickhouse"
23    DUCKDB = "duckdb"
24    HIVE = "hive"
25    MYSQL = "mysql"
26    ORACLE = "oracle"
27    POSTGRES = "postgres"
28    PRESTO = "presto"
29    REDSHIFT = "redshift"
30    SNOWFLAKE = "snowflake"
31    SPARK = "spark"
32    SQLITE = "sqlite"
33    STARROCKS = "starrocks"
34    TABLEAU = "tableau"
35    TRINO = "trino"
36    TSQL = "tsql"
37    DATABRICKS = "databricks"
38    DRILL = "drill"
39    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
101class Dialect(metaclass=_Dialect):
102    index_offset = 0
103    unnest_column_only = False
104    alias_post_tablesample = False
105    normalize_functions: t.Optional[str] = "upper"
106    null_ordering = "nulls_are_small"
107
108    date_format = "'%Y-%m-%d'"
109    dateint_format = "'%Y%m%d'"
110    time_format = "'%Y-%m-%d %H:%M:%S'"
111    time_mapping: t.Dict[str, str] = {}
112
113    # autofilled
114    quote_start = None
115    quote_end = None
116    identifier_start = None
117    identifier_end = None
118
119    time_trie = None
120    inverse_time_mapping = None
121    inverse_time_trie = None
122    tokenizer_class = None
123    parser_class = None
124    generator_class = None
125
126    @classmethod
127    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
128        if not dialect:
129            return cls
130        if isinstance(dialect, _Dialect):
131            return dialect
132        if isinstance(dialect, Dialect):
133            return dialect.__class__
134
135        result = cls.get(dialect)
136        if not result:
137            raise ValueError(f"Unknown dialect '{dialect}'")
138
139        return result
140
141    @classmethod
142    def format_time(
143        cls, expression: t.Optional[str | exp.Expression]
144    ) -> t.Optional[exp.Expression]:
145        if isinstance(expression, str):
146            return exp.Literal.string(
147                format_time(
148                    expression[1:-1],  # the time formats are quoted
149                    cls.time_mapping,
150                    cls.time_trie,
151                )
152            )
153        if expression and expression.is_string:
154            return exp.Literal.string(
155                format_time(
156                    expression.this,
157                    cls.time_mapping,
158                    cls.time_trie,
159                )
160            )
161        return expression
162
163    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
164        return self.parser(**opts).parse(self.tokenize(sql), sql)
165
166    def parse_into(
167        self, expression_type: exp.IntoType, sql: str, **opts
168    ) -> t.List[t.Optional[exp.Expression]]:
169        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
170
171    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
172        return self.generator(**opts).generate(expression)
173
174    def transpile(self, sql: str, **opts) -> t.List[str]:
175        return [self.generate(expression, **opts) for expression in self.parse(sql)]
176
177    def tokenize(self, sql: str) -> t.List[Token]:
178        return self.tokenizer.tokenize(sql)
179
180    @property
181    def tokenizer(self) -> Tokenizer:
182        if not hasattr(self, "_tokenizer"):
183            self._tokenizer = self.tokenizer_class()  # type: ignore
184        return self._tokenizer
185
186    def parser(self, **opts) -> Parser:
187        return self.parser_class(  # type: ignore
188            **{
189                "index_offset": self.index_offset,
190                "unnest_column_only": self.unnest_column_only,
191                "alias_post_tablesample": self.alias_post_tablesample,
192                "null_ordering": self.null_ordering,
193                **opts,
194            },
195        )
196
197    def generator(self, **opts) -> Generator:
198        return self.generator_class(  # type: ignore
199            **{
200                "quote_start": self.quote_start,
201                "quote_end": self.quote_end,
202                "identifier_start": self.identifier_start,
203                "identifier_end": self.identifier_end,
204                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
205                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
206                "index_offset": self.index_offset,
207                "time_mapping": self.inverse_time_mapping,
208                "time_trie": self.inverse_time_trie,
209                "unnest_column_only": self.unnest_column_only,
210                "alias_post_tablesample": self.alias_post_tablesample,
211                "normalize_functions": self.normalize_functions,
212                "null_ordering": self.null_ordering,
213                **opts,
214            }
215        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
126    @classmethod
127    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
128        if not dialect:
129            return cls
130        if isinstance(dialect, _Dialect):
131            return dialect
132        if isinstance(dialect, Dialect):
133            return dialect.__class__
134
135        result = cls.get(dialect)
136        if not result:
137            raise ValueError(f"Unknown dialect '{dialect}'")
138
139        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
141    @classmethod
142    def format_time(
143        cls, expression: t.Optional[str | exp.Expression]
144    ) -> t.Optional[exp.Expression]:
145        if isinstance(expression, str):
146            return exp.Literal.string(
147                format_time(
148                    expression[1:-1],  # the time formats are quoted
149                    cls.time_mapping,
150                    cls.time_trie,
151                )
152            )
153        if expression and expression.is_string:
154            return exp.Literal.string(
155                format_time(
156                    expression.this,
157                    cls.time_mapping,
158                    cls.time_trie,
159                )
160            )
161        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
163    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
164        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
166    def parse_into(
167        self, expression_type: exp.IntoType, sql: str, **opts
168    ) -> t.List[t.Optional[exp.Expression]]:
169        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
171    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
172        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
174    def transpile(self, sql: str, **opts) -> t.List[str]:
175        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
177    def tokenize(self, sql: str) -> t.List[Token]:
178        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
186    def parser(self, **opts) -> Parser:
187        return self.parser_class(  # type: ignore
188            **{
189                "index_offset": self.index_offset,
190                "unnest_column_only": self.unnest_column_only,
191                "alias_post_tablesample": self.alias_post_tablesample,
192                "null_ordering": self.null_ordering,
193                **opts,
194            },
195        )
def generator(self, **opts) -> sqlglot.generator.Generator:
197    def generator(self, **opts) -> Generator:
198        return self.generator_class(  # type: ignore
199            **{
200                "quote_start": self.quote_start,
201                "quote_end": self.quote_end,
202                "identifier_start": self.identifier_start,
203                "identifier_end": self.identifier_end,
204                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
205                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
206                "index_offset": self.index_offset,
207                "time_mapping": self.inverse_time_mapping,
208                "time_trie": self.inverse_time_trie,
209                "unnest_column_only": self.unnest_column_only,
210                "alias_post_tablesample": self.alias_post_tablesample,
211                "normalize_functions": self.normalize_functions,
212                "null_ordering": self.null_ordering,
213                **opts,
214            }
215        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
221def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
222    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
225def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
226    if expression.args.get("accuracy"):
227        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
228    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
231def if_sql(self: Generator, expression: exp.If) -> str:
232    return self.func(
233        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
234    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
237def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
238    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
241def arrow_json_extract_scalar_sql(
242    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
243) -> str:
244    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
247def inline_array_sql(self: Generator, expression: exp.Array) -> str:
248    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
251def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
252    return self.like_sql(
253        exp.Like(
254            this=exp.Lower(this=expression.this),
255            expression=expression.args["expression"],
256        )
257    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
260def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
261    zone = self.sql(expression, "this")
262    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
265def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
266    if expression.args.get("recursive"):
267        self.unsupported("Recursive CTEs are unsupported")
268        expression.args["recursive"] = False
269    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
272def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
273    n = self.sql(expression, "this")
274    d = self.sql(expression, "expression")
275    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
278def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
279    self.unsupported("TABLESAMPLE unsupported")
280    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
283def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
284    self.unsupported("PIVOT unsupported")
285    return self.sql(expression)
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
288def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
289    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
292def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
293    self.unsupported("Properties unsupported")
294    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
297def no_comment_column_constraint_sql(
298    self: Generator, expression: exp.CommentColumnConstraint
299) -> str:
300    self.unsupported("CommentColumnConstraint unsupported")
301    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
304def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
305    this = self.sql(expression, "this")
306    substr = self.sql(expression, "substr")
307    position = self.sql(expression, "position")
308    if position:
309        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
310    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
313def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
314    this = self.sql(expression, "this")
315    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
316    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
319def var_map_sql(
320    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
321) -> str:
322    keys = expression.args["keys"]
323    values = expression.args["values"]
324
325    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
326        self.unsupported("Cannot convert array columns into map.")
327        return self.func(map_func_name, keys, values)
328
329    args = []
330    for key, value in zip(keys.expressions, values.expressions):
331        args.append(self.sql(key))
332        args.append(self.sql(value))
333    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
336def format_time_lambda(
337    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
338) -> t.Callable[[t.Sequence], E]:
339    """Helper used for time expressions.
340
341    Args:
342        exp_class: the expression class to instantiate.
343        dialect: target sql dialect.
344        default: the default format, True being time.
345
346    Returns:
347        A callable that can be used to return the appropriately formatted time expression.
348    """
349
350    def _format_time(args: t.Sequence):
351        return exp_class(
352            this=seq_get(args, 0),
353            format=Dialect[dialect].format_time(
354                seq_get(args, 1)
355                or (Dialect[dialect].time_format if default is True else default or None)
356            ),
357        )
358
359    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
362def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
363    """
364    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
365    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
366    columns are removed from the create statement.
367    """
368    has_schema = isinstance(expression.this, exp.Schema)
369    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
370
371    if has_schema and is_partitionable:
372        expression = expression.copy()
373        prop = expression.find(exp.PartitionedByProperty)
374        if prop and prop.this and not isinstance(prop.this, exp.Schema):
375            schema = expression.this
376            columns = {v.name.upper() for v in prop.this.expressions}
377            partitions = [col for col in schema.expressions if col.name.upper() in columns]
378            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
379            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
380            expression.set("this", schema)
381
382    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
385def parse_date_delta(
386    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
387) -> t.Callable[[t.Sequence], E]:
388    def inner_func(args: t.Sequence) -> E:
389        unit_based = len(args) == 3
390        this = args[2] if unit_based else seq_get(args, 0)
391        unit = args[0] if unit_based else exp.Literal.string("DAY")
392        unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
393        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
394
395    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[Sequence], Optional[~E]]:
398def parse_date_delta_with_interval(
399    expression_class: t.Type[E],
400) -> t.Callable[[t.Sequence], t.Optional[E]]:
401    def func(args: t.Sequence) -> t.Optional[E]:
402        if len(args) < 2:
403            return None
404
405        interval = args[1]
406        expression = interval.this
407        if expression and expression.is_string:
408            expression = exp.Literal.number(expression.this)
409
410        return expression_class(
411            this=args[0],
412            expression=expression,
413            unit=exp.Literal.string(interval.text("unit")),
414        )
415
416    return func
def date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
419def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
420    unit = seq_get(args, 0)
421    this = seq_get(args, 1)
422
423    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
424        return exp.DateTrunc(unit=unit, this=this)
425    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
428def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
429    return self.func(
430        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
431    )
def locate_to_strposition(args: Sequence) -> sqlglot.expressions.Expression:
434def locate_to_strposition(args: t.Sequence) -> exp.Expression:
435    return exp.StrPosition(
436        this=seq_get(args, 1),
437        substr=seq_get(args, 0),
438        position=seq_get(args, 2),
439    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
442def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
443    return self.func(
444        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
445    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
448def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
449    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
452def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
453    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
456def min_or_least(self: Generator, expression: exp.Min) -> str:
457    name = "LEAST" if expression.expressions else "MIN"
458    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
461def max_or_greatest(self: Generator, expression: exp.Max) -> str:
462    name = "GREATEST" if expression.expressions else "MAX"
463    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
466def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
467    cond = expression.this
468
469    if isinstance(expression.this, exp.Distinct):
470        cond = expression.this.expressions[0]
471        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
472
473    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
476def trim_sql(self: Generator, expression: exp.Trim) -> str:
477    target = self.sql(expression, "this")
478    trim_type = self.sql(expression, "position")
479    remove_chars = self.sql(expression, "expression")
480    collation = self.sql(expression, "collation")
481
482    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
483    if not remove_chars and not collation:
484        return self.trim_sql(expression)
485
486    trim_type = f"{trim_type} " if trim_type else ""
487    remove_chars = f"{remove_chars} " if remove_chars else ""
488    from_part = "FROM " if trim_type or remove_chars else ""
489    collation = f" COLLATE {collation}" if collation else ""
490    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql(self, expression: sqlglot.expressions.Expression) -> str:
493def str_to_time_sql(self, expression: exp.Expression) -> str:
494    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
497def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
498    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
499        _dialect = Dialect.get_or_raise(dialect)
500        time_format = self.format_time(expression)
501        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
502            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
503        return f"CAST({self.sql(expression, 'this')} AS DATE)"
504
505    return _ts_or_ds_to_date_sql