Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot._typing import E
  8from sqlglot.generator import Generator
  9from sqlglot.helper import flatten, seq_get
 10from sqlglot.parser import Parser
 11from sqlglot.time import format_time
 12from sqlglot.tokens import Token, Tokenizer, TokenType
 13from sqlglot.trie import new_trie
 14
 15B = t.TypeVar("B", bound=exp.Binary)
 16
 17
 18class Dialects(str, Enum):
 19    DIALECT = ""
 20
 21    BIGQUERY = "bigquery"
 22    CLICKHOUSE = "clickhouse"
 23    DATABRICKS = "databricks"
 24    DRILL = "drill"
 25    DUCKDB = "duckdb"
 26    HIVE = "hive"
 27    MYSQL = "mysql"
 28    ORACLE = "oracle"
 29    POSTGRES = "postgres"
 30    PRESTO = "presto"
 31    REDSHIFT = "redshift"
 32    SNOWFLAKE = "snowflake"
 33    SPARK = "spark"
 34    SPARK2 = "spark2"
 35    SQLITE = "sqlite"
 36    STARROCKS = "starrocks"
 37    TABLEAU = "tableau"
 38    TERADATA = "teradata"
 39    TRINO = "trino"
 40    TSQL = "tsql"
 41
 42
 43class _Dialect(type):
 44    classes: t.Dict[str, t.Type[Dialect]] = {}
 45
 46    def __eq__(cls, other: t.Any) -> bool:
 47        if cls is other:
 48            return True
 49        if isinstance(other, str):
 50            return cls is cls.get(other)
 51        if isinstance(other, Dialect):
 52            return cls is type(other)
 53
 54        return False
 55
 56    def __hash__(cls) -> int:
 57        return hash(cls.__name__.lower())
 58
 59    @classmethod
 60    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 61        return cls.classes[key]
 62
 63    @classmethod
 64    def get(
 65        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 66    ) -> t.Optional[t.Type[Dialect]]:
 67        return cls.classes.get(key, default)
 68
 69    def __new__(cls, clsname, bases, attrs):
 70        klass = super().__new__(cls, clsname, bases, attrs)
 71        enum = Dialects.__members__.get(clsname.upper())
 72        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 73
 74        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 75        klass.FORMAT_TRIE = (
 76            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 77        )
 78        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 79        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 80
 81        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 82        klass.parser_class = getattr(klass, "Parser", Parser)
 83        klass.generator_class = getattr(klass, "Generator", Generator)
 84
 85        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 86        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 87            klass.tokenizer_class._IDENTIFIERS.items()
 88        )[0]
 89
 90        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 91            return next(
 92                (
 93                    (s, e)
 94                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 95                    if t == token_type
 96                ),
 97                (None, None),
 98            )
 99
100        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
101        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
102        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
103
104        dialect_properties = {
105            **{
106                k: v
107                for k, v in vars(klass).items()
108                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
109            },
110            "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
111            "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
112        }
113
114        if enum not in ("", "bigquery"):
115            dialect_properties["SELECT_KINDS"] = ()
116
117        # Pass required dialect properties to the tokenizer, parser and generator classes
118        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
119            for name, value in dialect_properties.items():
120                if hasattr(subclass, name):
121                    setattr(subclass, name, value)
122
123        if not klass.STRICT_STRING_CONCAT:
124            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
125
126        klass.generator_class.can_identify = klass.can_identify
127
128        return klass
129
130
131class Dialect(metaclass=_Dialect):
132    # Determines the base index offset for arrays
133    INDEX_OFFSET = 0
134
135    # If true unnest table aliases are considered only as column aliases
136    UNNEST_COLUMN_ONLY = False
137
138    # Determines whether or not the table alias comes after tablesample
139    ALIAS_POST_TABLESAMPLE = False
140
141    # Determines whether or not unquoted identifiers are resolved as uppercase
142    # When set to None, it means that the dialect treats all identifiers as case-insensitive
143    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
144
145    # Determines whether or not an unquoted identifier can start with a digit
146    IDENTIFIERS_CAN_START_WITH_DIGIT = False
147
148    # Determines whether or not CONCAT's arguments must be strings
149    STRICT_STRING_CONCAT = False
150
151    # Determines how function names are going to be normalized
152    NORMALIZE_FUNCTIONS: bool | str = "upper"
153
154    # Indicates the default null ordering method to use if not explicitly set
155    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
156    NULL_ORDERING = "nulls_are_small"
157
158    DATE_FORMAT = "'%Y-%m-%d'"
159    DATEINT_FORMAT = "'%Y%m%d'"
160    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
161
162    # Custom time mappings in which the key represents dialect time format
163    # and the value represents a python time format
164    TIME_MAPPING: t.Dict[str, str] = {}
165
166    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
167    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
168    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
169    FORMAT_MAPPING: t.Dict[str, str] = {}
170
171    # Autofilled
172    tokenizer_class = Tokenizer
173    parser_class = Parser
174    generator_class = Generator
175
176    # A trie of the time_mapping keys
177    TIME_TRIE: t.Dict = {}
178    FORMAT_TRIE: t.Dict = {}
179
180    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
181    INVERSE_TIME_TRIE: t.Dict = {}
182
183    def __eq__(self, other: t.Any) -> bool:
184        return type(self) == other
185
186    def __hash__(self) -> int:
187        return hash(type(self))
188
189    @classmethod
190    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
191        if not dialect:
192            return cls
193        if isinstance(dialect, _Dialect):
194            return dialect
195        if isinstance(dialect, Dialect):
196            return dialect.__class__
197
198        result = cls.get(dialect)
199        if not result:
200            raise ValueError(f"Unknown dialect '{dialect}'")
201
202        return result
203
204    @classmethod
205    def format_time(
206        cls, expression: t.Optional[str | exp.Expression]
207    ) -> t.Optional[exp.Expression]:
208        if isinstance(expression, str):
209            return exp.Literal.string(
210                # the time formats are quoted
211                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
212            )
213
214        if expression and expression.is_string:
215            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
216
217        return expression
218
219    @classmethod
220    def normalize_identifier(cls, expression: E) -> E:
221        """
222        Normalizes an unquoted identifier to either lower or upper case, thus essentially
223        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
224        they will be normalized regardless of being quoted or not.
225        """
226        if isinstance(expression, exp.Identifier) and (
227            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
228        ):
229            expression.set(
230                "this",
231                expression.this.upper()
232                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
233                else expression.this.lower(),
234            )
235
236        return expression
237
238    @classmethod
239    def case_sensitive(cls, text: str) -> bool:
240        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
241        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
242            return False
243
244        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
245        return any(unsafe(char) for char in text)
246
247    @classmethod
248    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
249        """Checks if text can be identified given an identify option.
250
251        Args:
252            text: The text to check.
253            identify:
254                "always" or `True`: Always returns true.
255                "safe": True if the identifier is case-insensitive.
256
257        Returns:
258            Whether or not the given text can be identified.
259        """
260        if identify is True or identify == "always":
261            return True
262
263        if identify == "safe":
264            return not cls.case_sensitive(text)
265
266        return False
267
268    @classmethod
269    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
270        if isinstance(expression, exp.Identifier):
271            name = expression.this
272            expression.set(
273                "quoted",
274                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
275            )
276
277        return expression
278
279    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
280        return self.parser(**opts).parse(self.tokenize(sql), sql)
281
282    def parse_into(
283        self, expression_type: exp.IntoType, sql: str, **opts
284    ) -> t.List[t.Optional[exp.Expression]]:
285        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
286
287    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
288        return self.generator(**opts).generate(expression)
289
290    def transpile(self, sql: str, **opts) -> t.List[str]:
291        return [self.generate(expression, **opts) for expression in self.parse(sql)]
292
293    def tokenize(self, sql: str) -> t.List[Token]:
294        return self.tokenizer.tokenize(sql)
295
296    @property
297    def tokenizer(self) -> Tokenizer:
298        if not hasattr(self, "_tokenizer"):
299            self._tokenizer = self.tokenizer_class()
300        return self._tokenizer
301
302    def parser(self, **opts) -> Parser:
303        return self.parser_class(**opts)
304
305    def generator(self, **opts) -> Generator:
306        return self.generator_class(**opts)
307
308
309DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
310
311
312def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
313    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
314
315
316def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
317    if expression.args.get("accuracy"):
318        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
319    return self.func("APPROX_COUNT_DISTINCT", expression.this)
320
321
322def if_sql(self: Generator, expression: exp.If) -> str:
323    return self.func(
324        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
325    )
326
327
328def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
329    return self.binary(expression, "->")
330
331
332def arrow_json_extract_scalar_sql(
333    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
334) -> str:
335    return self.binary(expression, "->>")
336
337
338def inline_array_sql(self: Generator, expression: exp.Array) -> str:
339    return f"[{self.expressions(expression)}]"
340
341
342def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
343    return self.like_sql(
344        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
345    )
346
347
348def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
349    zone = self.sql(expression, "this")
350    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
351
352
353def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
354    if expression.args.get("recursive"):
355        self.unsupported("Recursive CTEs are unsupported")
356        expression.args["recursive"] = False
357    return self.with_sql(expression)
358
359
360def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
361    n = self.sql(expression, "this")
362    d = self.sql(expression, "expression")
363    return f"IF({d} <> 0, {n} / {d}, NULL)"
364
365
366def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
367    self.unsupported("TABLESAMPLE unsupported")
368    return self.sql(expression.this)
369
370
371def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
372    self.unsupported("PIVOT unsupported")
373    return ""
374
375
376def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
377    return self.cast_sql(expression)
378
379
380def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
381    self.unsupported("Properties unsupported")
382    return ""
383
384
385def no_comment_column_constraint_sql(
386    self: Generator, expression: exp.CommentColumnConstraint
387) -> str:
388    self.unsupported("CommentColumnConstraint unsupported")
389    return ""
390
391
392def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
393    self.unsupported("MAP_FROM_ENTRIES unsupported")
394    return ""
395
396
397def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
398    this = self.sql(expression, "this")
399    substr = self.sql(expression, "substr")
400    position = self.sql(expression, "position")
401    if position:
402        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
403    return f"STRPOS({this}, {substr})"
404
405
406def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
407    this = self.sql(expression, "this")
408    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
409    return f"{this}.{struct_key}"
410
411
412def var_map_sql(
413    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
414) -> str:
415    keys = expression.args["keys"]
416    values = expression.args["values"]
417
418    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
419        self.unsupported("Cannot convert array columns into map.")
420        return self.func(map_func_name, keys, values)
421
422    args = []
423    for key, value in zip(keys.expressions, values.expressions):
424        args.append(self.sql(key))
425        args.append(self.sql(value))
426
427    return self.func(map_func_name, *args)
428
429
430def format_time_lambda(
431    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
432) -> t.Callable[[t.List], E]:
433    """Helper used for time expressions.
434
435    Args:
436        exp_class: the expression class to instantiate.
437        dialect: target sql dialect.
438        default: the default format, True being time.
439
440    Returns:
441        A callable that can be used to return the appropriately formatted time expression.
442    """
443
444    def _format_time(args: t.List):
445        return exp_class(
446            this=seq_get(args, 0),
447            format=Dialect[dialect].format_time(
448                seq_get(args, 1)
449                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
450            ),
451        )
452
453    return _format_time
454
455
456def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
457    """
458    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
459    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
460    columns are removed from the create statement.
461    """
462    has_schema = isinstance(expression.this, exp.Schema)
463    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
464
465    if has_schema and is_partitionable:
466        expression = expression.copy()
467        prop = expression.find(exp.PartitionedByProperty)
468        if prop and prop.this and not isinstance(prop.this, exp.Schema):
469            schema = expression.this
470            columns = {v.name.upper() for v in prop.this.expressions}
471            partitions = [col for col in schema.expressions if col.name.upper() in columns]
472            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
473            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
474            expression.set("this", schema)
475
476    return self.create_sql(expression)
477
478
479def parse_date_delta(
480    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
481) -> t.Callable[[t.List], E]:
482    def inner_func(args: t.List) -> E:
483        unit_based = len(args) == 3
484        this = args[2] if unit_based else seq_get(args, 0)
485        unit = args[0] if unit_based else exp.Literal.string("DAY")
486        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
487        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
488
489    return inner_func
490
491
492def parse_date_delta_with_interval(
493    expression_class: t.Type[E],
494) -> t.Callable[[t.List], t.Optional[E]]:
495    def func(args: t.List) -> t.Optional[E]:
496        if len(args) < 2:
497            return None
498
499        interval = args[1]
500        expression = interval.this
501        if expression and expression.is_string:
502            expression = exp.Literal.number(expression.this)
503
504        return expression_class(
505            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
506        )
507
508    return func
509
510
511def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
512    unit = seq_get(args, 0)
513    this = seq_get(args, 1)
514
515    if isinstance(this, exp.Cast) and this.is_type("date"):
516        return exp.DateTrunc(unit=unit, this=this)
517    return exp.TimestampTrunc(this=this, unit=unit)
518
519
520def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
521    return self.func(
522        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
523    )
524
525
526def locate_to_strposition(args: t.List) -> exp.Expression:
527    return exp.StrPosition(
528        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
529    )
530
531
532def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
533    return self.func(
534        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
535    )
536
537
538def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
539    expression = expression.copy()
540    return self.sql(
541        exp.Substring(
542            this=expression.this, start=exp.Literal.number(1), length=expression.expression
543        )
544    )
545
546
547def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
548    expression = expression.copy()
549    return self.sql(
550        exp.Substring(
551            this=expression.this,
552            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
553        )
554    )
555
556
557def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
558    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
559
560
561def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
562    return f"CAST({self.sql(expression, 'this')} AS DATE)"
563
564
565def min_or_least(self: Generator, expression: exp.Min) -> str:
566    name = "LEAST" if expression.expressions else "MIN"
567    return rename_func(name)(self, expression)
568
569
570def max_or_greatest(self: Generator, expression: exp.Max) -> str:
571    name = "GREATEST" if expression.expressions else "MAX"
572    return rename_func(name)(self, expression)
573
574
575def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
576    cond = expression.this
577
578    if isinstance(expression.this, exp.Distinct):
579        cond = expression.this.expressions[0]
580        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
581
582    return self.func("sum", exp.func("if", cond, 1, 0))
583
584
585def trim_sql(self: Generator, expression: exp.Trim) -> str:
586    target = self.sql(expression, "this")
587    trim_type = self.sql(expression, "position")
588    remove_chars = self.sql(expression, "expression")
589    collation = self.sql(expression, "collation")
590
591    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
592    if not remove_chars and not collation:
593        return self.trim_sql(expression)
594
595    trim_type = f"{trim_type} " if trim_type else ""
596    remove_chars = f"{remove_chars} " if remove_chars else ""
597    from_part = "FROM " if trim_type or remove_chars else ""
598    collation = f" COLLATE {collation}" if collation else ""
599    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
600
601
602def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
603    return self.func("STRPTIME", expression.this, self.format_time(expression))
604
605
606def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
607    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
608        _dialect = Dialect.get_or_raise(dialect)
609        time_format = self.format_time(expression)
610        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
611            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
612        return f"CAST({self.sql(expression, 'this')} AS DATE)"
613
614    return _ts_or_ds_to_date_sql
615
616
617def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
618    this, *rest_args = expression.expressions
619    for arg in rest_args:
620        this = exp.DPipe(this=this, expression=arg)
621
622    return self.sql(this)
623
624
625def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
626    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
627    if bad_args:
628        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
629
630    return self.func(
631        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
632    )
633
634
635def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
636    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
637    if bad_args:
638        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
639
640    return self.func(
641        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
642    )
643
644
645def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
646    names = []
647    for agg in aggregations:
648        if isinstance(agg, exp.Alias):
649            names.append(agg.alias)
650        else:
651            """
652            This case corresponds to aggregations without aliases being used as suffixes
653            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
654            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
655            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
656            """
657            agg_all_unquoted = agg.transform(
658                lambda node: exp.Identifier(this=node.name, quoted=False)
659                if isinstance(node, exp.Identifier)
660                else node
661            )
662            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
663
664    return names
665
666
667def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
668    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
class Dialects(builtins.str, enum.Enum):
19class Dialects(str, Enum):
20    DIALECT = ""
21
22    BIGQUERY = "bigquery"
23    CLICKHOUSE = "clickhouse"
24    DATABRICKS = "databricks"
25    DRILL = "drill"
26    DUCKDB = "duckdb"
27    HIVE = "hive"
28    MYSQL = "mysql"
29    ORACLE = "oracle"
30    POSTGRES = "postgres"
31    PRESTO = "presto"
32    REDSHIFT = "redshift"
33    SNOWFLAKE = "snowflake"
34    SPARK = "spark"
35    SPARK2 = "spark2"
36    SQLITE = "sqlite"
37    STARROCKS = "starrocks"
38    TABLEAU = "tableau"
39    TERADATA = "teradata"
40    TRINO = "trino"
41    TSQL = "tsql"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
132class Dialect(metaclass=_Dialect):
133    # Determines the base index offset for arrays
134    INDEX_OFFSET = 0
135
136    # If true unnest table aliases are considered only as column aliases
137    UNNEST_COLUMN_ONLY = False
138
139    # Determines whether or not the table alias comes after tablesample
140    ALIAS_POST_TABLESAMPLE = False
141
142    # Determines whether or not unquoted identifiers are resolved as uppercase
143    # When set to None, it means that the dialect treats all identifiers as case-insensitive
144    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
145
146    # Determines whether or not an unquoted identifier can start with a digit
147    IDENTIFIERS_CAN_START_WITH_DIGIT = False
148
149    # Determines whether or not CONCAT's arguments must be strings
150    STRICT_STRING_CONCAT = False
151
152    # Determines how function names are going to be normalized
153    NORMALIZE_FUNCTIONS: bool | str = "upper"
154
155    # Indicates the default null ordering method to use if not explicitly set
156    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
157    NULL_ORDERING = "nulls_are_small"
158
159    DATE_FORMAT = "'%Y-%m-%d'"
160    DATEINT_FORMAT = "'%Y%m%d'"
161    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
162
163    # Custom time mappings in which the key represents dialect time format
164    # and the value represents a python time format
165    TIME_MAPPING: t.Dict[str, str] = {}
166
167    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
168    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
169    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
170    FORMAT_MAPPING: t.Dict[str, str] = {}
171
172    # Autofilled
173    tokenizer_class = Tokenizer
174    parser_class = Parser
175    generator_class = Generator
176
177    # A trie of the time_mapping keys
178    TIME_TRIE: t.Dict = {}
179    FORMAT_TRIE: t.Dict = {}
180
181    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
182    INVERSE_TIME_TRIE: t.Dict = {}
183
184    def __eq__(self, other: t.Any) -> bool:
185        return type(self) == other
186
187    def __hash__(self) -> int:
188        return hash(type(self))
189
190    @classmethod
191    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
192        if not dialect:
193            return cls
194        if isinstance(dialect, _Dialect):
195            return dialect
196        if isinstance(dialect, Dialect):
197            return dialect.__class__
198
199        result = cls.get(dialect)
200        if not result:
201            raise ValueError(f"Unknown dialect '{dialect}'")
202
203        return result
204
205    @classmethod
206    def format_time(
207        cls, expression: t.Optional[str | exp.Expression]
208    ) -> t.Optional[exp.Expression]:
209        if isinstance(expression, str):
210            return exp.Literal.string(
211                # the time formats are quoted
212                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
213            )
214
215        if expression and expression.is_string:
216            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
217
218        return expression
219
220    @classmethod
221    def normalize_identifier(cls, expression: E) -> E:
222        """
223        Normalizes an unquoted identifier to either lower or upper case, thus essentially
224        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
225        they will be normalized regardless of being quoted or not.
226        """
227        if isinstance(expression, exp.Identifier) and (
228            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
229        ):
230            expression.set(
231                "this",
232                expression.this.upper()
233                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
234                else expression.this.lower(),
235            )
236
237        return expression
238
239    @classmethod
240    def case_sensitive(cls, text: str) -> bool:
241        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
242        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
243            return False
244
245        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
246        return any(unsafe(char) for char in text)
247
248    @classmethod
249    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
250        """Checks if text can be identified given an identify option.
251
252        Args:
253            text: The text to check.
254            identify:
255                "always" or `True`: Always returns true.
256                "safe": True if the identifier is case-insensitive.
257
258        Returns:
259            Whether or not the given text can be identified.
260        """
261        if identify is True or identify == "always":
262            return True
263
264        if identify == "safe":
265            return not cls.case_sensitive(text)
266
267        return False
268
269    @classmethod
270    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
271        if isinstance(expression, exp.Identifier):
272            name = expression.this
273            expression.set(
274                "quoted",
275                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
276            )
277
278        return expression
279
280    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
281        return self.parser(**opts).parse(self.tokenize(sql), sql)
282
283    def parse_into(
284        self, expression_type: exp.IntoType, sql: str, **opts
285    ) -> t.List[t.Optional[exp.Expression]]:
286        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
287
288    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
289        return self.generator(**opts).generate(expression)
290
291    def transpile(self, sql: str, **opts) -> t.List[str]:
292        return [self.generate(expression, **opts) for expression in self.parse(sql)]
293
294    def tokenize(self, sql: str) -> t.List[Token]:
295        return self.tokenizer.tokenize(sql)
296
297    @property
298    def tokenizer(self) -> Tokenizer:
299        if not hasattr(self, "_tokenizer"):
300            self._tokenizer = self.tokenizer_class()
301        return self._tokenizer
302
303    def parser(self, **opts) -> Parser:
304        return self.parser_class(**opts)
305
306    def generator(self, **opts) -> Generator:
307        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
190    @classmethod
191    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
192        if not dialect:
193            return cls
194        if isinstance(dialect, _Dialect):
195            return dialect
196        if isinstance(dialect, Dialect):
197            return dialect.__class__
198
199        result = cls.get(dialect)
200        if not result:
201            raise ValueError(f"Unknown dialect '{dialect}'")
202
203        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
205    @classmethod
206    def format_time(
207        cls, expression: t.Optional[str | exp.Expression]
208    ) -> t.Optional[exp.Expression]:
209        if isinstance(expression, str):
210            return exp.Literal.string(
211                # the time formats are quoted
212                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
213            )
214
215        if expression and expression.is_string:
216            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
217
218        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
220    @classmethod
221    def normalize_identifier(cls, expression: E) -> E:
222        """
223        Normalizes an unquoted identifier to either lower or upper case, thus essentially
224        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
225        they will be normalized regardless of being quoted or not.
226        """
227        if isinstance(expression, exp.Identifier) and (
228            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
229        ):
230            expression.set(
231                "this",
232                expression.this.upper()
233                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
234                else expression.this.lower(),
235            )
236
237        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
239    @classmethod
240    def case_sensitive(cls, text: str) -> bool:
241        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
242        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
243            return False
244
245        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
246        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
248    @classmethod
249    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
250        """Checks if text can be identified given an identify option.
251
252        Args:
253            text: The text to check.
254            identify:
255                "always" or `True`: Always returns true.
256                "safe": True if the identifier is case-insensitive.
257
258        Returns:
259            Whether or not the given text can be identified.
260        """
261        if identify is True or identify == "always":
262            return True
263
264        if identify == "safe":
265            return not cls.case_sensitive(text)
266
267        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
269    @classmethod
270    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
271        if isinstance(expression, exp.Identifier):
272            name = expression.this
273            expression.set(
274                "quoted",
275                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
276            )
277
278        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
280    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
281        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
283    def parse_into(
284        self, expression_type: exp.IntoType, sql: str, **opts
285    ) -> t.List[t.Optional[exp.Expression]]:
286        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
288    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
289        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
291    def transpile(self, sql: str, **opts) -> t.List[str]:
292        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
294    def tokenize(self, sql: str) -> t.List[Token]:
295        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
303    def parser(self, **opts) -> Parser:
304        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
306    def generator(self, **opts) -> Generator:
307        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, sqlglot.dialects.dialect.Dialect, typing.Type[sqlglot.dialects.dialect.Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
313def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
314    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
317def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
318    if expression.args.get("accuracy"):
319        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
320    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
323def if_sql(self: Generator, expression: exp.If) -> str:
324    return self.func(
325        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
326    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
329def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
330    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
333def arrow_json_extract_scalar_sql(
334    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
335) -> str:
336    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
339def inline_array_sql(self: Generator, expression: exp.Array) -> str:
340    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
343def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
344    return self.like_sql(
345        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
346    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
349def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
350    zone = self.sql(expression, "this")
351    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
354def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
355    if expression.args.get("recursive"):
356        self.unsupported("Recursive CTEs are unsupported")
357        expression.args["recursive"] = False
358    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
361def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
362    n = self.sql(expression, "this")
363    d = self.sql(expression, "expression")
364    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
367def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
368    self.unsupported("TABLESAMPLE unsupported")
369    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
372def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
373    self.unsupported("PIVOT unsupported")
374    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
377def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
378    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
381def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
382    self.unsupported("Properties unsupported")
383    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
386def no_comment_column_constraint_sql(
387    self: Generator, expression: exp.CommentColumnConstraint
388) -> str:
389    self.unsupported("CommentColumnConstraint unsupported")
390    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
393def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
394    self.unsupported("MAP_FROM_ENTRIES unsupported")
395    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
398def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
399    this = self.sql(expression, "this")
400    substr = self.sql(expression, "substr")
401    position = self.sql(expression, "position")
402    if position:
403        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
404    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
407def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
408    this = self.sql(expression, "this")
409    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
410    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
413def var_map_sql(
414    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
415) -> str:
416    keys = expression.args["keys"]
417    values = expression.args["values"]
418
419    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
420        self.unsupported("Cannot convert array columns into map.")
421        return self.func(map_func_name, keys, values)
422
423    args = []
424    for key, value in zip(keys.expressions, values.expressions):
425        args.append(self.sql(key))
426        args.append(self.sql(value))
427
428    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
431def format_time_lambda(
432    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
433) -> t.Callable[[t.List], E]:
434    """Helper used for time expressions.
435
436    Args:
437        exp_class: the expression class to instantiate.
438        dialect: target sql dialect.
439        default: the default format, True being time.
440
441    Returns:
442        A callable that can be used to return the appropriately formatted time expression.
443    """
444
445    def _format_time(args: t.List):
446        return exp_class(
447            this=seq_get(args, 0),
448            format=Dialect[dialect].format_time(
449                seq_get(args, 1)
450                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
451            ),
452        )
453
454    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
457def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
458    """
459    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
460    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
461    columns are removed from the create statement.
462    """
463    has_schema = isinstance(expression.this, exp.Schema)
464    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
465
466    if has_schema and is_partitionable:
467        expression = expression.copy()
468        prop = expression.find(exp.PartitionedByProperty)
469        if prop and prop.this and not isinstance(prop.this, exp.Schema):
470            schema = expression.this
471            columns = {v.name.upper() for v in prop.this.expressions}
472            partitions = [col for col in schema.expressions if col.name.upper() in columns]
473            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
474            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
475            expression.set("this", schema)
476
477    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
480def parse_date_delta(
481    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
482) -> t.Callable[[t.List], E]:
483    def inner_func(args: t.List) -> E:
484        unit_based = len(args) == 3
485        this = args[2] if unit_based else seq_get(args, 0)
486        unit = args[0] if unit_based else exp.Literal.string("DAY")
487        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
488        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
489
490    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
493def parse_date_delta_with_interval(
494    expression_class: t.Type[E],
495) -> t.Callable[[t.List], t.Optional[E]]:
496    def func(args: t.List) -> t.Optional[E]:
497        if len(args) < 2:
498            return None
499
500        interval = args[1]
501        expression = interval.this
502        if expression and expression.is_string:
503            expression = exp.Literal.number(expression.this)
504
505        return expression_class(
506            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
507        )
508
509    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
512def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
513    unit = seq_get(args, 0)
514    this = seq_get(args, 1)
515
516    if isinstance(this, exp.Cast) and this.is_type("date"):
517        return exp.DateTrunc(unit=unit, this=this)
518    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
521def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
522    return self.func(
523        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
524    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
527def locate_to_strposition(args: t.List) -> exp.Expression:
528    return exp.StrPosition(
529        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
530    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
533def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
534    return self.func(
535        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
536    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
539def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
540    expression = expression.copy()
541    return self.sql(
542        exp.Substring(
543            this=expression.this, start=exp.Literal.number(1), length=expression.expression
544        )
545    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
548def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
549    expression = expression.copy()
550    return self.sql(
551        exp.Substring(
552            this=expression.this,
553            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
554        )
555    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
558def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
559    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
562def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
563    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
566def min_or_least(self: Generator, expression: exp.Min) -> str:
567    name = "LEAST" if expression.expressions else "MIN"
568    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
571def max_or_greatest(self: Generator, expression: exp.Max) -> str:
572    name = "GREATEST" if expression.expressions else "MAX"
573    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
576def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
577    cond = expression.this
578
579    if isinstance(expression.this, exp.Distinct):
580        cond = expression.this.expressions[0]
581        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
582
583    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
586def trim_sql(self: Generator, expression: exp.Trim) -> str:
587    target = self.sql(expression, "this")
588    trim_type = self.sql(expression, "position")
589    remove_chars = self.sql(expression, "expression")
590    collation = self.sql(expression, "collation")
591
592    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
593    if not remove_chars and not collation:
594        return self.trim_sql(expression)
595
596    trim_type = f"{trim_type} " if trim_type else ""
597    remove_chars = f"{remove_chars} " if remove_chars else ""
598    from_part = "FROM " if trim_type or remove_chars else ""
599    collation = f" COLLATE {collation}" if collation else ""
600    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
603def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
604    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
607def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
608    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
609        _dialect = Dialect.get_or_raise(dialect)
610        time_format = self.format_time(expression)
611        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
612            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
613        return f"CAST({self.sql(expression, 'this')} AS DATE)"
614
615    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
618def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
619    this, *rest_args = expression.expressions
620    for arg in rest_args:
621        this = exp.DPipe(this=this, expression=arg)
622
623    return self.sql(this)
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
626def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
627    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
628    if bad_args:
629        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
630
631    return self.func(
632        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
633    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
636def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
637    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
638    if bad_args:
639        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
640
641    return self.func(
642        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
643    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
646def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
647    names = []
648    for agg in aggregations:
649        if isinstance(agg, exp.Alias):
650            names.append(agg.alias)
651        else:
652            """
653            This case corresponds to aggregations without aliases being used as suffixes
654            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
655            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
656            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
657            """
658            agg_all_unquoted = agg.transform(
659                lambda node: exp.Identifier(this=node.name, quoted=False)
660                if isinstance(node, exp.Identifier)
661                else node
662            )
663            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
664
665    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
668def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
669    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))