Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot._typing import E
  8from sqlglot.errors import ParseError
  9from sqlglot.generator import Generator
 10from sqlglot.helper import flatten, seq_get
 11from sqlglot.parser import Parser
 12from sqlglot.time import format_time
 13from sqlglot.tokens import Token, Tokenizer, TokenType
 14from sqlglot.trie import new_trie
 15
 16B = t.TypeVar("B", bound=exp.Binary)
 17
 18
 19class Dialects(str, Enum):
 20    DIALECT = ""
 21
 22    BIGQUERY = "bigquery"
 23    CLICKHOUSE = "clickhouse"
 24    DATABRICKS = "databricks"
 25    DRILL = "drill"
 26    DUCKDB = "duckdb"
 27    HIVE = "hive"
 28    MYSQL = "mysql"
 29    ORACLE = "oracle"
 30    POSTGRES = "postgres"
 31    PRESTO = "presto"
 32    REDSHIFT = "redshift"
 33    SNOWFLAKE = "snowflake"
 34    SPARK = "spark"
 35    SPARK2 = "spark2"
 36    SQLITE = "sqlite"
 37    STARROCKS = "starrocks"
 38    TABLEAU = "tableau"
 39    TERADATA = "teradata"
 40    TRINO = "trino"
 41    TSQL = "tsql"
 42
 43
 44class _Dialect(type):
 45    classes: t.Dict[str, t.Type[Dialect]] = {}
 46
 47    def __eq__(cls, other: t.Any) -> bool:
 48        if cls is other:
 49            return True
 50        if isinstance(other, str):
 51            return cls is cls.get(other)
 52        if isinstance(other, Dialect):
 53            return cls is type(other)
 54
 55        return False
 56
 57    def __hash__(cls) -> int:
 58        return hash(cls.__name__.lower())
 59
 60    @classmethod
 61    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 62        return cls.classes[key]
 63
 64    @classmethod
 65    def get(
 66        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 67    ) -> t.Optional[t.Type[Dialect]]:
 68        return cls.classes.get(key, default)
 69
 70    def __new__(cls, clsname, bases, attrs):
 71        klass = super().__new__(cls, clsname, bases, attrs)
 72        enum = Dialects.__members__.get(clsname.upper())
 73        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 74
 75        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 76        klass.FORMAT_TRIE = (
 77            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 78        )
 79        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 80        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 81
 82        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 83        klass.parser_class = getattr(klass, "Parser", Parser)
 84        klass.generator_class = getattr(klass, "Generator", Generator)
 85
 86        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 87        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 88            klass.tokenizer_class._IDENTIFIERS.items()
 89        )[0]
 90
 91        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 92            return next(
 93                (
 94                    (s, e)
 95                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 96                    if t == token_type
 97                ),
 98                (None, None),
 99            )
100
101        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
102        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
103        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
104
105        dialect_properties = {
106            **{
107                k: v
108                for k, v in vars(klass).items()
109                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
110            },
111            "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
112            "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
113        }
114
115        if enum not in ("", "bigquery"):
116            dialect_properties["SELECT_KINDS"] = ()
117
118        # Pass required dialect properties to the tokenizer, parser and generator classes
119        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
120            for name, value in dialect_properties.items():
121                if hasattr(subclass, name):
122                    setattr(subclass, name, value)
123
124        if not klass.STRICT_STRING_CONCAT:
125            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
126
127        klass.generator_class.can_identify = klass.can_identify
128
129        return klass
130
131
132class Dialect(metaclass=_Dialect):
133    # Determines the base index offset for arrays
134    INDEX_OFFSET = 0
135
136    # If true unnest table aliases are considered only as column aliases
137    UNNEST_COLUMN_ONLY = False
138
139    # Determines whether or not the table alias comes after tablesample
140    ALIAS_POST_TABLESAMPLE = False
141
142    # Determines whether or not unquoted identifiers are resolved as uppercase
143    # When set to None, it means that the dialect treats all identifiers as case-insensitive
144    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
145
146    # Determines whether or not an unquoted identifier can start with a digit
147    IDENTIFIERS_CAN_START_WITH_DIGIT = False
148
149    # Determines whether or not CONCAT's arguments must be strings
150    STRICT_STRING_CONCAT = False
151
152    # Determines how function names are going to be normalized
153    NORMALIZE_FUNCTIONS: bool | str = "upper"
154
155    # Indicates the default null ordering method to use if not explicitly set
156    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
157    NULL_ORDERING = "nulls_are_small"
158
159    DATE_FORMAT = "'%Y-%m-%d'"
160    DATEINT_FORMAT = "'%Y%m%d'"
161    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
162
163    # Custom time mappings in which the key represents dialect time format
164    # and the value represents a python time format
165    TIME_MAPPING: t.Dict[str, str] = {}
166
167    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
168    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
169    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
170    FORMAT_MAPPING: t.Dict[str, str] = {}
171
172    # Columns that are auto-generated by the engine corresponding to this dialect
173    # Such columns may be excluded from SELECT * queries, for example
174    PSEUDOCOLUMNS: t.Set[str] = set()
175
176    # Autofilled
177    tokenizer_class = Tokenizer
178    parser_class = Parser
179    generator_class = Generator
180
181    # A trie of the time_mapping keys
182    TIME_TRIE: t.Dict = {}
183    FORMAT_TRIE: t.Dict = {}
184
185    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
186    INVERSE_TIME_TRIE: t.Dict = {}
187
188    def __eq__(self, other: t.Any) -> bool:
189        return type(self) == other
190
191    def __hash__(self) -> int:
192        return hash(type(self))
193
194    @classmethod
195    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
196        if not dialect:
197            return cls
198        if isinstance(dialect, _Dialect):
199            return dialect
200        if isinstance(dialect, Dialect):
201            return dialect.__class__
202
203        result = cls.get(dialect)
204        if not result:
205            raise ValueError(f"Unknown dialect '{dialect}'")
206
207        return result
208
209    @classmethod
210    def format_time(
211        cls, expression: t.Optional[str | exp.Expression]
212    ) -> t.Optional[exp.Expression]:
213        if isinstance(expression, str):
214            return exp.Literal.string(
215                # the time formats are quoted
216                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
217            )
218
219        if expression and expression.is_string:
220            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
221
222        return expression
223
224    @classmethod
225    def normalize_identifier(cls, expression: E) -> E:
226        """
227        Normalizes an unquoted identifier to either lower or upper case, thus essentially
228        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
229        they will be normalized regardless of being quoted or not.
230        """
231        if isinstance(expression, exp.Identifier) and (
232            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
233        ):
234            expression.set(
235                "this",
236                expression.this.upper()
237                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
238                else expression.this.lower(),
239            )
240
241        return expression
242
243    @classmethod
244    def case_sensitive(cls, text: str) -> bool:
245        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
246        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
247            return False
248
249        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
250        return any(unsafe(char) for char in text)
251
252    @classmethod
253    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
254        """Checks if text can be identified given an identify option.
255
256        Args:
257            text: The text to check.
258            identify:
259                "always" or `True`: Always returns true.
260                "safe": True if the identifier is case-insensitive.
261
262        Returns:
263            Whether or not the given text can be identified.
264        """
265        if identify is True or identify == "always":
266            return True
267
268        if identify == "safe":
269            return not cls.case_sensitive(text)
270
271        return False
272
273    @classmethod
274    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
275        if isinstance(expression, exp.Identifier):
276            name = expression.this
277            expression.set(
278                "quoted",
279                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
280            )
281
282        return expression
283
284    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
285        return self.parser(**opts).parse(self.tokenize(sql), sql)
286
287    def parse_into(
288        self, expression_type: exp.IntoType, sql: str, **opts
289    ) -> t.List[t.Optional[exp.Expression]]:
290        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
291
292    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
293        return self.generator(**opts).generate(expression)
294
295    def transpile(self, sql: str, **opts) -> t.List[str]:
296        return [self.generate(expression, **opts) for expression in self.parse(sql)]
297
298    def tokenize(self, sql: str) -> t.List[Token]:
299        return self.tokenizer.tokenize(sql)
300
301    @property
302    def tokenizer(self) -> Tokenizer:
303        if not hasattr(self, "_tokenizer"):
304            self._tokenizer = self.tokenizer_class()
305        return self._tokenizer
306
307    def parser(self, **opts) -> Parser:
308        return self.parser_class(**opts)
309
310    def generator(self, **opts) -> Generator:
311        return self.generator_class(**opts)
312
313
314DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
315
316
317def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
318    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
319
320
321def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
322    if expression.args.get("accuracy"):
323        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
324    return self.func("APPROX_COUNT_DISTINCT", expression.this)
325
326
327def if_sql(self: Generator, expression: exp.If) -> str:
328    return self.func(
329        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
330    )
331
332
333def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
334    return self.binary(expression, "->")
335
336
337def arrow_json_extract_scalar_sql(
338    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
339) -> str:
340    return self.binary(expression, "->>")
341
342
343def inline_array_sql(self: Generator, expression: exp.Array) -> str:
344    return f"[{self.expressions(expression)}]"
345
346
347def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
348    return self.like_sql(
349        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
350    )
351
352
353def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
354    zone = self.sql(expression, "this")
355    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
356
357
358def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
359    if expression.args.get("recursive"):
360        self.unsupported("Recursive CTEs are unsupported")
361        expression.args["recursive"] = False
362    return self.with_sql(expression)
363
364
365def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
366    n = self.sql(expression, "this")
367    d = self.sql(expression, "expression")
368    return f"IF({d} <> 0, {n} / {d}, NULL)"
369
370
371def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
372    self.unsupported("TABLESAMPLE unsupported")
373    return self.sql(expression.this)
374
375
376def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
377    self.unsupported("PIVOT unsupported")
378    return ""
379
380
381def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
382    return self.cast_sql(expression)
383
384
385def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
386    self.unsupported("Properties unsupported")
387    return ""
388
389
390def no_comment_column_constraint_sql(
391    self: Generator, expression: exp.CommentColumnConstraint
392) -> str:
393    self.unsupported("CommentColumnConstraint unsupported")
394    return ""
395
396
397def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
398    self.unsupported("MAP_FROM_ENTRIES unsupported")
399    return ""
400
401
402def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
403    this = self.sql(expression, "this")
404    substr = self.sql(expression, "substr")
405    position = self.sql(expression, "position")
406    if position:
407        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
408    return f"STRPOS({this}, {substr})"
409
410
411def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
412    this = self.sql(expression, "this")
413    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
414    return f"{this}.{struct_key}"
415
416
417def var_map_sql(
418    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
419) -> str:
420    keys = expression.args["keys"]
421    values = expression.args["values"]
422
423    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
424        self.unsupported("Cannot convert array columns into map.")
425        return self.func(map_func_name, keys, values)
426
427    args = []
428    for key, value in zip(keys.expressions, values.expressions):
429        args.append(self.sql(key))
430        args.append(self.sql(value))
431
432    return self.func(map_func_name, *args)
433
434
435def format_time_lambda(
436    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
437) -> t.Callable[[t.List], E]:
438    """Helper used for time expressions.
439
440    Args:
441        exp_class: the expression class to instantiate.
442        dialect: target sql dialect.
443        default: the default format, True being time.
444
445    Returns:
446        A callable that can be used to return the appropriately formatted time expression.
447    """
448
449    def _format_time(args: t.List):
450        return exp_class(
451            this=seq_get(args, 0),
452            format=Dialect[dialect].format_time(
453                seq_get(args, 1)
454                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
455            ),
456        )
457
458    return _format_time
459
460
461def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
462    """
463    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
464    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
465    columns are removed from the create statement.
466    """
467    has_schema = isinstance(expression.this, exp.Schema)
468    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
469
470    if has_schema and is_partitionable:
471        expression = expression.copy()
472        prop = expression.find(exp.PartitionedByProperty)
473        if prop and prop.this and not isinstance(prop.this, exp.Schema):
474            schema = expression.this
475            columns = {v.name.upper() for v in prop.this.expressions}
476            partitions = [col for col in schema.expressions if col.name.upper() in columns]
477            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
478            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
479            expression.set("this", schema)
480
481    return self.create_sql(expression)
482
483
484def parse_date_delta(
485    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
486) -> t.Callable[[t.List], E]:
487    def inner_func(args: t.List) -> E:
488        unit_based = len(args) == 3
489        this = args[2] if unit_based else seq_get(args, 0)
490        unit = args[0] if unit_based else exp.Literal.string("DAY")
491        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
492        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
493
494    return inner_func
495
496
497def parse_date_delta_with_interval(
498    expression_class: t.Type[E],
499) -> t.Callable[[t.List], t.Optional[E]]:
500    def func(args: t.List) -> t.Optional[E]:
501        if len(args) < 2:
502            return None
503
504        interval = args[1]
505
506        if not isinstance(interval, exp.Interval):
507            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
508
509        expression = interval.this
510        if expression and expression.is_string:
511            expression = exp.Literal.number(expression.this)
512
513        return expression_class(
514            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
515        )
516
517    return func
518
519
520def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
521    unit = seq_get(args, 0)
522    this = seq_get(args, 1)
523
524    if isinstance(this, exp.Cast) and this.is_type("date"):
525        return exp.DateTrunc(unit=unit, this=this)
526    return exp.TimestampTrunc(this=this, unit=unit)
527
528
529def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
530    return self.func(
531        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
532    )
533
534
535def locate_to_strposition(args: t.List) -> exp.Expression:
536    return exp.StrPosition(
537        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
538    )
539
540
541def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
542    return self.func(
543        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
544    )
545
546
547def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
548    expression = expression.copy()
549    return self.sql(
550        exp.Substring(
551            this=expression.this, start=exp.Literal.number(1), length=expression.expression
552        )
553    )
554
555
556def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
557    expression = expression.copy()
558    return self.sql(
559        exp.Substring(
560            this=expression.this,
561            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
562        )
563    )
564
565
566def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
567    return self.sql(exp.cast(expression.this, "timestamp"))
568
569
570def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
571    return self.sql(exp.cast(expression.this, "date"))
572
573
574def min_or_least(self: Generator, expression: exp.Min) -> str:
575    name = "LEAST" if expression.expressions else "MIN"
576    return rename_func(name)(self, expression)
577
578
579def max_or_greatest(self: Generator, expression: exp.Max) -> str:
580    name = "GREATEST" if expression.expressions else "MAX"
581    return rename_func(name)(self, expression)
582
583
584def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
585    cond = expression.this
586
587    if isinstance(expression.this, exp.Distinct):
588        cond = expression.this.expressions[0]
589        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
590
591    return self.func("sum", exp.func("if", cond, 1, 0))
592
593
594def trim_sql(self: Generator, expression: exp.Trim) -> str:
595    target = self.sql(expression, "this")
596    trim_type = self.sql(expression, "position")
597    remove_chars = self.sql(expression, "expression")
598    collation = self.sql(expression, "collation")
599
600    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
601    if not remove_chars and not collation:
602        return self.trim_sql(expression)
603
604    trim_type = f"{trim_type} " if trim_type else ""
605    remove_chars = f"{remove_chars} " if remove_chars else ""
606    from_part = "FROM " if trim_type or remove_chars else ""
607    collation = f" COLLATE {collation}" if collation else ""
608    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
609
610
611def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
612    return self.func("STRPTIME", expression.this, self.format_time(expression))
613
614
615def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
616    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
617        _dialect = Dialect.get_or_raise(dialect)
618        time_format = self.format_time(expression)
619        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
620            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
621
622        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
623
624    return _ts_or_ds_to_date_sql
625
626
627def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
628    this, *rest_args = expression.expressions
629    for arg in rest_args:
630        this = exp.DPipe(this=this, expression=arg)
631
632    return self.sql(this)
633
634
635def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
636    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
637    if bad_args:
638        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
639
640    return self.func(
641        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
642    )
643
644
645def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
646    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
647    if bad_args:
648        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
649
650    return self.func(
651        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
652    )
653
654
655def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
656    names = []
657    for agg in aggregations:
658        if isinstance(agg, exp.Alias):
659            names.append(agg.alias)
660        else:
661            """
662            This case corresponds to aggregations without aliases being used as suffixes
663            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
664            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
665            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
666            """
667            agg_all_unquoted = agg.transform(
668                lambda node: exp.Identifier(this=node.name, quoted=False)
669                if isinstance(node, exp.Identifier)
670                else node
671            )
672            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
673
674    return names
675
676
677def simplify_literal(expression: E, copy: bool = True) -> E:
678    if not isinstance(expression.expression, exp.Literal):
679        from sqlglot.optimizer.simplify import simplify
680
681        expression = exp.maybe_copy(expression, copy)
682        simplify(expression.expression)
683
684    return expression
685
686
687def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
688    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
class Dialects(builtins.str, enum.Enum):
20class Dialects(str, Enum):
21    DIALECT = ""
22
23    BIGQUERY = "bigquery"
24    CLICKHOUSE = "clickhouse"
25    DATABRICKS = "databricks"
26    DRILL = "drill"
27    DUCKDB = "duckdb"
28    HIVE = "hive"
29    MYSQL = "mysql"
30    ORACLE = "oracle"
31    POSTGRES = "postgres"
32    PRESTO = "presto"
33    REDSHIFT = "redshift"
34    SNOWFLAKE = "snowflake"
35    SPARK = "spark"
36    SPARK2 = "spark2"
37    SQLITE = "sqlite"
38    STARROCKS = "starrocks"
39    TABLEAU = "tableau"
40    TERADATA = "teradata"
41    TRINO = "trino"
42    TSQL = "tsql"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Determines whether or not unquoted identifiers are resolved as uppercase
144    # When set to None, it means that the dialect treats all identifiers as case-insensitive
145    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not CONCAT's arguments must be strings
151    STRICT_STRING_CONCAT = False
152
153    # Determines how function names are going to be normalized
154    NORMALIZE_FUNCTIONS: bool | str = "upper"
155
156    # Indicates the default null ordering method to use if not explicitly set
157    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
158    NULL_ORDERING = "nulls_are_small"
159
160    DATE_FORMAT = "'%Y-%m-%d'"
161    DATEINT_FORMAT = "'%Y%m%d'"
162    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
163
164    # Custom time mappings in which the key represents dialect time format
165    # and the value represents a python time format
166    TIME_MAPPING: t.Dict[str, str] = {}
167
168    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
169    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
170    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
171    FORMAT_MAPPING: t.Dict[str, str] = {}
172
173    # Columns that are auto-generated by the engine corresponding to this dialect
174    # Such columns may be excluded from SELECT * queries, for example
175    PSEUDOCOLUMNS: t.Set[str] = set()
176
177    # Autofilled
178    tokenizer_class = Tokenizer
179    parser_class = Parser
180    generator_class = Generator
181
182    # A trie of the time_mapping keys
183    TIME_TRIE: t.Dict = {}
184    FORMAT_TRIE: t.Dict = {}
185
186    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
187    INVERSE_TIME_TRIE: t.Dict = {}
188
189    def __eq__(self, other: t.Any) -> bool:
190        return type(self) == other
191
192    def __hash__(self) -> int:
193        return hash(type(self))
194
195    @classmethod
196    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
197        if not dialect:
198            return cls
199        if isinstance(dialect, _Dialect):
200            return dialect
201        if isinstance(dialect, Dialect):
202            return dialect.__class__
203
204        result = cls.get(dialect)
205        if not result:
206            raise ValueError(f"Unknown dialect '{dialect}'")
207
208        return result
209
210    @classmethod
211    def format_time(
212        cls, expression: t.Optional[str | exp.Expression]
213    ) -> t.Optional[exp.Expression]:
214        if isinstance(expression, str):
215            return exp.Literal.string(
216                # the time formats are quoted
217                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
218            )
219
220        if expression and expression.is_string:
221            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
222
223        return expression
224
225    @classmethod
226    def normalize_identifier(cls, expression: E) -> E:
227        """
228        Normalizes an unquoted identifier to either lower or upper case, thus essentially
229        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
230        they will be normalized regardless of being quoted or not.
231        """
232        if isinstance(expression, exp.Identifier) and (
233            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
234        ):
235            expression.set(
236                "this",
237                expression.this.upper()
238                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
239                else expression.this.lower(),
240            )
241
242        return expression
243
244    @classmethod
245    def case_sensitive(cls, text: str) -> bool:
246        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
247        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
248            return False
249
250        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
251        return any(unsafe(char) for char in text)
252
253    @classmethod
254    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
255        """Checks if text can be identified given an identify option.
256
257        Args:
258            text: The text to check.
259            identify:
260                "always" or `True`: Always returns true.
261                "safe": True if the identifier is case-insensitive.
262
263        Returns:
264            Whether or not the given text can be identified.
265        """
266        if identify is True or identify == "always":
267            return True
268
269        if identify == "safe":
270            return not cls.case_sensitive(text)
271
272        return False
273
274    @classmethod
275    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
276        if isinstance(expression, exp.Identifier):
277            name = expression.this
278            expression.set(
279                "quoted",
280                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
281            )
282
283        return expression
284
285    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
286        return self.parser(**opts).parse(self.tokenize(sql), sql)
287
288    def parse_into(
289        self, expression_type: exp.IntoType, sql: str, **opts
290    ) -> t.List[t.Optional[exp.Expression]]:
291        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
292
293    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
294        return self.generator(**opts).generate(expression)
295
296    def transpile(self, sql: str, **opts) -> t.List[str]:
297        return [self.generate(expression, **opts) for expression in self.parse(sql)]
298
299    def tokenize(self, sql: str) -> t.List[Token]:
300        return self.tokenizer.tokenize(sql)
301
302    @property
303    def tokenizer(self) -> Tokenizer:
304        if not hasattr(self, "_tokenizer"):
305            self._tokenizer = self.tokenizer_class()
306        return self._tokenizer
307
308    def parser(self, **opts) -> Parser:
309        return self.parser_class(**opts)
310
311    def generator(self, **opts) -> Generator:
312        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
195    @classmethod
196    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
197        if not dialect:
198            return cls
199        if isinstance(dialect, _Dialect):
200            return dialect
201        if isinstance(dialect, Dialect):
202            return dialect.__class__
203
204        result = cls.get(dialect)
205        if not result:
206            raise ValueError(f"Unknown dialect '{dialect}'")
207
208        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
210    @classmethod
211    def format_time(
212        cls, expression: t.Optional[str | exp.Expression]
213    ) -> t.Optional[exp.Expression]:
214        if isinstance(expression, str):
215            return exp.Literal.string(
216                # the time formats are quoted
217                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
218            )
219
220        if expression and expression.is_string:
221            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
222
223        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
225    @classmethod
226    def normalize_identifier(cls, expression: E) -> E:
227        """
228        Normalizes an unquoted identifier to either lower or upper case, thus essentially
229        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
230        they will be normalized regardless of being quoted or not.
231        """
232        if isinstance(expression, exp.Identifier) and (
233            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
234        ):
235            expression.set(
236                "this",
237                expression.this.upper()
238                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
239                else expression.this.lower(),
240            )
241
242        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
244    @classmethod
245    def case_sensitive(cls, text: str) -> bool:
246        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
247        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
248            return False
249
250        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
251        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
253    @classmethod
254    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
255        """Checks if text can be identified given an identify option.
256
257        Args:
258            text: The text to check.
259            identify:
260                "always" or `True`: Always returns true.
261                "safe": True if the identifier is case-insensitive.
262
263        Returns:
264            Whether or not the given text can be identified.
265        """
266        if identify is True or identify == "always":
267            return True
268
269        if identify == "safe":
270            return not cls.case_sensitive(text)
271
272        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
274    @classmethod
275    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
276        if isinstance(expression, exp.Identifier):
277            name = expression.this
278            expression.set(
279                "quoted",
280                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
281            )
282
283        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
285    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
286        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
288    def parse_into(
289        self, expression_type: exp.IntoType, sql: str, **opts
290    ) -> t.List[t.Optional[exp.Expression]]:
291        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
293    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
294        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
296    def transpile(self, sql: str, **opts) -> t.List[str]:
297        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
299    def tokenize(self, sql: str) -> t.List[Token]:
300        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
308    def parser(self, **opts) -> Parser:
309        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
311    def generator(self, **opts) -> Generator:
312        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, sqlglot.dialects.dialect.Dialect, typing.Type[sqlglot.dialects.dialect.Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
318def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
319    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
322def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
323    if expression.args.get("accuracy"):
324        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
325    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
328def if_sql(self: Generator, expression: exp.If) -> str:
329    return self.func(
330        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
331    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
334def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
335    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
338def arrow_json_extract_scalar_sql(
339    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
340) -> str:
341    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
344def inline_array_sql(self: Generator, expression: exp.Array) -> str:
345    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
348def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
349    return self.like_sql(
350        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
351    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
354def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
355    zone = self.sql(expression, "this")
356    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
359def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
360    if expression.args.get("recursive"):
361        self.unsupported("Recursive CTEs are unsupported")
362        expression.args["recursive"] = False
363    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
366def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
367    n = self.sql(expression, "this")
368    d = self.sql(expression, "expression")
369    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
372def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
373    self.unsupported("TABLESAMPLE unsupported")
374    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
377def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
378    self.unsupported("PIVOT unsupported")
379    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
382def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
383    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
386def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
387    self.unsupported("Properties unsupported")
388    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
391def no_comment_column_constraint_sql(
392    self: Generator, expression: exp.CommentColumnConstraint
393) -> str:
394    self.unsupported("CommentColumnConstraint unsupported")
395    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
398def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
399    self.unsupported("MAP_FROM_ENTRIES unsupported")
400    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
403def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
404    this = self.sql(expression, "this")
405    substr = self.sql(expression, "substr")
406    position = self.sql(expression, "position")
407    if position:
408        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
409    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
412def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
413    this = self.sql(expression, "this")
414    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
415    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
418def var_map_sql(
419    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
420) -> str:
421    keys = expression.args["keys"]
422    values = expression.args["values"]
423
424    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
425        self.unsupported("Cannot convert array columns into map.")
426        return self.func(map_func_name, keys, values)
427
428    args = []
429    for key, value in zip(keys.expressions, values.expressions):
430        args.append(self.sql(key))
431        args.append(self.sql(value))
432
433    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
436def format_time_lambda(
437    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
438) -> t.Callable[[t.List], E]:
439    """Helper used for time expressions.
440
441    Args:
442        exp_class: the expression class to instantiate.
443        dialect: target sql dialect.
444        default: the default format, True being time.
445
446    Returns:
447        A callable that can be used to return the appropriately formatted time expression.
448    """
449
450    def _format_time(args: t.List):
451        return exp_class(
452            this=seq_get(args, 0),
453            format=Dialect[dialect].format_time(
454                seq_get(args, 1)
455                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
456            ),
457        )
458
459    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
462def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
463    """
464    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
465    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
466    columns are removed from the create statement.
467    """
468    has_schema = isinstance(expression.this, exp.Schema)
469    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
470
471    if has_schema and is_partitionable:
472        expression = expression.copy()
473        prop = expression.find(exp.PartitionedByProperty)
474        if prop and prop.this and not isinstance(prop.this, exp.Schema):
475            schema = expression.this
476            columns = {v.name.upper() for v in prop.this.expressions}
477            partitions = [col for col in schema.expressions if col.name.upper() in columns]
478            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
479            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
480            expression.set("this", schema)
481
482    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
485def parse_date_delta(
486    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
487) -> t.Callable[[t.List], E]:
488    def inner_func(args: t.List) -> E:
489        unit_based = len(args) == 3
490        this = args[2] if unit_based else seq_get(args, 0)
491        unit = args[0] if unit_based else exp.Literal.string("DAY")
492        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
493        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
494
495    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
498def parse_date_delta_with_interval(
499    expression_class: t.Type[E],
500) -> t.Callable[[t.List], t.Optional[E]]:
501    def func(args: t.List) -> t.Optional[E]:
502        if len(args) < 2:
503            return None
504
505        interval = args[1]
506
507        if not isinstance(interval, exp.Interval):
508            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
509
510        expression = interval.this
511        if expression and expression.is_string:
512            expression = exp.Literal.number(expression.this)
513
514        return expression_class(
515            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
516        )
517
518    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
521def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
522    unit = seq_get(args, 0)
523    this = seq_get(args, 1)
524
525    if isinstance(this, exp.Cast) and this.is_type("date"):
526        return exp.DateTrunc(unit=unit, this=this)
527    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
530def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
531    return self.func(
532        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
533    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
536def locate_to_strposition(args: t.List) -> exp.Expression:
537    return exp.StrPosition(
538        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
539    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
542def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
543    return self.func(
544        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
545    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
548def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
549    expression = expression.copy()
550    return self.sql(
551        exp.Substring(
552            this=expression.this, start=exp.Literal.number(1), length=expression.expression
553        )
554    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
557def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
558    expression = expression.copy()
559    return self.sql(
560        exp.Substring(
561            this=expression.this,
562            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
563        )
564    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
567def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
568    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
571def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
572    return self.sql(exp.cast(expression.this, "date"))
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
575def min_or_least(self: Generator, expression: exp.Min) -> str:
576    name = "LEAST" if expression.expressions else "MIN"
577    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
580def max_or_greatest(self: Generator, expression: exp.Max) -> str:
581    name = "GREATEST" if expression.expressions else "MAX"
582    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
585def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
586    cond = expression.this
587
588    if isinstance(expression.this, exp.Distinct):
589        cond = expression.this.expressions[0]
590        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
591
592    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
595def trim_sql(self: Generator, expression: exp.Trim) -> str:
596    target = self.sql(expression, "this")
597    trim_type = self.sql(expression, "position")
598    remove_chars = self.sql(expression, "expression")
599    collation = self.sql(expression, "collation")
600
601    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
602    if not remove_chars and not collation:
603        return self.trim_sql(expression)
604
605    trim_type = f"{trim_type} " if trim_type else ""
606    remove_chars = f"{remove_chars} " if remove_chars else ""
607    from_part = "FROM " if trim_type or remove_chars else ""
608    collation = f" COLLATE {collation}" if collation else ""
609    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
612def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
613    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
616def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
617    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
618        _dialect = Dialect.get_or_raise(dialect)
619        time_format = self.format_time(expression)
620        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
621            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
622
623        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
624
625    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
628def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
629    this, *rest_args = expression.expressions
630    for arg in rest_args:
631        this = exp.DPipe(this=this, expression=arg)
632
633    return self.sql(this)
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
636def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
637    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
638    if bad_args:
639        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
640
641    return self.func(
642        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
643    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
646def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
647    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
648    if bad_args:
649        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
650
651    return self.func(
652        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
653    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
656def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
657    names = []
658    for agg in aggregations:
659        if isinstance(agg, exp.Alias):
660            names.append(agg.alias)
661        else:
662            """
663            This case corresponds to aggregations without aliases being used as suffixes
664            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
665            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
666            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
667            """
668            agg_all_unquoted = agg.transform(
669                lambda node: exp.Identifier(this=node.name, quoted=False)
670                if isinstance(node, exp.Identifier)
671                else node
672            )
673            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
674
675    return names
def simplify_literal(expression: ~E, copy: bool = True) -> ~E:
678def simplify_literal(expression: E, copy: bool = True) -> E:
679    if not isinstance(expression.expression, exp.Literal):
680        from sqlglot.optimizer.simplify import simplify
681
682        expression = exp.maybe_copy(expression, copy)
683        simplify(expression.expression)
684
685    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
688def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
689    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))