Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot._typing import E
  8from sqlglot.errors import ParseError
  9from sqlglot.generator import Generator
 10from sqlglot.helper import flatten, seq_get
 11from sqlglot.parser import Parser
 12from sqlglot.time import format_time
 13from sqlglot.tokens import Token, Tokenizer, TokenType
 14from sqlglot.trie import new_trie
 15
 16B = t.TypeVar("B", bound=exp.Binary)
 17
 18
 19class Dialects(str, Enum):
 20    DIALECT = ""
 21
 22    BIGQUERY = "bigquery"
 23    CLICKHOUSE = "clickhouse"
 24    DATABRICKS = "databricks"
 25    DRILL = "drill"
 26    DUCKDB = "duckdb"
 27    HIVE = "hive"
 28    MYSQL = "mysql"
 29    ORACLE = "oracle"
 30    POSTGRES = "postgres"
 31    PRESTO = "presto"
 32    REDSHIFT = "redshift"
 33    SNOWFLAKE = "snowflake"
 34    SPARK = "spark"
 35    SPARK2 = "spark2"
 36    SQLITE = "sqlite"
 37    STARROCKS = "starrocks"
 38    TABLEAU = "tableau"
 39    TERADATA = "teradata"
 40    TRINO = "trino"
 41    TSQL = "tsql"
 42    Doris = "doris"
 43
 44
 45class _Dialect(type):
 46    classes: t.Dict[str, t.Type[Dialect]] = {}
 47
 48    def __eq__(cls, other: t.Any) -> bool:
 49        if cls is other:
 50            return True
 51        if isinstance(other, str):
 52            return cls is cls.get(other)
 53        if isinstance(other, Dialect):
 54            return cls is type(other)
 55
 56        return False
 57
 58    def __hash__(cls) -> int:
 59        return hash(cls.__name__.lower())
 60
 61    @classmethod
 62    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 63        return cls.classes[key]
 64
 65    @classmethod
 66    def get(
 67        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 68    ) -> t.Optional[t.Type[Dialect]]:
 69        return cls.classes.get(key, default)
 70
 71    def __new__(cls, clsname, bases, attrs):
 72        klass = super().__new__(cls, clsname, bases, attrs)
 73        enum = Dialects.__members__.get(clsname.upper())
 74        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 75
 76        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 77        klass.FORMAT_TRIE = (
 78            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 79        )
 80        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 81        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 82
 83        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 84        klass.parser_class = getattr(klass, "Parser", Parser)
 85        klass.generator_class = getattr(klass, "Generator", Generator)
 86
 87        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 88        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 89            klass.tokenizer_class._IDENTIFIERS.items()
 90        )[0]
 91
 92        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 93            return next(
 94                (
 95                    (s, e)
 96                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 97                    if t == token_type
 98                ),
 99                (None, None),
100            )
101
102        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
103        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
104        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
105
106        dialect_properties = {
107            **{
108                k: v
109                for k, v in vars(klass).items()
110                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
111            },
112            "TOKENIZER_CLASS": klass.tokenizer_class,
113        }
114
115        if enum not in ("", "bigquery"):
116            dialect_properties["SELECT_KINDS"] = ()
117
118        # Pass required dialect properties to the tokenizer, parser and generator classes
119        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
120            for name, value in dialect_properties.items():
121                if hasattr(subclass, name):
122                    setattr(subclass, name, value)
123
124        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
125            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
126
127        klass.generator_class.can_identify = klass.can_identify
128
129        return klass
130
131
132class Dialect(metaclass=_Dialect):
133    # Determines the base index offset for arrays
134    INDEX_OFFSET = 0
135
136    # If true unnest table aliases are considered only as column aliases
137    UNNEST_COLUMN_ONLY = False
138
139    # Determines whether or not the table alias comes after tablesample
140    ALIAS_POST_TABLESAMPLE = False
141
142    # Determines whether or not unquoted identifiers are resolved as uppercase
143    # When set to None, it means that the dialect treats all identifiers as case-insensitive
144    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
145
146    # Determines whether or not an unquoted identifier can start with a digit
147    IDENTIFIERS_CAN_START_WITH_DIGIT = False
148
149    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
150    DPIPE_IS_STRING_CONCAT = True
151
152    # Determines whether or not CONCAT's arguments must be strings
153    STRICT_STRING_CONCAT = False
154
155    # Determines how function names are going to be normalized
156    NORMALIZE_FUNCTIONS: bool | str = "upper"
157
158    # Indicates the default null ordering method to use if not explicitly set
159    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
160    NULL_ORDERING = "nulls_are_small"
161
162    DATE_FORMAT = "'%Y-%m-%d'"
163    DATEINT_FORMAT = "'%Y%m%d'"
164    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
165
166    # Custom time mappings in which the key represents dialect time format
167    # and the value represents a python time format
168    TIME_MAPPING: t.Dict[str, str] = {}
169
170    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
171    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
172    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
173    FORMAT_MAPPING: t.Dict[str, str] = {}
174
175    # Columns that are auto-generated by the engine corresponding to this dialect
176    # Such columns may be excluded from SELECT * queries, for example
177    PSEUDOCOLUMNS: t.Set[str] = set()
178
179    # Autofilled
180    tokenizer_class = Tokenizer
181    parser_class = Parser
182    generator_class = Generator
183
184    # A trie of the time_mapping keys
185    TIME_TRIE: t.Dict = {}
186    FORMAT_TRIE: t.Dict = {}
187
188    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
189    INVERSE_TIME_TRIE: t.Dict = {}
190
191    def __eq__(self, other: t.Any) -> bool:
192        return type(self) == other
193
194    def __hash__(self) -> int:
195        return hash(type(self))
196
197    @classmethod
198    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
199        if not dialect:
200            return cls
201        if isinstance(dialect, _Dialect):
202            return dialect
203        if isinstance(dialect, Dialect):
204            return dialect.__class__
205
206        result = cls.get(dialect)
207        if not result:
208            raise ValueError(f"Unknown dialect '{dialect}'")
209
210        return result
211
212    @classmethod
213    def format_time(
214        cls, expression: t.Optional[str | exp.Expression]
215    ) -> t.Optional[exp.Expression]:
216        if isinstance(expression, str):
217            return exp.Literal.string(
218                # the time formats are quoted
219                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
220            )
221
222        if expression and expression.is_string:
223            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
224
225        return expression
226
227    @classmethod
228    def normalize_identifier(cls, expression: E) -> E:
229        """
230        Normalizes an unquoted identifier to either lower or upper case, thus essentially
231        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
232        they will be normalized regardless of being quoted or not.
233        """
234        if isinstance(expression, exp.Identifier) and (
235            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
236        ):
237            expression.set(
238                "this",
239                expression.this.upper()
240                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
241                else expression.this.lower(),
242            )
243
244        return expression
245
246    @classmethod
247    def case_sensitive(cls, text: str) -> bool:
248        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
249        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
250            return False
251
252        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
253        return any(unsafe(char) for char in text)
254
255    @classmethod
256    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
257        """Checks if text can be identified given an identify option.
258
259        Args:
260            text: The text to check.
261            identify:
262                "always" or `True`: Always returns true.
263                "safe": True if the identifier is case-insensitive.
264
265        Returns:
266            Whether or not the given text can be identified.
267        """
268        if identify is True or identify == "always":
269            return True
270
271        if identify == "safe":
272            return not cls.case_sensitive(text)
273
274        return False
275
276    @classmethod
277    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
278        if isinstance(expression, exp.Identifier):
279            name = expression.this
280            expression.set(
281                "quoted",
282                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
283            )
284
285        return expression
286
287    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
288        return self.parser(**opts).parse(self.tokenize(sql), sql)
289
290    def parse_into(
291        self, expression_type: exp.IntoType, sql: str, **opts
292    ) -> t.List[t.Optional[exp.Expression]]:
293        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
294
295    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
296        return self.generator(**opts).generate(expression)
297
298    def transpile(self, sql: str, **opts) -> t.List[str]:
299        return [self.generate(expression, **opts) for expression in self.parse(sql)]
300
301    def tokenize(self, sql: str) -> t.List[Token]:
302        return self.tokenizer.tokenize(sql)
303
304    @property
305    def tokenizer(self) -> Tokenizer:
306        if not hasattr(self, "_tokenizer"):
307            self._tokenizer = self.tokenizer_class()
308        return self._tokenizer
309
310    def parser(self, **opts) -> Parser:
311        return self.parser_class(**opts)
312
313    def generator(self, **opts) -> Generator:
314        return self.generator_class(**opts)
315
316
317DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
318
319
320def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
321    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
322
323
324def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
325    if expression.args.get("accuracy"):
326        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
327    return self.func("APPROX_COUNT_DISTINCT", expression.this)
328
329
330def if_sql(self: Generator, expression: exp.If) -> str:
331    return self.func(
332        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
333    )
334
335
336def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
337    return self.binary(expression, "->")
338
339
340def arrow_json_extract_scalar_sql(
341    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
342) -> str:
343    return self.binary(expression, "->>")
344
345
346def inline_array_sql(self: Generator, expression: exp.Array) -> str:
347    return f"[{self.expressions(expression, flat=True)}]"
348
349
350def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
351    return self.like_sql(
352        exp.Like(
353            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
354        )
355    )
356
357
358def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
359    zone = self.sql(expression, "this")
360    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
361
362
363def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
364    if expression.args.get("recursive"):
365        self.unsupported("Recursive CTEs are unsupported")
366        expression.args["recursive"] = False
367    return self.with_sql(expression)
368
369
370def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
371    n = self.sql(expression, "this")
372    d = self.sql(expression, "expression")
373    return f"IF({d} <> 0, {n} / {d}, NULL)"
374
375
376def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
377    self.unsupported("TABLESAMPLE unsupported")
378    return self.sql(expression.this)
379
380
381def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
382    self.unsupported("PIVOT unsupported")
383    return ""
384
385
386def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
387    return self.cast_sql(expression)
388
389
390def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
391    self.unsupported("Properties unsupported")
392    return ""
393
394
395def no_comment_column_constraint_sql(
396    self: Generator, expression: exp.CommentColumnConstraint
397) -> str:
398    self.unsupported("CommentColumnConstraint unsupported")
399    return ""
400
401
402def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
403    self.unsupported("MAP_FROM_ENTRIES unsupported")
404    return ""
405
406
407def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
408    this = self.sql(expression, "this")
409    substr = self.sql(expression, "substr")
410    position = self.sql(expression, "position")
411    if position:
412        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
413    return f"STRPOS({this}, {substr})"
414
415
416def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
417    return (
418        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
419    )
420
421
422def var_map_sql(
423    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
424) -> str:
425    keys = expression.args["keys"]
426    values = expression.args["values"]
427
428    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
429        self.unsupported("Cannot convert array columns into map.")
430        return self.func(map_func_name, keys, values)
431
432    args = []
433    for key, value in zip(keys.expressions, values.expressions):
434        args.append(self.sql(key))
435        args.append(self.sql(value))
436
437    return self.func(map_func_name, *args)
438
439
440def format_time_lambda(
441    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
442) -> t.Callable[[t.List], E]:
443    """Helper used for time expressions.
444
445    Args:
446        exp_class: the expression class to instantiate.
447        dialect: target sql dialect.
448        default: the default format, True being time.
449
450    Returns:
451        A callable that can be used to return the appropriately formatted time expression.
452    """
453
454    def _format_time(args: t.List):
455        return exp_class(
456            this=seq_get(args, 0),
457            format=Dialect[dialect].format_time(
458                seq_get(args, 1)
459                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
460            ),
461        )
462
463    return _format_time
464
465
466def time_format(
467    dialect: DialectType = None,
468) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
469    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
470        """
471        Returns the time format for a given expression, unless it's equivalent
472        to the default time format of the dialect of interest.
473        """
474        time_format = self.format_time(expression)
475        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
476
477    return _time_format
478
479
480def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
481    """
482    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
483    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
484    columns are removed from the create statement.
485    """
486    has_schema = isinstance(expression.this, exp.Schema)
487    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
488
489    if has_schema and is_partitionable:
490        expression = expression.copy()
491        prop = expression.find(exp.PartitionedByProperty)
492        if prop and prop.this and not isinstance(prop.this, exp.Schema):
493            schema = expression.this
494            columns = {v.name.upper() for v in prop.this.expressions}
495            partitions = [col for col in schema.expressions if col.name.upper() in columns]
496            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
497            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
498            expression.set("this", schema)
499
500    return self.create_sql(expression)
501
502
503def parse_date_delta(
504    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
505) -> t.Callable[[t.List], E]:
506    def inner_func(args: t.List) -> E:
507        unit_based = len(args) == 3
508        this = args[2] if unit_based else seq_get(args, 0)
509        unit = args[0] if unit_based else exp.Literal.string("DAY")
510        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
511        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
512
513    return inner_func
514
515
516def parse_date_delta_with_interval(
517    expression_class: t.Type[E],
518) -> t.Callable[[t.List], t.Optional[E]]:
519    def func(args: t.List) -> t.Optional[E]:
520        if len(args) < 2:
521            return None
522
523        interval = args[1]
524
525        if not isinstance(interval, exp.Interval):
526            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
527
528        expression = interval.this
529        if expression and expression.is_string:
530            expression = exp.Literal.number(expression.this)
531
532        return expression_class(
533            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
534        )
535
536    return func
537
538
539def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
540    unit = seq_get(args, 0)
541    this = seq_get(args, 1)
542
543    if isinstance(this, exp.Cast) and this.is_type("date"):
544        return exp.DateTrunc(unit=unit, this=this)
545    return exp.TimestampTrunc(this=this, unit=unit)
546
547
548def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
549    return self.func(
550        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
551    )
552
553
554def locate_to_strposition(args: t.List) -> exp.Expression:
555    return exp.StrPosition(
556        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
557    )
558
559
560def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
561    return self.func(
562        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
563    )
564
565
566def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
567    expression = expression.copy()
568    return self.sql(
569        exp.Substring(
570            this=expression.this, start=exp.Literal.number(1), length=expression.expression
571        )
572    )
573
574
575def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
576    expression = expression.copy()
577    return self.sql(
578        exp.Substring(
579            this=expression.this,
580            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
581        )
582    )
583
584
585def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
586    return self.sql(exp.cast(expression.this, "timestamp"))
587
588
589def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
590    return self.sql(exp.cast(expression.this, "date"))
591
592
593# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
594def encode_decode_sql(
595    self: Generator, expression: exp.Expression, name: str, replace: bool = True
596) -> str:
597    charset = expression.args.get("charset")
598    if charset and charset.name.lower() != "utf-8":
599        self.unsupported(f"Expected utf-8 character set, got {charset}.")
600
601    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
602
603
604def min_or_least(self: Generator, expression: exp.Min) -> str:
605    name = "LEAST" if expression.expressions else "MIN"
606    return rename_func(name)(self, expression)
607
608
609def max_or_greatest(self: Generator, expression: exp.Max) -> str:
610    name = "GREATEST" if expression.expressions else "MAX"
611    return rename_func(name)(self, expression)
612
613
614def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
615    cond = expression.this
616
617    if isinstance(expression.this, exp.Distinct):
618        cond = expression.this.expressions[0]
619        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
620
621    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
622
623
624def trim_sql(self: Generator, expression: exp.Trim) -> str:
625    target = self.sql(expression, "this")
626    trim_type = self.sql(expression, "position")
627    remove_chars = self.sql(expression, "expression")
628    collation = self.sql(expression, "collation")
629
630    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
631    if not remove_chars and not collation:
632        return self.trim_sql(expression)
633
634    trim_type = f"{trim_type} " if trim_type else ""
635    remove_chars = f"{remove_chars} " if remove_chars else ""
636    from_part = "FROM " if trim_type or remove_chars else ""
637    collation = f" COLLATE {collation}" if collation else ""
638    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
639
640
641def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
642    return self.func("STRPTIME", expression.this, self.format_time(expression))
643
644
645def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
646    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
647        _dialect = Dialect.get_or_raise(dialect)
648        time_format = self.format_time(expression)
649        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
650            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
651
652        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
653
654    return _ts_or_ds_to_date_sql
655
656
657def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
658    expression = expression.copy()
659    this, *rest_args = expression.expressions
660    for arg in rest_args:
661        this = exp.DPipe(this=this, expression=arg)
662
663    return self.sql(this)
664
665
666def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
667    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
668    if bad_args:
669        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
670
671    return self.func(
672        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
673    )
674
675
676def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
677    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
678    if bad_args:
679        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
680
681    return self.func(
682        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
683    )
684
685
686def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
687    names = []
688    for agg in aggregations:
689        if isinstance(agg, exp.Alias):
690            names.append(agg.alias)
691        else:
692            """
693            This case corresponds to aggregations without aliases being used as suffixes
694            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
695            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
696            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
697            """
698            agg_all_unquoted = agg.transform(
699                lambda node: exp.Identifier(this=node.name, quoted=False)
700                if isinstance(node, exp.Identifier)
701                else node
702            )
703            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
704
705    return names
706
707
708def simplify_literal(expression: E) -> E:
709    if not isinstance(expression.expression, exp.Literal):
710        from sqlglot.optimizer.simplify import simplify
711
712        simplify(expression.expression)
713
714    return expression
715
716
717def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
718    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
719
720
721# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
722def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
723    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
724
725
726def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
727    return self.func("MAX", expression.this)
728
729
730# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
731def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
732    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
class Dialects(builtins.str, enum.Enum):
20class Dialects(str, Enum):
21    DIALECT = ""
22
23    BIGQUERY = "bigquery"
24    CLICKHOUSE = "clickhouse"
25    DATABRICKS = "databricks"
26    DRILL = "drill"
27    DUCKDB = "duckdb"
28    HIVE = "hive"
29    MYSQL = "mysql"
30    ORACLE = "oracle"
31    POSTGRES = "postgres"
32    PRESTO = "presto"
33    REDSHIFT = "redshift"
34    SNOWFLAKE = "snowflake"
35    SPARK = "spark"
36    SPARK2 = "spark2"
37    SQLITE = "sqlite"
38    STARROCKS = "starrocks"
39    TABLEAU = "tableau"
40    TERADATA = "teradata"
41    TRINO = "trino"
42    TSQL = "tsql"
43    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Determines whether or not unquoted identifiers are resolved as uppercase
144    # When set to None, it means that the dialect treats all identifiers as case-insensitive
145    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
151    DPIPE_IS_STRING_CONCAT = True
152
153    # Determines whether or not CONCAT's arguments must be strings
154    STRICT_STRING_CONCAT = False
155
156    # Determines how function names are going to be normalized
157    NORMALIZE_FUNCTIONS: bool | str = "upper"
158
159    # Indicates the default null ordering method to use if not explicitly set
160    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
161    NULL_ORDERING = "nulls_are_small"
162
163    DATE_FORMAT = "'%Y-%m-%d'"
164    DATEINT_FORMAT = "'%Y%m%d'"
165    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
166
167    # Custom time mappings in which the key represents dialect time format
168    # and the value represents a python time format
169    TIME_MAPPING: t.Dict[str, str] = {}
170
171    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
172    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
173    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
174    FORMAT_MAPPING: t.Dict[str, str] = {}
175
176    # Columns that are auto-generated by the engine corresponding to this dialect
177    # Such columns may be excluded from SELECT * queries, for example
178    PSEUDOCOLUMNS: t.Set[str] = set()
179
180    # Autofilled
181    tokenizer_class = Tokenizer
182    parser_class = Parser
183    generator_class = Generator
184
185    # A trie of the time_mapping keys
186    TIME_TRIE: t.Dict = {}
187    FORMAT_TRIE: t.Dict = {}
188
189    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
190    INVERSE_TIME_TRIE: t.Dict = {}
191
192    def __eq__(self, other: t.Any) -> bool:
193        return type(self) == other
194
195    def __hash__(self) -> int:
196        return hash(type(self))
197
198    @classmethod
199    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
200        if not dialect:
201            return cls
202        if isinstance(dialect, _Dialect):
203            return dialect
204        if isinstance(dialect, Dialect):
205            return dialect.__class__
206
207        result = cls.get(dialect)
208        if not result:
209            raise ValueError(f"Unknown dialect '{dialect}'")
210
211        return result
212
213    @classmethod
214    def format_time(
215        cls, expression: t.Optional[str | exp.Expression]
216    ) -> t.Optional[exp.Expression]:
217        if isinstance(expression, str):
218            return exp.Literal.string(
219                # the time formats are quoted
220                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
221            )
222
223        if expression and expression.is_string:
224            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
225
226        return expression
227
228    @classmethod
229    def normalize_identifier(cls, expression: E) -> E:
230        """
231        Normalizes an unquoted identifier to either lower or upper case, thus essentially
232        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
233        they will be normalized regardless of being quoted or not.
234        """
235        if isinstance(expression, exp.Identifier) and (
236            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
237        ):
238            expression.set(
239                "this",
240                expression.this.upper()
241                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
242                else expression.this.lower(),
243            )
244
245        return expression
246
247    @classmethod
248    def case_sensitive(cls, text: str) -> bool:
249        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
250        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
251            return False
252
253        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
254        return any(unsafe(char) for char in text)
255
256    @classmethod
257    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
258        """Checks if text can be identified given an identify option.
259
260        Args:
261            text: The text to check.
262            identify:
263                "always" or `True`: Always returns true.
264                "safe": True if the identifier is case-insensitive.
265
266        Returns:
267            Whether or not the given text can be identified.
268        """
269        if identify is True or identify == "always":
270            return True
271
272        if identify == "safe":
273            return not cls.case_sensitive(text)
274
275        return False
276
277    @classmethod
278    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
279        if isinstance(expression, exp.Identifier):
280            name = expression.this
281            expression.set(
282                "quoted",
283                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
284            )
285
286        return expression
287
288    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
289        return self.parser(**opts).parse(self.tokenize(sql), sql)
290
291    def parse_into(
292        self, expression_type: exp.IntoType, sql: str, **opts
293    ) -> t.List[t.Optional[exp.Expression]]:
294        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
295
296    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
297        return self.generator(**opts).generate(expression)
298
299    def transpile(self, sql: str, **opts) -> t.List[str]:
300        return [self.generate(expression, **opts) for expression in self.parse(sql)]
301
302    def tokenize(self, sql: str) -> t.List[Token]:
303        return self.tokenizer.tokenize(sql)
304
305    @property
306    def tokenizer(self) -> Tokenizer:
307        if not hasattr(self, "_tokenizer"):
308            self._tokenizer = self.tokenizer_class()
309        return self._tokenizer
310
311    def parser(self, **opts) -> Parser:
312        return self.parser_class(**opts)
313
314    def generator(self, **opts) -> Generator:
315        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
198    @classmethod
199    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
200        if not dialect:
201            return cls
202        if isinstance(dialect, _Dialect):
203            return dialect
204        if isinstance(dialect, Dialect):
205            return dialect.__class__
206
207        result = cls.get(dialect)
208        if not result:
209            raise ValueError(f"Unknown dialect '{dialect}'")
210
211        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
213    @classmethod
214    def format_time(
215        cls, expression: t.Optional[str | exp.Expression]
216    ) -> t.Optional[exp.Expression]:
217        if isinstance(expression, str):
218            return exp.Literal.string(
219                # the time formats are quoted
220                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
221            )
222
223        if expression and expression.is_string:
224            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
225
226        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
228    @classmethod
229    def normalize_identifier(cls, expression: E) -> E:
230        """
231        Normalizes an unquoted identifier to either lower or upper case, thus essentially
232        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
233        they will be normalized regardless of being quoted or not.
234        """
235        if isinstance(expression, exp.Identifier) and (
236            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
237        ):
238            expression.set(
239                "this",
240                expression.this.upper()
241                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
242                else expression.this.lower(),
243            )
244
245        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
247    @classmethod
248    def case_sensitive(cls, text: str) -> bool:
249        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
250        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
251            return False
252
253        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
254        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
256    @classmethod
257    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
258        """Checks if text can be identified given an identify option.
259
260        Args:
261            text: The text to check.
262            identify:
263                "always" or `True`: Always returns true.
264                "safe": True if the identifier is case-insensitive.
265
266        Returns:
267            Whether or not the given text can be identified.
268        """
269        if identify is True or identify == "always":
270            return True
271
272        if identify == "safe":
273            return not cls.case_sensitive(text)
274
275        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
277    @classmethod
278    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
279        if isinstance(expression, exp.Identifier):
280            name = expression.this
281            expression.set(
282                "quoted",
283                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
284            )
285
286        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
288    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
289        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
291    def parse_into(
292        self, expression_type: exp.IntoType, sql: str, **opts
293    ) -> t.List[t.Optional[exp.Expression]]:
294        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
296    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
297        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
299    def transpile(self, sql: str, **opts) -> t.List[str]:
300        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
302    def tokenize(self, sql: str) -> t.List[Token]:
303        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
311    def parser(self, **opts) -> Parser:
312        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
314    def generator(self, **opts) -> Generator:
315        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, sqlglot.dialects.dialect.Dialect, typing.Type[sqlglot.dialects.dialect.Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
321def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
322    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
325def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
326    if expression.args.get("accuracy"):
327        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
328    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
331def if_sql(self: Generator, expression: exp.If) -> str:
332    return self.func(
333        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
334    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
337def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
338    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
341def arrow_json_extract_scalar_sql(
342    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
343) -> str:
344    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
347def inline_array_sql(self: Generator, expression: exp.Array) -> str:
348    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
351def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
352    return self.like_sql(
353        exp.Like(
354            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
355        )
356    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
359def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
360    zone = self.sql(expression, "this")
361    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
364def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
365    if expression.args.get("recursive"):
366        self.unsupported("Recursive CTEs are unsupported")
367        expression.args["recursive"] = False
368    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
371def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
372    n = self.sql(expression, "this")
373    d = self.sql(expression, "expression")
374    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
377def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
378    self.unsupported("TABLESAMPLE unsupported")
379    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
382def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
383    self.unsupported("PIVOT unsupported")
384    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
387def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
388    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
391def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
392    self.unsupported("Properties unsupported")
393    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
396def no_comment_column_constraint_sql(
397    self: Generator, expression: exp.CommentColumnConstraint
398) -> str:
399    self.unsupported("CommentColumnConstraint unsupported")
400    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
403def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
404    self.unsupported("MAP_FROM_ENTRIES unsupported")
405    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
408def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
409    this = self.sql(expression, "this")
410    substr = self.sql(expression, "substr")
411    position = self.sql(expression, "position")
412    if position:
413        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
414    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
417def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
418    return (
419        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
420    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
423def var_map_sql(
424    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
425) -> str:
426    keys = expression.args["keys"]
427    values = expression.args["values"]
428
429    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
430        self.unsupported("Cannot convert array columns into map.")
431        return self.func(map_func_name, keys, values)
432
433    args = []
434    for key, value in zip(keys.expressions, values.expressions):
435        args.append(self.sql(key))
436        args.append(self.sql(value))
437
438    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
441def format_time_lambda(
442    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
443) -> t.Callable[[t.List], E]:
444    """Helper used for time expressions.
445
446    Args:
447        exp_class: the expression class to instantiate.
448        dialect: target sql dialect.
449        default: the default format, True being time.
450
451    Returns:
452        A callable that can be used to return the appropriately formatted time expression.
453    """
454
455    def _format_time(args: t.List):
456        return exp_class(
457            this=seq_get(args, 0),
458            format=Dialect[dialect].format_time(
459                seq_get(args, 1)
460                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
461            ),
462        )
463
464    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
467def time_format(
468    dialect: DialectType = None,
469) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
470    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
471        """
472        Returns the time format for a given expression, unless it's equivalent
473        to the default time format of the dialect of interest.
474        """
475        time_format = self.format_time(expression)
476        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
477
478    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
481def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
482    """
483    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
484    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
485    columns are removed from the create statement.
486    """
487    has_schema = isinstance(expression.this, exp.Schema)
488    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
489
490    if has_schema and is_partitionable:
491        expression = expression.copy()
492        prop = expression.find(exp.PartitionedByProperty)
493        if prop and prop.this and not isinstance(prop.this, exp.Schema):
494            schema = expression.this
495            columns = {v.name.upper() for v in prop.this.expressions}
496            partitions = [col for col in schema.expressions if col.name.upper() in columns]
497            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
498            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
499            expression.set("this", schema)
500
501    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
504def parse_date_delta(
505    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
506) -> t.Callable[[t.List], E]:
507    def inner_func(args: t.List) -> E:
508        unit_based = len(args) == 3
509        this = args[2] if unit_based else seq_get(args, 0)
510        unit = args[0] if unit_based else exp.Literal.string("DAY")
511        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
512        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
513
514    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
517def parse_date_delta_with_interval(
518    expression_class: t.Type[E],
519) -> t.Callable[[t.List], t.Optional[E]]:
520    def func(args: t.List) -> t.Optional[E]:
521        if len(args) < 2:
522            return None
523
524        interval = args[1]
525
526        if not isinstance(interval, exp.Interval):
527            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
528
529        expression = interval.this
530        if expression and expression.is_string:
531            expression = exp.Literal.number(expression.this)
532
533        return expression_class(
534            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
535        )
536
537    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
540def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
541    unit = seq_get(args, 0)
542    this = seq_get(args, 1)
543
544    if isinstance(this, exp.Cast) and this.is_type("date"):
545        return exp.DateTrunc(unit=unit, this=this)
546    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
549def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
550    return self.func(
551        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
552    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
555def locate_to_strposition(args: t.List) -> exp.Expression:
556    return exp.StrPosition(
557        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
558    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
561def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
562    return self.func(
563        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
564    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
567def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
568    expression = expression.copy()
569    return self.sql(
570        exp.Substring(
571            this=expression.this, start=exp.Literal.number(1), length=expression.expression
572        )
573    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
576def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
577    expression = expression.copy()
578    return self.sql(
579        exp.Substring(
580            this=expression.this,
581            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
582        )
583    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
586def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
587    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
590def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
591    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
595def encode_decode_sql(
596    self: Generator, expression: exp.Expression, name: str, replace: bool = True
597) -> str:
598    charset = expression.args.get("charset")
599    if charset and charset.name.lower() != "utf-8":
600        self.unsupported(f"Expected utf-8 character set, got {charset}.")
601
602    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
605def min_or_least(self: Generator, expression: exp.Min) -> str:
606    name = "LEAST" if expression.expressions else "MIN"
607    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
610def max_or_greatest(self: Generator, expression: exp.Max) -> str:
611    name = "GREATEST" if expression.expressions else "MAX"
612    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
615def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
616    cond = expression.this
617
618    if isinstance(expression.this, exp.Distinct):
619        cond = expression.this.expressions[0]
620        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
621
622    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
625def trim_sql(self: Generator, expression: exp.Trim) -> str:
626    target = self.sql(expression, "this")
627    trim_type = self.sql(expression, "position")
628    remove_chars = self.sql(expression, "expression")
629    collation = self.sql(expression, "collation")
630
631    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
632    if not remove_chars and not collation:
633        return self.trim_sql(expression)
634
635    trim_type = f"{trim_type} " if trim_type else ""
636    remove_chars = f"{remove_chars} " if remove_chars else ""
637    from_part = "FROM " if trim_type or remove_chars else ""
638    collation = f" COLLATE {collation}" if collation else ""
639    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
642def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
643    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
646def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
647    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
648        _dialect = Dialect.get_or_raise(dialect)
649        time_format = self.format_time(expression)
650        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
651            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
652
653        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
654
655    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
658def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
659    expression = expression.copy()
660    this, *rest_args = expression.expressions
661    for arg in rest_args:
662        this = exp.DPipe(this=this, expression=arg)
663
664    return self.sql(this)
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
667def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
668    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
669    if bad_args:
670        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
671
672    return self.func(
673        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
674    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
677def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
678    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
679    if bad_args:
680        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
681
682    return self.func(
683        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
684    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
687def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
688    names = []
689    for agg in aggregations:
690        if isinstance(agg, exp.Alias):
691            names.append(agg.alias)
692        else:
693            """
694            This case corresponds to aggregations without aliases being used as suffixes
695            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
696            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
697            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
698            """
699            agg_all_unquoted = agg.transform(
700                lambda node: exp.Identifier(this=node.name, quoted=False)
701                if isinstance(node, exp.Identifier)
702                else node
703            )
704            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
705
706    return names
def simplify_literal(expression: ~E) -> ~E:
709def simplify_literal(expression: E) -> E:
710    if not isinstance(expression.expression, exp.Literal):
711        from sqlglot.optimizer.simplify import simplify
712
713        simplify(expression.expression)
714
715    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
718def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
719    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
723def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
724    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
727def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
728    return self.func("MAX", expression.this)
def json_keyvalue_comma_sql(self, expression: sqlglot.expressions.JSONKeyValue) -> str:
732def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
733    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"