Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28UNESCAPED_SEQUENCES = {
  29    "\\a": "\a",
  30    "\\b": "\b",
  31    "\\f": "\f",
  32    "\\n": "\n",
  33    "\\r": "\r",
  34    "\\t": "\t",
  35    "\\v": "\v",
  36    "\\\\": "\\",
  37}
  38
  39
  40class Dialects(str, Enum):
  41    """Dialects supported by SQLGLot."""
  42
  43    DIALECT = ""
  44
  45    ATHENA = "athena"
  46    BIGQUERY = "bigquery"
  47    CLICKHOUSE = "clickhouse"
  48    DATABRICKS = "databricks"
  49    DORIS = "doris"
  50    DRILL = "drill"
  51    DUCKDB = "duckdb"
  52    HIVE = "hive"
  53    MYSQL = "mysql"
  54    ORACLE = "oracle"
  55    POSTGRES = "postgres"
  56    PRESTO = "presto"
  57    PRQL = "prql"
  58    REDSHIFT = "redshift"
  59    SNOWFLAKE = "snowflake"
  60    SPARK = "spark"
  61    SPARK2 = "spark2"
  62    SQLITE = "sqlite"
  63    STARROCKS = "starrocks"
  64    TABLEAU = "tableau"
  65    TERADATA = "teradata"
  66    TRINO = "trino"
  67    TSQL = "tsql"
  68
  69
  70class NormalizationStrategy(str, AutoName):
  71    """Specifies the strategy according to which identifiers should be normalized."""
  72
  73    LOWERCASE = auto()
  74    """Unquoted identifiers are lowercased."""
  75
  76    UPPERCASE = auto()
  77    """Unquoted identifiers are uppercased."""
  78
  79    CASE_SENSITIVE = auto()
  80    """Always case-sensitive, regardless of quotes."""
  81
  82    CASE_INSENSITIVE = auto()
  83    """Always case-insensitive, regardless of quotes."""
  84
  85
  86class _Dialect(type):
  87    classes: t.Dict[str, t.Type[Dialect]] = {}
  88
  89    def __eq__(cls, other: t.Any) -> bool:
  90        if cls is other:
  91            return True
  92        if isinstance(other, str):
  93            return cls is cls.get(other)
  94        if isinstance(other, Dialect):
  95            return cls is type(other)
  96
  97        return False
  98
  99    def __hash__(cls) -> int:
 100        return hash(cls.__name__.lower())
 101
 102    @classmethod
 103    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 104        return cls.classes[key]
 105
 106    @classmethod
 107    def get(
 108        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 109    ) -> t.Optional[t.Type[Dialect]]:
 110        return cls.classes.get(key, default)
 111
 112    def __new__(cls, clsname, bases, attrs):
 113        klass = super().__new__(cls, clsname, bases, attrs)
 114        enum = Dialects.__members__.get(clsname.upper())
 115        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 116
 117        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 118        klass.FORMAT_TRIE = (
 119            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 120        )
 121        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 122        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 123
 124        base = seq_get(bases, 0)
 125        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 126        base_parser = (getattr(base, "parser_class", Parser),)
 127        base_generator = (getattr(base, "generator_class", Generator),)
 128
 129        klass.tokenizer_class = klass.__dict__.get(
 130            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 131        )
 132        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 133        klass.generator_class = klass.__dict__.get(
 134            "Generator", type("Generator", base_generator, {})
 135        )
 136
 137        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 138        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 139            klass.tokenizer_class._IDENTIFIERS.items()
 140        )[0]
 141
 142        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 143            return next(
 144                (
 145                    (s, e)
 146                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 147                    if t == token_type
 148                ),
 149                (None, None),
 150            )
 151
 152        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 153        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 154        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 155        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 156
 157        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 158            klass.UNESCAPED_SEQUENCES = {
 159                **UNESCAPED_SEQUENCES,
 160                **klass.UNESCAPED_SEQUENCES,
 161            }
 162
 163        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 164
 165        if enum not in ("", "bigquery"):
 166            klass.generator_class.SELECT_KINDS = ()
 167
 168        if enum not in ("", "athena", "presto", "trino"):
 169            klass.generator_class.TRY_SUPPORTED = False
 170
 171        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 172            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 173            for modifier in ("cluster", "distribute", "sort"):
 174                modifier_transforms.pop(modifier, None)
 175
 176            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 177
 178        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 179            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 180                TokenType.ANTI,
 181                TokenType.SEMI,
 182            }
 183
 184        return klass
 185
 186
 187class Dialect(metaclass=_Dialect):
 188    INDEX_OFFSET = 0
 189    """The base index offset for arrays."""
 190
 191    WEEK_OFFSET = 0
 192    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 193
 194    UNNEST_COLUMN_ONLY = False
 195    """Whether `UNNEST` table aliases are treated as column aliases."""
 196
 197    ALIAS_POST_TABLESAMPLE = False
 198    """Whether the table alias comes after tablesample."""
 199
 200    TABLESAMPLE_SIZE_IS_PERCENT = False
 201    """Whether a size in the table sample clause represents percentage."""
 202
 203    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 204    """Specifies the strategy according to which identifiers should be normalized."""
 205
 206    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 207    """Whether an unquoted identifier can start with a digit."""
 208
 209    DPIPE_IS_STRING_CONCAT = True
 210    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 211
 212    STRICT_STRING_CONCAT = False
 213    """Whether `CONCAT`'s arguments must be strings."""
 214
 215    SUPPORTS_USER_DEFINED_TYPES = True
 216    """Whether user-defined data types are supported."""
 217
 218    SUPPORTS_SEMI_ANTI_JOIN = True
 219    """Whether `SEMI` or `ANTI` joins are supported."""
 220
 221    NORMALIZE_FUNCTIONS: bool | str = "upper"
 222    """
 223    Determines how function names are going to be normalized.
 224    Possible values:
 225        "upper" or True: Convert names to uppercase.
 226        "lower": Convert names to lowercase.
 227        False: Disables function name normalization.
 228    """
 229
 230    LOG_BASE_FIRST: t.Optional[bool] = True
 231    """
 232    Whether the base comes first in the `LOG` function.
 233    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 234    """
 235
 236    NULL_ORDERING = "nulls_are_small"
 237    """
 238    Default `NULL` ordering method to use if not explicitly set.
 239    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 240    """
 241
 242    TYPED_DIVISION = False
 243    """
 244    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 245    False means `a / b` is always float division.
 246    True means `a / b` is integer division if both `a` and `b` are integers.
 247    """
 248
 249    SAFE_DIVISION = False
 250    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 251
 252    CONCAT_COALESCE = False
 253    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 254
 255    HEX_LOWERCASE = False
 256    """Whether the `HEX` function returns a lowercase hexadecimal string."""
 257
 258    DATE_FORMAT = "'%Y-%m-%d'"
 259    DATEINT_FORMAT = "'%Y%m%d'"
 260    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 261
 262    TIME_MAPPING: t.Dict[str, str] = {}
 263    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 264
 265    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 266    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 267    FORMAT_MAPPING: t.Dict[str, str] = {}
 268    """
 269    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 270    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 271    """
 272
 273    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 274    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 275
 276    PSEUDOCOLUMNS: t.Set[str] = set()
 277    """
 278    Columns that are auto-generated by the engine corresponding to this dialect.
 279    For example, such columns may be excluded from `SELECT *` queries.
 280    """
 281
 282    PREFER_CTE_ALIAS_COLUMN = False
 283    """
 284    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 285    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 286    any projection aliases in the subquery.
 287
 288    For example,
 289        WITH y(c) AS (
 290            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 291        ) SELECT c FROM y;
 292
 293        will be rewritten as
 294
 295        WITH y(c) AS (
 296            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 297        ) SELECT c FROM y;
 298    """
 299
 300    # --- Autofilled ---
 301
 302    tokenizer_class = Tokenizer
 303    parser_class = Parser
 304    generator_class = Generator
 305
 306    # A trie of the time_mapping keys
 307    TIME_TRIE: t.Dict = {}
 308    FORMAT_TRIE: t.Dict = {}
 309
 310    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 311    INVERSE_TIME_TRIE: t.Dict = {}
 312
 313    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 314
 315    # Delimiters for string literals and identifiers
 316    QUOTE_START = "'"
 317    QUOTE_END = "'"
 318    IDENTIFIER_START = '"'
 319    IDENTIFIER_END = '"'
 320
 321    # Delimiters for bit, hex, byte and unicode literals
 322    BIT_START: t.Optional[str] = None
 323    BIT_END: t.Optional[str] = None
 324    HEX_START: t.Optional[str] = None
 325    HEX_END: t.Optional[str] = None
 326    BYTE_START: t.Optional[str] = None
 327    BYTE_END: t.Optional[str] = None
 328    UNICODE_START: t.Optional[str] = None
 329    UNICODE_END: t.Optional[str] = None
 330
 331    # Separator of COPY statement parameters
 332    COPY_PARAMS_ARE_CSV = True
 333
 334    @classmethod
 335    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 336        """
 337        Look up a dialect in the global dialect registry and return it if it exists.
 338
 339        Args:
 340            dialect: The target dialect. If this is a string, it can be optionally followed by
 341                additional key-value pairs that are separated by commas and are used to specify
 342                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 343
 344        Example:
 345            >>> dialect = dialect_class = get_or_raise("duckdb")
 346            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 347
 348        Returns:
 349            The corresponding Dialect instance.
 350        """
 351
 352        if not dialect:
 353            return cls()
 354        if isinstance(dialect, _Dialect):
 355            return dialect()
 356        if isinstance(dialect, Dialect):
 357            return dialect
 358        if isinstance(dialect, str):
 359            try:
 360                dialect_name, *kv_pairs = dialect.split(",")
 361                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 362            except ValueError:
 363                raise ValueError(
 364                    f"Invalid dialect format: '{dialect}'. "
 365                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 366                )
 367
 368            result = cls.get(dialect_name.strip())
 369            if not result:
 370                from difflib import get_close_matches
 371
 372                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 373                if similar:
 374                    similar = f" Did you mean {similar}?"
 375
 376                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 377
 378            return result(**kwargs)
 379
 380        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 381
 382    @classmethod
 383    def format_time(
 384        cls, expression: t.Optional[str | exp.Expression]
 385    ) -> t.Optional[exp.Expression]:
 386        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 387        if isinstance(expression, str):
 388            return exp.Literal.string(
 389                # the time formats are quoted
 390                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 391            )
 392
 393        if expression and expression.is_string:
 394            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 395
 396        return expression
 397
 398    def __init__(self, **kwargs) -> None:
 399        normalization_strategy = kwargs.get("normalization_strategy")
 400
 401        if normalization_strategy is None:
 402            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 403        else:
 404            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 405
 406    def __eq__(self, other: t.Any) -> bool:
 407        # Does not currently take dialect state into account
 408        return type(self) == other
 409
 410    def __hash__(self) -> int:
 411        # Does not currently take dialect state into account
 412        return hash(type(self))
 413
 414    def normalize_identifier(self, expression: E) -> E:
 415        """
 416        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 417
 418        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 419        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 420        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 421        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 422
 423        There are also dialects like Spark, which are case-insensitive even when quotes are
 424        present, and dialects like MySQL, whose resolution rules match those employed by the
 425        underlying operating system, for example they may always be case-sensitive in Linux.
 426
 427        Finally, the normalization behavior of some engines can even be controlled through flags,
 428        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 429
 430        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 431        that it can analyze queries in the optimizer and successfully capture their semantics.
 432        """
 433        if (
 434            isinstance(expression, exp.Identifier)
 435            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 436            and (
 437                not expression.quoted
 438                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 439            )
 440        ):
 441            expression.set(
 442                "this",
 443                (
 444                    expression.this.upper()
 445                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 446                    else expression.this.lower()
 447                ),
 448            )
 449
 450        return expression
 451
 452    def case_sensitive(self, text: str) -> bool:
 453        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 454        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 455            return False
 456
 457        unsafe = (
 458            str.islower
 459            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 460            else str.isupper
 461        )
 462        return any(unsafe(char) for char in text)
 463
 464    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 465        """Checks if text can be identified given an identify option.
 466
 467        Args:
 468            text: The text to check.
 469            identify:
 470                `"always"` or `True`: Always returns `True`.
 471                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 472
 473        Returns:
 474            Whether the given text can be identified.
 475        """
 476        if identify is True or identify == "always":
 477            return True
 478
 479        if identify == "safe":
 480            return not self.case_sensitive(text)
 481
 482        return False
 483
 484    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 485        """
 486        Adds quotes to a given identifier.
 487
 488        Args:
 489            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 490            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 491                "unsafe", with respect to its characters and this dialect's normalization strategy.
 492        """
 493        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 494            name = expression.this
 495            expression.set(
 496                "quoted",
 497                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 498            )
 499
 500        return expression
 501
 502    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 503        if isinstance(path, exp.Literal):
 504            path_text = path.name
 505            if path.is_number:
 506                path_text = f"[{path_text}]"
 507
 508            try:
 509                return parse_json_path(path_text)
 510            except ParseError as e:
 511                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 512
 513        return path
 514
 515    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 516        return self.parser(**opts).parse(self.tokenize(sql), sql)
 517
 518    def parse_into(
 519        self, expression_type: exp.IntoType, sql: str, **opts
 520    ) -> t.List[t.Optional[exp.Expression]]:
 521        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 522
 523    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 524        return self.generator(**opts).generate(expression, copy=copy)
 525
 526    def transpile(self, sql: str, **opts) -> t.List[str]:
 527        return [
 528            self.generate(expression, copy=False, **opts) if expression else ""
 529            for expression in self.parse(sql)
 530        ]
 531
 532    def tokenize(self, sql: str) -> t.List[Token]:
 533        return self.tokenizer.tokenize(sql)
 534
 535    @property
 536    def tokenizer(self) -> Tokenizer:
 537        if not hasattr(self, "_tokenizer"):
 538            self._tokenizer = self.tokenizer_class(dialect=self)
 539        return self._tokenizer
 540
 541    def parser(self, **opts) -> Parser:
 542        return self.parser_class(dialect=self, **opts)
 543
 544    def generator(self, **opts) -> Generator:
 545        return self.generator_class(dialect=self, **opts)
 546
 547
 548DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 549
 550
 551def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 552    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 553
 554
 555def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 556    if expression.args.get("accuracy"):
 557        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 558    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 559
 560
 561def if_sql(
 562    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 563) -> t.Callable[[Generator, exp.If], str]:
 564    def _if_sql(self: Generator, expression: exp.If) -> str:
 565        return self.func(
 566            name,
 567            expression.this,
 568            expression.args.get("true"),
 569            expression.args.get("false") or false_value,
 570        )
 571
 572    return _if_sql
 573
 574
 575def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 576    this = expression.this
 577    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 578        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 579
 580    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 581
 582
 583def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 584    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 585
 586
 587def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 588    elem = seq_get(expression.expressions, 0)
 589    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 590        return self.func("ARRAY", elem)
 591    return inline_array_sql(self, expression)
 592
 593
 594def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 595    return self.like_sql(
 596        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 597    )
 598
 599
 600def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 601    zone = self.sql(expression, "this")
 602    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 603
 604
 605def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 606    if expression.args.get("recursive"):
 607        self.unsupported("Recursive CTEs are unsupported")
 608        expression.args["recursive"] = False
 609    return self.with_sql(expression)
 610
 611
 612def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 613    n = self.sql(expression, "this")
 614    d = self.sql(expression, "expression")
 615    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 616
 617
 618def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 619    self.unsupported("TABLESAMPLE unsupported")
 620    return self.sql(expression.this)
 621
 622
 623def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 624    self.unsupported("PIVOT unsupported")
 625    return ""
 626
 627
 628def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 629    return self.cast_sql(expression)
 630
 631
 632def no_comment_column_constraint_sql(
 633    self: Generator, expression: exp.CommentColumnConstraint
 634) -> str:
 635    self.unsupported("CommentColumnConstraint unsupported")
 636    return ""
 637
 638
 639def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 640    self.unsupported("MAP_FROM_ENTRIES unsupported")
 641    return ""
 642
 643
 644def str_position_sql(
 645    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 646) -> str:
 647    this = self.sql(expression, "this")
 648    substr = self.sql(expression, "substr")
 649    position = self.sql(expression, "position")
 650    instance = expression.args.get("instance") if generate_instance else None
 651    position_offset = ""
 652
 653    if position:
 654        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 655        this = self.func("SUBSTR", this, position)
 656        position_offset = f" + {position} - 1"
 657
 658    return self.func("STRPOS", this, substr, instance) + position_offset
 659
 660
 661def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 662    return (
 663        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 664    )
 665
 666
 667def var_map_sql(
 668    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 669) -> str:
 670    keys = expression.args["keys"]
 671    values = expression.args["values"]
 672
 673    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 674        self.unsupported("Cannot convert array columns into map.")
 675        return self.func(map_func_name, keys, values)
 676
 677    args = []
 678    for key, value in zip(keys.expressions, values.expressions):
 679        args.append(self.sql(key))
 680        args.append(self.sql(value))
 681
 682    return self.func(map_func_name, *args)
 683
 684
 685def build_formatted_time(
 686    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 687) -> t.Callable[[t.List], E]:
 688    """Helper used for time expressions.
 689
 690    Args:
 691        exp_class: the expression class to instantiate.
 692        dialect: target sql dialect.
 693        default: the default format, True being time.
 694
 695    Returns:
 696        A callable that can be used to return the appropriately formatted time expression.
 697    """
 698
 699    def _builder(args: t.List):
 700        return exp_class(
 701            this=seq_get(args, 0),
 702            format=Dialect[dialect].format_time(
 703                seq_get(args, 1)
 704                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 705            ),
 706        )
 707
 708    return _builder
 709
 710
 711def time_format(
 712    dialect: DialectType = None,
 713) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 714    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 715        """
 716        Returns the time format for a given expression, unless it's equivalent
 717        to the default time format of the dialect of interest.
 718        """
 719        time_format = self.format_time(expression)
 720        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 721
 722    return _time_format
 723
 724
 725def build_date_delta(
 726    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 727) -> t.Callable[[t.List], E]:
 728    def _builder(args: t.List) -> E:
 729        unit_based = len(args) == 3
 730        this = args[2] if unit_based else seq_get(args, 0)
 731        unit = args[0] if unit_based else exp.Literal.string("DAY")
 732        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 733        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 734
 735    return _builder
 736
 737
 738def build_date_delta_with_interval(
 739    expression_class: t.Type[E],
 740) -> t.Callable[[t.List], t.Optional[E]]:
 741    def _builder(args: t.List) -> t.Optional[E]:
 742        if len(args) < 2:
 743            return None
 744
 745        interval = args[1]
 746
 747        if not isinstance(interval, exp.Interval):
 748            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 749
 750        expression = interval.this
 751        if expression and expression.is_string:
 752            expression = exp.Literal.number(expression.this)
 753
 754        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 755
 756    return _builder
 757
 758
 759def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 760    unit = seq_get(args, 0)
 761    this = seq_get(args, 1)
 762
 763    if isinstance(this, exp.Cast) and this.is_type("date"):
 764        return exp.DateTrunc(unit=unit, this=this)
 765    return exp.TimestampTrunc(this=this, unit=unit)
 766
 767
 768def date_add_interval_sql(
 769    data_type: str, kind: str
 770) -> t.Callable[[Generator, exp.Expression], str]:
 771    def func(self: Generator, expression: exp.Expression) -> str:
 772        this = self.sql(expression, "this")
 773        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 774        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 775
 776    return func
 777
 778
 779def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
 780    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 781        args = [unit_to_str(expression), expression.this]
 782        if zone:
 783            args.append(expression.args.get("zone"))
 784        return self.func("DATE_TRUNC", *args)
 785
 786    return _timestamptrunc_sql
 787
 788
 789def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 790    if not expression.expression:
 791        from sqlglot.optimizer.annotate_types import annotate_types
 792
 793        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 794        return self.sql(exp.cast(expression.this, target_type))
 795    if expression.text("expression").lower() in TIMEZONES:
 796        return self.sql(
 797            exp.AtTimeZone(
 798                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 799                zone=expression.expression,
 800            )
 801        )
 802    return self.func("TIMESTAMP", expression.this, expression.expression)
 803
 804
 805def locate_to_strposition(args: t.List) -> exp.Expression:
 806    return exp.StrPosition(
 807        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 808    )
 809
 810
 811def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 812    return self.func(
 813        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 814    )
 815
 816
 817def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 818    return self.sql(
 819        exp.Substring(
 820            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 821        )
 822    )
 823
 824
 825def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 826    return self.sql(
 827        exp.Substring(
 828            this=expression.this,
 829            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 830        )
 831    )
 832
 833
 834def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 835    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 836
 837
 838def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 839    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 840
 841
 842# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 843def encode_decode_sql(
 844    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 845) -> str:
 846    charset = expression.args.get("charset")
 847    if charset and charset.name.lower() != "utf-8":
 848        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 849
 850    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 851
 852
 853def min_or_least(self: Generator, expression: exp.Min) -> str:
 854    name = "LEAST" if expression.expressions else "MIN"
 855    return rename_func(name)(self, expression)
 856
 857
 858def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 859    name = "GREATEST" if expression.expressions else "MAX"
 860    return rename_func(name)(self, expression)
 861
 862
 863def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 864    cond = expression.this
 865
 866    if isinstance(expression.this, exp.Distinct):
 867        cond = expression.this.expressions[0]
 868        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 869
 870    return self.func("sum", exp.func("if", cond, 1, 0))
 871
 872
 873def trim_sql(self: Generator, expression: exp.Trim) -> str:
 874    target = self.sql(expression, "this")
 875    trim_type = self.sql(expression, "position")
 876    remove_chars = self.sql(expression, "expression")
 877    collation = self.sql(expression, "collation")
 878
 879    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 880    if not remove_chars and not collation:
 881        return self.trim_sql(expression)
 882
 883    trim_type = f"{trim_type} " if trim_type else ""
 884    remove_chars = f"{remove_chars} " if remove_chars else ""
 885    from_part = "FROM " if trim_type or remove_chars else ""
 886    collation = f" COLLATE {collation}" if collation else ""
 887    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 888
 889
 890def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 891    return self.func("STRPTIME", expression.this, self.format_time(expression))
 892
 893
 894def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 895    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 896
 897
 898def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 899    delim, *rest_args = expression.expressions
 900    return self.sql(
 901        reduce(
 902            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 903            rest_args,
 904        )
 905    )
 906
 907
 908def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 909    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 910    if bad_args:
 911        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 912
 913    return self.func(
 914        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 915    )
 916
 917
 918def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 919    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
 920    if bad_args:
 921        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 922
 923    return self.func(
 924        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 925    )
 926
 927
 928def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 929    names = []
 930    for agg in aggregations:
 931        if isinstance(agg, exp.Alias):
 932            names.append(agg.alias)
 933        else:
 934            """
 935            This case corresponds to aggregations without aliases being used as suffixes
 936            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 937            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 938            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 939            """
 940            agg_all_unquoted = agg.transform(
 941                lambda node: (
 942                    exp.Identifier(this=node.name, quoted=False)
 943                    if isinstance(node, exp.Identifier)
 944                    else node
 945                )
 946            )
 947            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 948
 949    return names
 950
 951
 952def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 953    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 954
 955
 956# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 957def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 958    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 959
 960
 961def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 962    return self.func("MAX", expression.this)
 963
 964
 965def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 966    a = self.sql(expression.left)
 967    b = self.sql(expression.right)
 968    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 969
 970
 971def is_parse_json(expression: exp.Expression) -> bool:
 972    return isinstance(expression, exp.ParseJSON) or (
 973        isinstance(expression, exp.Cast) and expression.is_type("json")
 974    )
 975
 976
 977def isnull_to_is_null(args: t.List) -> exp.Expression:
 978    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 979
 980
 981def generatedasidentitycolumnconstraint_sql(
 982    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 983) -> str:
 984    start = self.sql(expression, "start") or "1"
 985    increment = self.sql(expression, "increment") or "1"
 986    return f"IDENTITY({start}, {increment})"
 987
 988
 989def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 990    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 991        if expression.args.get("count"):
 992            self.unsupported(f"Only two arguments are supported in function {name}.")
 993
 994        return self.func(name, expression.this, expression.expression)
 995
 996    return _arg_max_or_min_sql
 997
 998
 999def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1000    this = expression.this.copy()
1001
1002    return_type = expression.return_type
1003    if return_type.is_type(exp.DataType.Type.DATE):
1004        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1005        # can truncate timestamp strings, because some dialects can't cast them to DATE
1006        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1007
1008    expression.this.replace(exp.cast(this, return_type))
1009    return expression
1010
1011
1012def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1013    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1014        if cast and isinstance(expression, exp.TsOrDsAdd):
1015            expression = ts_or_ds_add_cast(expression)
1016
1017        return self.func(
1018            name,
1019            unit_to_var(expression),
1020            expression.expression,
1021            expression.this,
1022        )
1023
1024    return _delta_sql
1025
1026
1027def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1028    unit = expression.args.get("unit")
1029
1030    if isinstance(unit, exp.Placeholder):
1031        return unit
1032    if unit:
1033        return exp.Literal.string(unit.name)
1034    return exp.Literal.string(default) if default else None
1035
1036
1037def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1038    unit = expression.args.get("unit")
1039
1040    if isinstance(unit, (exp.Var, exp.Placeholder)):
1041        return unit
1042    return exp.Var(this=default) if default else None
1043
1044
1045def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1046    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1047    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1048    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1049
1050    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1051
1052
1053def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1054    """Remove table refs from columns in when statements."""
1055    alias = expression.this.args.get("alias")
1056
1057    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1058        return self.dialect.normalize_identifier(identifier).name if identifier else None
1059
1060    targets = {normalize(expression.this.this)}
1061
1062    if alias:
1063        targets.add(normalize(alias.this))
1064
1065    for when in expression.expressions:
1066        when.transform(
1067            lambda node: (
1068                exp.column(node.this)
1069                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1070                else node
1071            ),
1072            copy=False,
1073        )
1074
1075    return self.merge_sql(expression)
1076
1077
1078def build_json_extract_path(
1079    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1080) -> t.Callable[[t.List], F]:
1081    def _builder(args: t.List) -> F:
1082        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1083        for arg in args[1:]:
1084            if not isinstance(arg, exp.Literal):
1085                # We use the fallback parser because we can't really transpile non-literals safely
1086                return expr_type.from_arg_list(args)
1087
1088            text = arg.name
1089            if is_int(text):
1090                index = int(text)
1091                segments.append(
1092                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1093                )
1094            else:
1095                segments.append(exp.JSONPathKey(this=text))
1096
1097        # This is done to avoid failing in the expression validator due to the arg count
1098        del args[2:]
1099        return expr_type(
1100            this=seq_get(args, 0),
1101            expression=exp.JSONPath(expressions=segments),
1102            only_json_types=arrow_req_json_type,
1103        )
1104
1105    return _builder
1106
1107
1108def json_extract_segments(
1109    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1110) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1111    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1112        path = expression.expression
1113        if not isinstance(path, exp.JSONPath):
1114            return rename_func(name)(self, expression)
1115
1116        segments = []
1117        for segment in path.expressions:
1118            path = self.sql(segment)
1119            if path:
1120                if isinstance(segment, exp.JSONPathPart) and (
1121                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1122                ):
1123                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1124
1125                segments.append(path)
1126
1127        if op:
1128            return f" {op} ".join([self.sql(expression.this), *segments])
1129        return self.func(name, expression.this, *segments)
1130
1131    return _json_extract_segments
1132
1133
1134def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1135    if isinstance(expression.this, exp.JSONPathWildcard):
1136        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1137
1138    return expression.name
1139
1140
1141def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1142    cond = expression.expression
1143    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1144        alias = cond.expressions[0]
1145        cond = cond.this
1146    elif isinstance(cond, exp.Predicate):
1147        alias = "_u"
1148    else:
1149        self.unsupported("Unsupported filter condition")
1150        return ""
1151
1152    unnest = exp.Unnest(expressions=[expression.this])
1153    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1154    return self.sql(exp.Array(expressions=[filtered]))
1155
1156
1157def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1158    return self.func(
1159        "TO_NUMBER",
1160        expression.this,
1161        expression.args.get("format"),
1162        expression.args.get("nlsparam"),
1163    )
1164
1165
1166def build_default_decimal_type(
1167    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1168) -> t.Callable[[exp.DataType], exp.DataType]:
1169    def _builder(dtype: exp.DataType) -> exp.DataType:
1170        if dtype.expressions or precision is None:
1171            return dtype
1172
1173        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1174        return exp.DataType.build(f"DECIMAL({params})")
1175
1176    return _builder
logger = <Logger sqlglot (WARNING)>
UNESCAPED_SEQUENCES = {'\\a': '\x07', '\\b': '\x08', '\\f': '\x0c', '\\n': '\n', '\\r': '\r', '\\t': '\t', '\\v': '\x0b', '\\\\': '\\'}
class Dialects(builtins.str, enum.Enum):
41class Dialects(str, Enum):
42    """Dialects supported by SQLGLot."""
43
44    DIALECT = ""
45
46    ATHENA = "athena"
47    BIGQUERY = "bigquery"
48    CLICKHOUSE = "clickhouse"
49    DATABRICKS = "databricks"
50    DORIS = "doris"
51    DRILL = "drill"
52    DUCKDB = "duckdb"
53    HIVE = "hive"
54    MYSQL = "mysql"
55    ORACLE = "oracle"
56    POSTGRES = "postgres"
57    PRESTO = "presto"
58    PRQL = "prql"
59    REDSHIFT = "redshift"
60    SNOWFLAKE = "snowflake"
61    SPARK = "spark"
62    SPARK2 = "spark2"
63    SQLITE = "sqlite"
64    STARROCKS = "starrocks"
65    TABLEAU = "tableau"
66    TERADATA = "teradata"
67    TRINO = "trino"
68    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
71class NormalizationStrategy(str, AutoName):
72    """Specifies the strategy according to which identifiers should be normalized."""
73
74    LOWERCASE = auto()
75    """Unquoted identifiers are lowercased."""
76
77    UPPERCASE = auto()
78    """Unquoted identifiers are uppercased."""
79
80    CASE_SENSITIVE = auto()
81    """Always case-sensitive, regardless of quotes."""
82
83    CASE_INSENSITIVE = auto()
84    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
188class Dialect(metaclass=_Dialect):
189    INDEX_OFFSET = 0
190    """The base index offset for arrays."""
191
192    WEEK_OFFSET = 0
193    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
194
195    UNNEST_COLUMN_ONLY = False
196    """Whether `UNNEST` table aliases are treated as column aliases."""
197
198    ALIAS_POST_TABLESAMPLE = False
199    """Whether the table alias comes after tablesample."""
200
201    TABLESAMPLE_SIZE_IS_PERCENT = False
202    """Whether a size in the table sample clause represents percentage."""
203
204    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
205    """Specifies the strategy according to which identifiers should be normalized."""
206
207    IDENTIFIERS_CAN_START_WITH_DIGIT = False
208    """Whether an unquoted identifier can start with a digit."""
209
210    DPIPE_IS_STRING_CONCAT = True
211    """Whether the DPIPE token (`||`) is a string concatenation operator."""
212
213    STRICT_STRING_CONCAT = False
214    """Whether `CONCAT`'s arguments must be strings."""
215
216    SUPPORTS_USER_DEFINED_TYPES = True
217    """Whether user-defined data types are supported."""
218
219    SUPPORTS_SEMI_ANTI_JOIN = True
220    """Whether `SEMI` or `ANTI` joins are supported."""
221
222    NORMALIZE_FUNCTIONS: bool | str = "upper"
223    """
224    Determines how function names are going to be normalized.
225    Possible values:
226        "upper" or True: Convert names to uppercase.
227        "lower": Convert names to lowercase.
228        False: Disables function name normalization.
229    """
230
231    LOG_BASE_FIRST: t.Optional[bool] = True
232    """
233    Whether the base comes first in the `LOG` function.
234    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
235    """
236
237    NULL_ORDERING = "nulls_are_small"
238    """
239    Default `NULL` ordering method to use if not explicitly set.
240    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
241    """
242
243    TYPED_DIVISION = False
244    """
245    Whether the behavior of `a / b` depends on the types of `a` and `b`.
246    False means `a / b` is always float division.
247    True means `a / b` is integer division if both `a` and `b` are integers.
248    """
249
250    SAFE_DIVISION = False
251    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
252
253    CONCAT_COALESCE = False
254    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
255
256    HEX_LOWERCASE = False
257    """Whether the `HEX` function returns a lowercase hexadecimal string."""
258
259    DATE_FORMAT = "'%Y-%m-%d'"
260    DATEINT_FORMAT = "'%Y%m%d'"
261    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
262
263    TIME_MAPPING: t.Dict[str, str] = {}
264    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
265
266    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
267    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
268    FORMAT_MAPPING: t.Dict[str, str] = {}
269    """
270    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
271    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
272    """
273
274    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
275    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
276
277    PSEUDOCOLUMNS: t.Set[str] = set()
278    """
279    Columns that are auto-generated by the engine corresponding to this dialect.
280    For example, such columns may be excluded from `SELECT *` queries.
281    """
282
283    PREFER_CTE_ALIAS_COLUMN = False
284    """
285    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
286    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
287    any projection aliases in the subquery.
288
289    For example,
290        WITH y(c) AS (
291            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
292        ) SELECT c FROM y;
293
294        will be rewritten as
295
296        WITH y(c) AS (
297            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
298        ) SELECT c FROM y;
299    """
300
301    # --- Autofilled ---
302
303    tokenizer_class = Tokenizer
304    parser_class = Parser
305    generator_class = Generator
306
307    # A trie of the time_mapping keys
308    TIME_TRIE: t.Dict = {}
309    FORMAT_TRIE: t.Dict = {}
310
311    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
312    INVERSE_TIME_TRIE: t.Dict = {}
313
314    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
315
316    # Delimiters for string literals and identifiers
317    QUOTE_START = "'"
318    QUOTE_END = "'"
319    IDENTIFIER_START = '"'
320    IDENTIFIER_END = '"'
321
322    # Delimiters for bit, hex, byte and unicode literals
323    BIT_START: t.Optional[str] = None
324    BIT_END: t.Optional[str] = None
325    HEX_START: t.Optional[str] = None
326    HEX_END: t.Optional[str] = None
327    BYTE_START: t.Optional[str] = None
328    BYTE_END: t.Optional[str] = None
329    UNICODE_START: t.Optional[str] = None
330    UNICODE_END: t.Optional[str] = None
331
332    # Separator of COPY statement parameters
333    COPY_PARAMS_ARE_CSV = True
334
335    @classmethod
336    def get_or_raise(cls, dialect: DialectType) -> Dialect:
337        """
338        Look up a dialect in the global dialect registry and return it if it exists.
339
340        Args:
341            dialect: The target dialect. If this is a string, it can be optionally followed by
342                additional key-value pairs that are separated by commas and are used to specify
343                dialect settings, such as whether the dialect's identifiers are case-sensitive.
344
345        Example:
346            >>> dialect = dialect_class = get_or_raise("duckdb")
347            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
348
349        Returns:
350            The corresponding Dialect instance.
351        """
352
353        if not dialect:
354            return cls()
355        if isinstance(dialect, _Dialect):
356            return dialect()
357        if isinstance(dialect, Dialect):
358            return dialect
359        if isinstance(dialect, str):
360            try:
361                dialect_name, *kv_pairs = dialect.split(",")
362                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
363            except ValueError:
364                raise ValueError(
365                    f"Invalid dialect format: '{dialect}'. "
366                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
367                )
368
369            result = cls.get(dialect_name.strip())
370            if not result:
371                from difflib import get_close_matches
372
373                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
374                if similar:
375                    similar = f" Did you mean {similar}?"
376
377                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
378
379            return result(**kwargs)
380
381        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
382
383    @classmethod
384    def format_time(
385        cls, expression: t.Optional[str | exp.Expression]
386    ) -> t.Optional[exp.Expression]:
387        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
388        if isinstance(expression, str):
389            return exp.Literal.string(
390                # the time formats are quoted
391                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
392            )
393
394        if expression and expression.is_string:
395            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
396
397        return expression
398
399    def __init__(self, **kwargs) -> None:
400        normalization_strategy = kwargs.get("normalization_strategy")
401
402        if normalization_strategy is None:
403            self.normalization_strategy = self.NORMALIZATION_STRATEGY
404        else:
405            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
406
407    def __eq__(self, other: t.Any) -> bool:
408        # Does not currently take dialect state into account
409        return type(self) == other
410
411    def __hash__(self) -> int:
412        # Does not currently take dialect state into account
413        return hash(type(self))
414
415    def normalize_identifier(self, expression: E) -> E:
416        """
417        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
418
419        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
420        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
421        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
422        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
423
424        There are also dialects like Spark, which are case-insensitive even when quotes are
425        present, and dialects like MySQL, whose resolution rules match those employed by the
426        underlying operating system, for example they may always be case-sensitive in Linux.
427
428        Finally, the normalization behavior of some engines can even be controlled through flags,
429        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
430
431        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
432        that it can analyze queries in the optimizer and successfully capture their semantics.
433        """
434        if (
435            isinstance(expression, exp.Identifier)
436            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
437            and (
438                not expression.quoted
439                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
440            )
441        ):
442            expression.set(
443                "this",
444                (
445                    expression.this.upper()
446                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
447                    else expression.this.lower()
448                ),
449            )
450
451        return expression
452
453    def case_sensitive(self, text: str) -> bool:
454        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
455        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
456            return False
457
458        unsafe = (
459            str.islower
460            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
461            else str.isupper
462        )
463        return any(unsafe(char) for char in text)
464
465    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
466        """Checks if text can be identified given an identify option.
467
468        Args:
469            text: The text to check.
470            identify:
471                `"always"` or `True`: Always returns `True`.
472                `"safe"`: Only returns `True` if the identifier is case-insensitive.
473
474        Returns:
475            Whether the given text can be identified.
476        """
477        if identify is True or identify == "always":
478            return True
479
480        if identify == "safe":
481            return not self.case_sensitive(text)
482
483        return False
484
485    def quote_identifier(self, expression: E, identify: bool = True) -> E:
486        """
487        Adds quotes to a given identifier.
488
489        Args:
490            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
491            identify: If set to `False`, the quotes will only be added if the identifier is deemed
492                "unsafe", with respect to its characters and this dialect's normalization strategy.
493        """
494        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
495            name = expression.this
496            expression.set(
497                "quoted",
498                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
499            )
500
501        return expression
502
503    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
504        if isinstance(path, exp.Literal):
505            path_text = path.name
506            if path.is_number:
507                path_text = f"[{path_text}]"
508
509            try:
510                return parse_json_path(path_text)
511            except ParseError as e:
512                logger.warning(f"Invalid JSON path syntax. {str(e)}")
513
514        return path
515
516    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
517        return self.parser(**opts).parse(self.tokenize(sql), sql)
518
519    def parse_into(
520        self, expression_type: exp.IntoType, sql: str, **opts
521    ) -> t.List[t.Optional[exp.Expression]]:
522        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
523
524    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
525        return self.generator(**opts).generate(expression, copy=copy)
526
527    def transpile(self, sql: str, **opts) -> t.List[str]:
528        return [
529            self.generate(expression, copy=False, **opts) if expression else ""
530            for expression in self.parse(sql)
531        ]
532
533    def tokenize(self, sql: str) -> t.List[Token]:
534        return self.tokenizer.tokenize(sql)
535
536    @property
537    def tokenizer(self) -> Tokenizer:
538        if not hasattr(self, "_tokenizer"):
539            self._tokenizer = self.tokenizer_class(dialect=self)
540        return self._tokenizer
541
542    def parser(self, **opts) -> Parser:
543        return self.parser_class(dialect=self, **opts)
544
545    def generator(self, **opts) -> Generator:
546        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
399    def __init__(self, **kwargs) -> None:
400        normalization_strategy = kwargs.get("normalization_strategy")
401
402        if normalization_strategy is None:
403            self.normalization_strategy = self.NORMALIZATION_STRATEGY
404        else:
405            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

HEX_LOWERCASE = False

Whether the HEX function returns a lowercase hexadecimal string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
COPY_PARAMS_ARE_CSV = True
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
335    @classmethod
336    def get_or_raise(cls, dialect: DialectType) -> Dialect:
337        """
338        Look up a dialect in the global dialect registry and return it if it exists.
339
340        Args:
341            dialect: The target dialect. If this is a string, it can be optionally followed by
342                additional key-value pairs that are separated by commas and are used to specify
343                dialect settings, such as whether the dialect's identifiers are case-sensitive.
344
345        Example:
346            >>> dialect = dialect_class = get_or_raise("duckdb")
347            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
348
349        Returns:
350            The corresponding Dialect instance.
351        """
352
353        if not dialect:
354            return cls()
355        if isinstance(dialect, _Dialect):
356            return dialect()
357        if isinstance(dialect, Dialect):
358            return dialect
359        if isinstance(dialect, str):
360            try:
361                dialect_name, *kv_pairs = dialect.split(",")
362                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
363            except ValueError:
364                raise ValueError(
365                    f"Invalid dialect format: '{dialect}'. "
366                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
367                )
368
369            result = cls.get(dialect_name.strip())
370            if not result:
371                from difflib import get_close_matches
372
373                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
374                if similar:
375                    similar = f" Did you mean {similar}?"
376
377                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
378
379            return result(**kwargs)
380
381        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
383    @classmethod
384    def format_time(
385        cls, expression: t.Optional[str | exp.Expression]
386    ) -> t.Optional[exp.Expression]:
387        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
388        if isinstance(expression, str):
389            return exp.Literal.string(
390                # the time formats are quoted
391                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
392            )
393
394        if expression and expression.is_string:
395            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
396
397        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
415    def normalize_identifier(self, expression: E) -> E:
416        """
417        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
418
419        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
420        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
421        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
422        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
423
424        There are also dialects like Spark, which are case-insensitive even when quotes are
425        present, and dialects like MySQL, whose resolution rules match those employed by the
426        underlying operating system, for example they may always be case-sensitive in Linux.
427
428        Finally, the normalization behavior of some engines can even be controlled through flags,
429        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
430
431        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
432        that it can analyze queries in the optimizer and successfully capture their semantics.
433        """
434        if (
435            isinstance(expression, exp.Identifier)
436            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
437            and (
438                not expression.quoted
439                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
440            )
441        ):
442            expression.set(
443                "this",
444                (
445                    expression.this.upper()
446                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
447                    else expression.this.lower()
448                ),
449            )
450
451        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
453    def case_sensitive(self, text: str) -> bool:
454        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
455        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
456            return False
457
458        unsafe = (
459            str.islower
460            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
461            else str.isupper
462        )
463        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
465    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
466        """Checks if text can be identified given an identify option.
467
468        Args:
469            text: The text to check.
470            identify:
471                `"always"` or `True`: Always returns `True`.
472                `"safe"`: Only returns `True` if the identifier is case-insensitive.
473
474        Returns:
475            Whether the given text can be identified.
476        """
477        if identify is True or identify == "always":
478            return True
479
480        if identify == "safe":
481            return not self.case_sensitive(text)
482
483        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
485    def quote_identifier(self, expression: E, identify: bool = True) -> E:
486        """
487        Adds quotes to a given identifier.
488
489        Args:
490            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
491            identify: If set to `False`, the quotes will only be added if the identifier is deemed
492                "unsafe", with respect to its characters and this dialect's normalization strategy.
493        """
494        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
495            name = expression.this
496            expression.set(
497                "quoted",
498                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
499            )
500
501        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
503    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
504        if isinstance(path, exp.Literal):
505            path_text = path.name
506            if path.is_number:
507                path_text = f"[{path_text}]"
508
509            try:
510                return parse_json_path(path_text)
511            except ParseError as e:
512                logger.warning(f"Invalid JSON path syntax. {str(e)}")
513
514        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
516    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
517        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
519    def parse_into(
520        self, expression_type: exp.IntoType, sql: str, **opts
521    ) -> t.List[t.Optional[exp.Expression]]:
522        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
524    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
525        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
527    def transpile(self, sql: str, **opts) -> t.List[str]:
528        return [
529            self.generate(expression, copy=False, **opts) if expression else ""
530            for expression in self.parse(sql)
531        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
533    def tokenize(self, sql: str) -> t.List[Token]:
534        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
536    @property
537    def tokenizer(self) -> Tokenizer:
538        if not hasattr(self, "_tokenizer"):
539            self._tokenizer = self.tokenizer_class(dialect=self)
540        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
542    def parser(self, **opts) -> Parser:
543        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
545    def generator(self, **opts) -> Generator:
546        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
552def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
553    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
556def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
557    if expression.args.get("accuracy"):
558        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
559    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
562def if_sql(
563    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
564) -> t.Callable[[Generator, exp.If], str]:
565    def _if_sql(self: Generator, expression: exp.If) -> str:
566        return self.func(
567            name,
568            expression.this,
569            expression.args.get("true"),
570            expression.args.get("false") or false_value,
571        )
572
573    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
576def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
577    this = expression.this
578    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
579        this.replace(exp.cast(this, exp.DataType.Type.JSON))
580
581    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
584def inline_array_sql(self: Generator, expression: exp.Array) -> str:
585    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
588def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
589    elem = seq_get(expression.expressions, 0)
590    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
591        return self.func("ARRAY", elem)
592    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
595def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
596    return self.like_sql(
597        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
598    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
601def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
602    zone = self.sql(expression, "this")
603    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
606def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
607    if expression.args.get("recursive"):
608        self.unsupported("Recursive CTEs are unsupported")
609        expression.args["recursive"] = False
610    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
613def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
614    n = self.sql(expression, "this")
615    d = self.sql(expression, "expression")
616    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
619def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
620    self.unsupported("TABLESAMPLE unsupported")
621    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
624def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
625    self.unsupported("PIVOT unsupported")
626    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
629def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
630    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
633def no_comment_column_constraint_sql(
634    self: Generator, expression: exp.CommentColumnConstraint
635) -> str:
636    self.unsupported("CommentColumnConstraint unsupported")
637    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
640def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
641    self.unsupported("MAP_FROM_ENTRIES unsupported")
642    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
645def str_position_sql(
646    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
647) -> str:
648    this = self.sql(expression, "this")
649    substr = self.sql(expression, "substr")
650    position = self.sql(expression, "position")
651    instance = expression.args.get("instance") if generate_instance else None
652    position_offset = ""
653
654    if position:
655        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
656        this = self.func("SUBSTR", this, position)
657        position_offset = f" + {position} - 1"
658
659    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
662def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
663    return (
664        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
665    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
668def var_map_sql(
669    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
670) -> str:
671    keys = expression.args["keys"]
672    values = expression.args["values"]
673
674    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
675        self.unsupported("Cannot convert array columns into map.")
676        return self.func(map_func_name, keys, values)
677
678    args = []
679    for key, value in zip(keys.expressions, values.expressions):
680        args.append(self.sql(key))
681        args.append(self.sql(value))
682
683    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
686def build_formatted_time(
687    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
688) -> t.Callable[[t.List], E]:
689    """Helper used for time expressions.
690
691    Args:
692        exp_class: the expression class to instantiate.
693        dialect: target sql dialect.
694        default: the default format, True being time.
695
696    Returns:
697        A callable that can be used to return the appropriately formatted time expression.
698    """
699
700    def _builder(args: t.List):
701        return exp_class(
702            this=seq_get(args, 0),
703            format=Dialect[dialect].format_time(
704                seq_get(args, 1)
705                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
706            ),
707        )
708
709    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
712def time_format(
713    dialect: DialectType = None,
714) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
715    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
716        """
717        Returns the time format for a given expression, unless it's equivalent
718        to the default time format of the dialect of interest.
719        """
720        time_format = self.format_time(expression)
721        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
722
723    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
726def build_date_delta(
727    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
728) -> t.Callable[[t.List], E]:
729    def _builder(args: t.List) -> E:
730        unit_based = len(args) == 3
731        this = args[2] if unit_based else seq_get(args, 0)
732        unit = args[0] if unit_based else exp.Literal.string("DAY")
733        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
734        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
735
736    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
739def build_date_delta_with_interval(
740    expression_class: t.Type[E],
741) -> t.Callable[[t.List], t.Optional[E]]:
742    def _builder(args: t.List) -> t.Optional[E]:
743        if len(args) < 2:
744            return None
745
746        interval = args[1]
747
748        if not isinstance(interval, exp.Interval):
749            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
750
751        expression = interval.this
752        if expression and expression.is_string:
753            expression = exp.Literal.number(expression.this)
754
755        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
756
757    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
760def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
761    unit = seq_get(args, 0)
762    this = seq_get(args, 1)
763
764    if isinstance(this, exp.Cast) and this.is_type("date"):
765        return exp.DateTrunc(unit=unit, this=this)
766    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
769def date_add_interval_sql(
770    data_type: str, kind: str
771) -> t.Callable[[Generator, exp.Expression], str]:
772    def func(self: Generator, expression: exp.Expression) -> str:
773        this = self.sql(expression, "this")
774        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
775        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
776
777    return func
def timestamptrunc_sql( zone: bool = False) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.TimestampTrunc], str]:
780def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
781    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
782        args = [unit_to_str(expression), expression.this]
783        if zone:
784            args.append(expression.args.get("zone"))
785        return self.func("DATE_TRUNC", *args)
786
787    return _timestamptrunc_sql
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
790def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
791    if not expression.expression:
792        from sqlglot.optimizer.annotate_types import annotate_types
793
794        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
795        return self.sql(exp.cast(expression.this, target_type))
796    if expression.text("expression").lower() in TIMEZONES:
797        return self.sql(
798            exp.AtTimeZone(
799                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
800                zone=expression.expression,
801            )
802        )
803    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
806def locate_to_strposition(args: t.List) -> exp.Expression:
807    return exp.StrPosition(
808        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
809    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
812def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
813    return self.func(
814        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
815    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
818def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
819    return self.sql(
820        exp.Substring(
821            this=expression.this, start=exp.Literal.number(1), length=expression.expression
822        )
823    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
826def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
827    return self.sql(
828        exp.Substring(
829            this=expression.this,
830            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
831        )
832    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
835def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
836    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
839def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
840    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
844def encode_decode_sql(
845    self: Generator, expression: exp.Expression, name: str, replace: bool = True
846) -> str:
847    charset = expression.args.get("charset")
848    if charset and charset.name.lower() != "utf-8":
849        self.unsupported(f"Expected utf-8 character set, got {charset}.")
850
851    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
854def min_or_least(self: Generator, expression: exp.Min) -> str:
855    name = "LEAST" if expression.expressions else "MIN"
856    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
859def max_or_greatest(self: Generator, expression: exp.Max) -> str:
860    name = "GREATEST" if expression.expressions else "MAX"
861    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
864def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
865    cond = expression.this
866
867    if isinstance(expression.this, exp.Distinct):
868        cond = expression.this.expressions[0]
869        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
870
871    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
874def trim_sql(self: Generator, expression: exp.Trim) -> str:
875    target = self.sql(expression, "this")
876    trim_type = self.sql(expression, "position")
877    remove_chars = self.sql(expression, "expression")
878    collation = self.sql(expression, "collation")
879
880    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
881    if not remove_chars and not collation:
882        return self.trim_sql(expression)
883
884    trim_type = f"{trim_type} " if trim_type else ""
885    remove_chars = f"{remove_chars} " if remove_chars else ""
886    from_part = "FROM " if trim_type or remove_chars else ""
887    collation = f" COLLATE {collation}" if collation else ""
888    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
891def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
892    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
895def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
896    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
899def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
900    delim, *rest_args = expression.expressions
901    return self.sql(
902        reduce(
903            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
904            rest_args,
905        )
906    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
909def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
910    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
911    if bad_args:
912        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
913
914    return self.func(
915        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
916    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
919def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
920    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
921    if bad_args:
922        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
923
924    return self.func(
925        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
926    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
929def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
930    names = []
931    for agg in aggregations:
932        if isinstance(agg, exp.Alias):
933            names.append(agg.alias)
934        else:
935            """
936            This case corresponds to aggregations without aliases being used as suffixes
937            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
938            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
939            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
940            """
941            agg_all_unquoted = agg.transform(
942                lambda node: (
943                    exp.Identifier(this=node.name, quoted=False)
944                    if isinstance(node, exp.Identifier)
945                    else node
946                )
947            )
948            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
949
950    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
953def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
954    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
958def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
959    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
962def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
963    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
966def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
967    a = self.sql(expression.left)
968    b = self.sql(expression.right)
969    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
972def is_parse_json(expression: exp.Expression) -> bool:
973    return isinstance(expression, exp.ParseJSON) or (
974        isinstance(expression, exp.Cast) and expression.is_type("json")
975    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
978def isnull_to_is_null(args: t.List) -> exp.Expression:
979    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
982def generatedasidentitycolumnconstraint_sql(
983    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
984) -> str:
985    start = self.sql(expression, "start") or "1"
986    increment = self.sql(expression, "increment") or "1"
987    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
990def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
991    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
992        if expression.args.get("count"):
993            self.unsupported(f"Only two arguments are supported in function {name}.")
994
995        return self.func(name, expression.this, expression.expression)
996
997    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
1000def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1001    this = expression.this.copy()
1002
1003    return_type = expression.return_type
1004    if return_type.is_type(exp.DataType.Type.DATE):
1005        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1006        # can truncate timestamp strings, because some dialects can't cast them to DATE
1007        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1008
1009    expression.this.replace(exp.cast(this, return_type))
1010    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1013def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1014    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1015        if cast and isinstance(expression, exp.TsOrDsAdd):
1016            expression = ts_or_ds_add_cast(expression)
1017
1018        return self.func(
1019            name,
1020            unit_to_var(expression),
1021            expression.expression,
1022            expression.this,
1023        )
1024
1025    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1028def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1029    unit = expression.args.get("unit")
1030
1031    if isinstance(unit, exp.Placeholder):
1032        return unit
1033    if unit:
1034        return exp.Literal.string(unit.name)
1035    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1038def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1039    unit = expression.args.get("unit")
1040
1041    if isinstance(unit, (exp.Var, exp.Placeholder)):
1042        return unit
1043    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1046def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1047    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1048    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1049    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1050
1051    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1054def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1055    """Remove table refs from columns in when statements."""
1056    alias = expression.this.args.get("alias")
1057
1058    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1059        return self.dialect.normalize_identifier(identifier).name if identifier else None
1060
1061    targets = {normalize(expression.this.this)}
1062
1063    if alias:
1064        targets.add(normalize(alias.this))
1065
1066    for when in expression.expressions:
1067        when.transform(
1068            lambda node: (
1069                exp.column(node.this)
1070                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1071                else node
1072            ),
1073            copy=False,
1074        )
1075
1076    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1079def build_json_extract_path(
1080    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1081) -> t.Callable[[t.List], F]:
1082    def _builder(args: t.List) -> F:
1083        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1084        for arg in args[1:]:
1085            if not isinstance(arg, exp.Literal):
1086                # We use the fallback parser because we can't really transpile non-literals safely
1087                return expr_type.from_arg_list(args)
1088
1089            text = arg.name
1090            if is_int(text):
1091                index = int(text)
1092                segments.append(
1093                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1094                )
1095            else:
1096                segments.append(exp.JSONPathKey(this=text))
1097
1098        # This is done to avoid failing in the expression validator due to the arg count
1099        del args[2:]
1100        return expr_type(
1101            this=seq_get(args, 0),
1102            expression=exp.JSONPath(expressions=segments),
1103            only_json_types=arrow_req_json_type,
1104        )
1105
1106    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1109def json_extract_segments(
1110    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1111) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1112    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1113        path = expression.expression
1114        if not isinstance(path, exp.JSONPath):
1115            return rename_func(name)(self, expression)
1116
1117        segments = []
1118        for segment in path.expressions:
1119            path = self.sql(segment)
1120            if path:
1121                if isinstance(segment, exp.JSONPathPart) and (
1122                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1123                ):
1124                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1125
1126                segments.append(path)
1127
1128        if op:
1129            return f" {op} ".join([self.sql(expression.this), *segments])
1130        return self.func(name, expression.this, *segments)
1131
1132    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1135def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1136    if isinstance(expression.this, exp.JSONPathWildcard):
1137        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1138
1139    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1142def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1143    cond = expression.expression
1144    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1145        alias = cond.expressions[0]
1146        cond = cond.this
1147    elif isinstance(cond, exp.Predicate):
1148        alias = "_u"
1149    else:
1150        self.unsupported("Unsupported filter condition")
1151        return ""
1152
1153    unnest = exp.Unnest(expressions=[expression.this])
1154    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1155    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ToNumber) -> str:
1158def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1159    return self.func(
1160        "TO_NUMBER",
1161        expression.this,
1162        expression.args.get("format"),
1163        expression.args.get("nlsparam"),
1164    )
def build_default_decimal_type( precision: Optional[int] = None, scale: Optional[int] = None) -> Callable[[sqlglot.expressions.DataType], sqlglot.expressions.DataType]:
1167def build_default_decimal_type(
1168    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1169) -> t.Callable[[exp.DataType], exp.DataType]:
1170    def _builder(dtype: exp.DataType) -> exp.DataType:
1171        if dtype.expressions or precision is None:
1172            return dtype
1173
1174        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1175        return exp.DataType.build(f"DECIMAL({params})")
1176
1177    return _builder