Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20
  21if t.TYPE_CHECKING:
  22    from sqlglot._typing import B, E, F
  23
  24    JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    BIGQUERY = "bigquery"
  35    CLICKHOUSE = "clickhouse"
  36    DATABRICKS = "databricks"
  37    DORIS = "doris"
  38    DRILL = "drill"
  39    DUCKDB = "duckdb"
  40    HIVE = "hive"
  41    MYSQL = "mysql"
  42    ORACLE = "oracle"
  43    POSTGRES = "postgres"
  44    PRESTO = "presto"
  45    REDSHIFT = "redshift"
  46    SNOWFLAKE = "snowflake"
  47    SPARK = "spark"
  48    SPARK2 = "spark2"
  49    SQLITE = "sqlite"
  50    STARROCKS = "starrocks"
  51    TABLEAU = "tableau"
  52    TERADATA = "teradata"
  53    TRINO = "trino"
  54    TSQL = "tsql"
  55
  56
  57class NormalizationStrategy(str, AutoName):
  58    """Specifies the strategy according to which identifiers should be normalized."""
  59
  60    LOWERCASE = auto()
  61    """Unquoted identifiers are lowercased."""
  62
  63    UPPERCASE = auto()
  64    """Unquoted identifiers are uppercased."""
  65
  66    CASE_SENSITIVE = auto()
  67    """Always case-sensitive, regardless of quotes."""
  68
  69    CASE_INSENSITIVE = auto()
  70    """Always case-insensitive, regardless of quotes."""
  71
  72
  73class _Dialect(type):
  74    classes: t.Dict[str, t.Type[Dialect]] = {}
  75
  76    def __eq__(cls, other: t.Any) -> bool:
  77        if cls is other:
  78            return True
  79        if isinstance(other, str):
  80            return cls is cls.get(other)
  81        if isinstance(other, Dialect):
  82            return cls is type(other)
  83
  84        return False
  85
  86    def __hash__(cls) -> int:
  87        return hash(cls.__name__.lower())
  88
  89    @classmethod
  90    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  91        return cls.classes[key]
  92
  93    @classmethod
  94    def get(
  95        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  96    ) -> t.Optional[t.Type[Dialect]]:
  97        return cls.classes.get(key, default)
  98
  99    def __new__(cls, clsname, bases, attrs):
 100        klass = super().__new__(cls, clsname, bases, attrs)
 101        enum = Dialects.__members__.get(clsname.upper())
 102        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 103
 104        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 105        klass.FORMAT_TRIE = (
 106            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 107        )
 108        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 109        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 110
 111        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 112
 113        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 114        klass.parser_class = getattr(klass, "Parser", Parser)
 115        klass.generator_class = getattr(klass, "Generator", Generator)
 116
 117        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 118        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 119            klass.tokenizer_class._IDENTIFIERS.items()
 120        )[0]
 121
 122        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 123            return next(
 124                (
 125                    (s, e)
 126                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 127                    if t == token_type
 128                ),
 129                (None, None),
 130            )
 131
 132        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 133        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 134        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 135        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 136
 137        if enum not in ("", "bigquery"):
 138            klass.generator_class.SELECT_KINDS = ()
 139
 140        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 141            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 142                TokenType.ANTI,
 143                TokenType.SEMI,
 144            }
 145
 146        return klass
 147
 148
 149class Dialect(metaclass=_Dialect):
 150    INDEX_OFFSET = 0
 151    """Determines the base index offset for arrays."""
 152
 153    WEEK_OFFSET = 0
 154    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 155
 156    UNNEST_COLUMN_ONLY = False
 157    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
 158
 159    ALIAS_POST_TABLESAMPLE = False
 160    """Determines whether or not the table alias comes after tablesample."""
 161
 162    TABLESAMPLE_SIZE_IS_PERCENT = False
 163    """Determines whether or not a size in the table sample clause represents percentage."""
 164
 165    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 166    """Specifies the strategy according to which identifiers should be normalized."""
 167
 168    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 169    """Determines whether or not an unquoted identifier can start with a digit."""
 170
 171    DPIPE_IS_STRING_CONCAT = True
 172    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
 173
 174    STRICT_STRING_CONCAT = False
 175    """Determines whether or not `CONCAT`'s arguments must be strings."""
 176
 177    SUPPORTS_USER_DEFINED_TYPES = True
 178    """Determines whether or not user-defined data types are supported."""
 179
 180    SUPPORTS_SEMI_ANTI_JOIN = True
 181    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
 182
 183    NORMALIZE_FUNCTIONS: bool | str = "upper"
 184    """Determines how function names are going to be normalized."""
 185
 186    LOG_BASE_FIRST = True
 187    """Determines whether the base comes first in the `LOG` function."""
 188
 189    NULL_ORDERING = "nulls_are_small"
 190    """
 191    Indicates the default `NULL` ordering method to use if not explicitly set.
 192    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 193    """
 194
 195    TYPED_DIVISION = False
 196    """
 197    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 198    False means `a / b` is always float division.
 199    True means `a / b` is integer division if both `a` and `b` are integers.
 200    """
 201
 202    SAFE_DIVISION = False
 203    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 204
 205    CONCAT_COALESCE = False
 206    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 207
 208    DATE_FORMAT = "'%Y-%m-%d'"
 209    DATEINT_FORMAT = "'%Y%m%d'"
 210    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 211
 212    TIME_MAPPING: t.Dict[str, str] = {}
 213    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
 214
 215    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 216    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 217    FORMAT_MAPPING: t.Dict[str, str] = {}
 218    """
 219    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 220    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 221    """
 222
 223    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 224    """Mapping of an unescaped escape sequence to the corresponding character."""
 225
 226    PSEUDOCOLUMNS: t.Set[str] = set()
 227    """
 228    Columns that are auto-generated by the engine corresponding to this dialect.
 229    For example, such columns may be excluded from `SELECT *` queries.
 230    """
 231
 232    PREFER_CTE_ALIAS_COLUMN = False
 233    """
 234    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 235    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 236    any projection aliases in the subquery.
 237
 238    For example,
 239        WITH y(c) AS (
 240            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 241        ) SELECT c FROM y;
 242
 243        will be rewritten as
 244
 245        WITH y(c) AS (
 246            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 247        ) SELECT c FROM y;
 248    """
 249
 250    # --- Autofilled ---
 251
 252    tokenizer_class = Tokenizer
 253    parser_class = Parser
 254    generator_class = Generator
 255
 256    # A trie of the time_mapping keys
 257    TIME_TRIE: t.Dict = {}
 258    FORMAT_TRIE: t.Dict = {}
 259
 260    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 261    INVERSE_TIME_TRIE: t.Dict = {}
 262
 263    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 264
 265    # Delimiters for string literals and identifiers
 266    QUOTE_START = "'"
 267    QUOTE_END = "'"
 268    IDENTIFIER_START = '"'
 269    IDENTIFIER_END = '"'
 270
 271    # Delimiters for bit, hex, byte and unicode literals
 272    BIT_START: t.Optional[str] = None
 273    BIT_END: t.Optional[str] = None
 274    HEX_START: t.Optional[str] = None
 275    HEX_END: t.Optional[str] = None
 276    BYTE_START: t.Optional[str] = None
 277    BYTE_END: t.Optional[str] = None
 278    UNICODE_START: t.Optional[str] = None
 279    UNICODE_END: t.Optional[str] = None
 280
 281    @classmethod
 282    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 283        """
 284        Look up a dialect in the global dialect registry and return it if it exists.
 285
 286        Args:
 287            dialect: The target dialect. If this is a string, it can be optionally followed by
 288                additional key-value pairs that are separated by commas and are used to specify
 289                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 290
 291        Example:
 292            >>> dialect = dialect_class = get_or_raise("duckdb")
 293            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 294
 295        Returns:
 296            The corresponding Dialect instance.
 297        """
 298
 299        if not dialect:
 300            return cls()
 301        if isinstance(dialect, _Dialect):
 302            return dialect()
 303        if isinstance(dialect, Dialect):
 304            return dialect
 305        if isinstance(dialect, str):
 306            try:
 307                dialect_name, *kv_pairs = dialect.split(",")
 308                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 309            except ValueError:
 310                raise ValueError(
 311                    f"Invalid dialect format: '{dialect}'. "
 312                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 313                )
 314
 315            result = cls.get(dialect_name.strip())
 316            if not result:
 317                from difflib import get_close_matches
 318
 319                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 320                if similar:
 321                    similar = f" Did you mean {similar}?"
 322
 323                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 324
 325            return result(**kwargs)
 326
 327        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 328
 329    @classmethod
 330    def format_time(
 331        cls, expression: t.Optional[str | exp.Expression]
 332    ) -> t.Optional[exp.Expression]:
 333        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 334        if isinstance(expression, str):
 335            return exp.Literal.string(
 336                # the time formats are quoted
 337                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 338            )
 339
 340        if expression and expression.is_string:
 341            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 342
 343        return expression
 344
 345    def __init__(self, **kwargs) -> None:
 346        normalization_strategy = kwargs.get("normalization_strategy")
 347
 348        if normalization_strategy is None:
 349            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 350        else:
 351            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 352
 353    def __eq__(self, other: t.Any) -> bool:
 354        # Does not currently take dialect state into account
 355        return type(self) == other
 356
 357    def __hash__(self) -> int:
 358        # Does not currently take dialect state into account
 359        return hash(type(self))
 360
 361    def normalize_identifier(self, expression: E) -> E:
 362        """
 363        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 364
 365        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 366        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 367        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 368        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 369
 370        There are also dialects like Spark, which are case-insensitive even when quotes are
 371        present, and dialects like MySQL, whose resolution rules match those employed by the
 372        underlying operating system, for example they may always be case-sensitive in Linux.
 373
 374        Finally, the normalization behavior of some engines can even be controlled through flags,
 375        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 376
 377        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 378        that it can analyze queries in the optimizer and successfully capture their semantics.
 379        """
 380        if (
 381            isinstance(expression, exp.Identifier)
 382            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 383            and (
 384                not expression.quoted
 385                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 386            )
 387        ):
 388            expression.set(
 389                "this",
 390                (
 391                    expression.this.upper()
 392                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 393                    else expression.this.lower()
 394                ),
 395            )
 396
 397        return expression
 398
 399    def case_sensitive(self, text: str) -> bool:
 400        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 401        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 402            return False
 403
 404        unsafe = (
 405            str.islower
 406            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 407            else str.isupper
 408        )
 409        return any(unsafe(char) for char in text)
 410
 411    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 412        """Checks if text can be identified given an identify option.
 413
 414        Args:
 415            text: The text to check.
 416            identify:
 417                `"always"` or `True`: Always returns `True`.
 418                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 419
 420        Returns:
 421            Whether or not the given text can be identified.
 422        """
 423        if identify is True or identify == "always":
 424            return True
 425
 426        if identify == "safe":
 427            return not self.case_sensitive(text)
 428
 429        return False
 430
 431    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 432        """
 433        Adds quotes to a given identifier.
 434
 435        Args:
 436            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 437            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 438                "unsafe", with respect to its characters and this dialect's normalization strategy.
 439        """
 440        if isinstance(expression, exp.Identifier):
 441            name = expression.this
 442            expression.set(
 443                "quoted",
 444                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 445            )
 446
 447        return expression
 448
 449    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 450        if isinstance(path, exp.Literal):
 451            path_text = path.name
 452            if path.is_number:
 453                path_text = f"[{path_text}]"
 454
 455            try:
 456                return parse_json_path(path_text)
 457            except ParseError as e:
 458                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 459
 460        return path
 461
 462    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 463        return self.parser(**opts).parse(self.tokenize(sql), sql)
 464
 465    def parse_into(
 466        self, expression_type: exp.IntoType, sql: str, **opts
 467    ) -> t.List[t.Optional[exp.Expression]]:
 468        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 469
 470    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 471        return self.generator(**opts).generate(expression, copy=copy)
 472
 473    def transpile(self, sql: str, **opts) -> t.List[str]:
 474        return [
 475            self.generate(expression, copy=False, **opts) if expression else ""
 476            for expression in self.parse(sql)
 477        ]
 478
 479    def tokenize(self, sql: str) -> t.List[Token]:
 480        return self.tokenizer.tokenize(sql)
 481
 482    @property
 483    def tokenizer(self) -> Tokenizer:
 484        if not hasattr(self, "_tokenizer"):
 485            self._tokenizer = self.tokenizer_class(dialect=self)
 486        return self._tokenizer
 487
 488    def parser(self, **opts) -> Parser:
 489        return self.parser_class(dialect=self, **opts)
 490
 491    def generator(self, **opts) -> Generator:
 492        return self.generator_class(dialect=self, **opts)
 493
 494
 495DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 496
 497
 498def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 499    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 500
 501
 502def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 503    if expression.args.get("accuracy"):
 504        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 505    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 506
 507
 508def if_sql(
 509    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 510) -> t.Callable[[Generator, exp.If], str]:
 511    def _if_sql(self: Generator, expression: exp.If) -> str:
 512        return self.func(
 513            name,
 514            expression.this,
 515            expression.args.get("true"),
 516            expression.args.get("false") or false_value,
 517        )
 518
 519    return _if_sql
 520
 521
 522def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 523    this = expression.this
 524    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 525        this.replace(exp.cast(this, "json"))
 526
 527    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 528
 529
 530def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 531    return f"[{self.expressions(expression, flat=True)}]"
 532
 533
 534def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 535    return self.like_sql(
 536        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 537    )
 538
 539
 540def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 541    zone = self.sql(expression, "this")
 542    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 543
 544
 545def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 546    if expression.args.get("recursive"):
 547        self.unsupported("Recursive CTEs are unsupported")
 548        expression.args["recursive"] = False
 549    return self.with_sql(expression)
 550
 551
 552def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 553    n = self.sql(expression, "this")
 554    d = self.sql(expression, "expression")
 555    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 556
 557
 558def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 559    self.unsupported("TABLESAMPLE unsupported")
 560    return self.sql(expression.this)
 561
 562
 563def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 564    self.unsupported("PIVOT unsupported")
 565    return ""
 566
 567
 568def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 569    return self.cast_sql(expression)
 570
 571
 572def no_comment_column_constraint_sql(
 573    self: Generator, expression: exp.CommentColumnConstraint
 574) -> str:
 575    self.unsupported("CommentColumnConstraint unsupported")
 576    return ""
 577
 578
 579def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 580    self.unsupported("MAP_FROM_ENTRIES unsupported")
 581    return ""
 582
 583
 584def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 585    this = self.sql(expression, "this")
 586    substr = self.sql(expression, "substr")
 587    position = self.sql(expression, "position")
 588    if position:
 589        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 590    return f"STRPOS({this}, {substr})"
 591
 592
 593def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 594    return (
 595        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 596    )
 597
 598
 599def var_map_sql(
 600    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 601) -> str:
 602    keys = expression.args["keys"]
 603    values = expression.args["values"]
 604
 605    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 606        self.unsupported("Cannot convert array columns into map.")
 607        return self.func(map_func_name, keys, values)
 608
 609    args = []
 610    for key, value in zip(keys.expressions, values.expressions):
 611        args.append(self.sql(key))
 612        args.append(self.sql(value))
 613
 614    return self.func(map_func_name, *args)
 615
 616
 617def format_time_lambda(
 618    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 619) -> t.Callable[[t.List], E]:
 620    """Helper used for time expressions.
 621
 622    Args:
 623        exp_class: the expression class to instantiate.
 624        dialect: target sql dialect.
 625        default: the default format, True being time.
 626
 627    Returns:
 628        A callable that can be used to return the appropriately formatted time expression.
 629    """
 630
 631    def _format_time(args: t.List):
 632        return exp_class(
 633            this=seq_get(args, 0),
 634            format=Dialect[dialect].format_time(
 635                seq_get(args, 1)
 636                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 637            ),
 638        )
 639
 640    return _format_time
 641
 642
 643def time_format(
 644    dialect: DialectType = None,
 645) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 646    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 647        """
 648        Returns the time format for a given expression, unless it's equivalent
 649        to the default time format of the dialect of interest.
 650        """
 651        time_format = self.format_time(expression)
 652        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 653
 654    return _time_format
 655
 656
 657def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
 658    """
 659    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
 660    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
 661    columns are removed from the create statement.
 662    """
 663    has_schema = isinstance(expression.this, exp.Schema)
 664    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
 665
 666    if has_schema and is_partitionable:
 667        prop = expression.find(exp.PartitionedByProperty)
 668        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 669            schema = expression.this
 670            columns = {v.name.upper() for v in prop.this.expressions}
 671            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 672            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 673            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 674            expression.set("this", schema)
 675
 676    return self.create_sql(expression)
 677
 678
 679def parse_date_delta(
 680    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 681) -> t.Callable[[t.List], E]:
 682    def inner_func(args: t.List) -> E:
 683        unit_based = len(args) == 3
 684        this = args[2] if unit_based else seq_get(args, 0)
 685        unit = args[0] if unit_based else exp.Literal.string("DAY")
 686        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 687        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 688
 689    return inner_func
 690
 691
 692def parse_date_delta_with_interval(
 693    expression_class: t.Type[E],
 694) -> t.Callable[[t.List], t.Optional[E]]:
 695    def func(args: t.List) -> t.Optional[E]:
 696        if len(args) < 2:
 697            return None
 698
 699        interval = args[1]
 700
 701        if not isinstance(interval, exp.Interval):
 702            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 703
 704        expression = interval.this
 705        if expression and expression.is_string:
 706            expression = exp.Literal.number(expression.this)
 707
 708        return expression_class(
 709            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 710        )
 711
 712    return func
 713
 714
 715def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 716    unit = seq_get(args, 0)
 717    this = seq_get(args, 1)
 718
 719    if isinstance(this, exp.Cast) and this.is_type("date"):
 720        return exp.DateTrunc(unit=unit, this=this)
 721    return exp.TimestampTrunc(this=this, unit=unit)
 722
 723
 724def date_add_interval_sql(
 725    data_type: str, kind: str
 726) -> t.Callable[[Generator, exp.Expression], str]:
 727    def func(self: Generator, expression: exp.Expression) -> str:
 728        this = self.sql(expression, "this")
 729        unit = expression.args.get("unit")
 730        unit = exp.var(unit.name.upper() if unit else "DAY")
 731        interval = exp.Interval(this=expression.expression, unit=unit)
 732        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 733
 734    return func
 735
 736
 737def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 738    return self.func(
 739        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 740    )
 741
 742
 743def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 744    if not expression.expression:
 745        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
 746    if expression.text("expression").lower() in TIMEZONES:
 747        return self.sql(
 748            exp.AtTimeZone(
 749                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 750                zone=expression.expression,
 751            )
 752        )
 753    return self.function_fallback_sql(expression)
 754
 755
 756def locate_to_strposition(args: t.List) -> exp.Expression:
 757    return exp.StrPosition(
 758        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 759    )
 760
 761
 762def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 763    return self.func(
 764        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 765    )
 766
 767
 768def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 769    return self.sql(
 770        exp.Substring(
 771            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 772        )
 773    )
 774
 775
 776def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 777    return self.sql(
 778        exp.Substring(
 779            this=expression.this,
 780            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 781        )
 782    )
 783
 784
 785def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 786    return self.sql(exp.cast(expression.this, "timestamp"))
 787
 788
 789def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 790    return self.sql(exp.cast(expression.this, "date"))
 791
 792
 793# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 794def encode_decode_sql(
 795    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 796) -> str:
 797    charset = expression.args.get("charset")
 798    if charset and charset.name.lower() != "utf-8":
 799        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 800
 801    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 802
 803
 804def min_or_least(self: Generator, expression: exp.Min) -> str:
 805    name = "LEAST" if expression.expressions else "MIN"
 806    return rename_func(name)(self, expression)
 807
 808
 809def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 810    name = "GREATEST" if expression.expressions else "MAX"
 811    return rename_func(name)(self, expression)
 812
 813
 814def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 815    cond = expression.this
 816
 817    if isinstance(expression.this, exp.Distinct):
 818        cond = expression.this.expressions[0]
 819        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 820
 821    return self.func("sum", exp.func("if", cond, 1, 0))
 822
 823
 824def trim_sql(self: Generator, expression: exp.Trim) -> str:
 825    target = self.sql(expression, "this")
 826    trim_type = self.sql(expression, "position")
 827    remove_chars = self.sql(expression, "expression")
 828    collation = self.sql(expression, "collation")
 829
 830    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 831    if not remove_chars and not collation:
 832        return self.trim_sql(expression)
 833
 834    trim_type = f"{trim_type} " if trim_type else ""
 835    remove_chars = f"{remove_chars} " if remove_chars else ""
 836    from_part = "FROM " if trim_type or remove_chars else ""
 837    collation = f" COLLATE {collation}" if collation else ""
 838    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 839
 840
 841def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 842    return self.func("STRPTIME", expression.this, self.format_time(expression))
 843
 844
 845def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 846    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 847
 848
 849def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 850    delim, *rest_args = expression.expressions
 851    return self.sql(
 852        reduce(
 853            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 854            rest_args,
 855        )
 856    )
 857
 858
 859def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 860    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 861    if bad_args:
 862        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 863
 864    return self.func(
 865        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 866    )
 867
 868
 869def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 870    bad_args = list(
 871        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 872    )
 873    if bad_args:
 874        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 875
 876    return self.func(
 877        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 878    )
 879
 880
 881def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 882    names = []
 883    for agg in aggregations:
 884        if isinstance(agg, exp.Alias):
 885            names.append(agg.alias)
 886        else:
 887            """
 888            This case corresponds to aggregations without aliases being used as suffixes
 889            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 890            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 891            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 892            """
 893            agg_all_unquoted = agg.transform(
 894                lambda node: (
 895                    exp.Identifier(this=node.name, quoted=False)
 896                    if isinstance(node, exp.Identifier)
 897                    else node
 898                )
 899            )
 900            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 901
 902    return names
 903
 904
 905def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 906    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 907
 908
 909# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 910def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 911    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 912
 913
 914def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 915    return self.func("MAX", expression.this)
 916
 917
 918def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 919    a = self.sql(expression.left)
 920    b = self.sql(expression.right)
 921    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 922
 923
 924def is_parse_json(expression: exp.Expression) -> bool:
 925    return isinstance(expression, exp.ParseJSON) or (
 926        isinstance(expression, exp.Cast) and expression.is_type("json")
 927    )
 928
 929
 930def isnull_to_is_null(args: t.List) -> exp.Expression:
 931    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 932
 933
 934def generatedasidentitycolumnconstraint_sql(
 935    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 936) -> str:
 937    start = self.sql(expression, "start") or "1"
 938    increment = self.sql(expression, "increment") or "1"
 939    return f"IDENTITY({start}, {increment})"
 940
 941
 942def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 943    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 944        if expression.args.get("count"):
 945            self.unsupported(f"Only two arguments are supported in function {name}.")
 946
 947        return self.func(name, expression.this, expression.expression)
 948
 949    return _arg_max_or_min_sql
 950
 951
 952def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 953    this = expression.this.copy()
 954
 955    return_type = expression.return_type
 956    if return_type.is_type(exp.DataType.Type.DATE):
 957        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 958        # can truncate timestamp strings, because some dialects can't cast them to DATE
 959        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 960
 961    expression.this.replace(exp.cast(this, return_type))
 962    return expression
 963
 964
 965def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 966    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 967        if cast and isinstance(expression, exp.TsOrDsAdd):
 968            expression = ts_or_ds_add_cast(expression)
 969
 970        return self.func(
 971            name,
 972            exp.var(expression.text("unit").upper() or "DAY"),
 973            expression.expression,
 974            expression.this,
 975        )
 976
 977    return _delta_sql
 978
 979
 980def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 981    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 982    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 983    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 984
 985    return self.sql(exp.cast(minus_one_day, "date"))
 986
 987
 988def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 989    """Remove table refs from columns in when statements."""
 990    alias = expression.this.args.get("alias")
 991
 992    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 993        return self.dialect.normalize_identifier(identifier).name if identifier else None
 994
 995    targets = {normalize(expression.this.this)}
 996
 997    if alias:
 998        targets.add(normalize(alias.this))
 999
1000    for when in expression.expressions:
1001        when.transform(
1002            lambda node: (
1003                exp.column(node.this)
1004                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1005                else node
1006            ),
1007            copy=False,
1008        )
1009
1010    return self.merge_sql(expression)
1011
1012
1013def parse_json_extract_path(
1014    expr_type: t.Type[F], zero_based_indexing: bool = True
1015) -> t.Callable[[t.List], F]:
1016    def _parse_json_extract_path(args: t.List) -> F:
1017        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1018        for arg in args[1:]:
1019            if not isinstance(arg, exp.Literal):
1020                # We use the fallback parser because we can't really transpile non-literals safely
1021                return expr_type.from_arg_list(args)
1022
1023            text = arg.name
1024            if is_int(text):
1025                index = int(text)
1026                segments.append(
1027                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1028                )
1029            else:
1030                segments.append(exp.JSONPathKey(this=text))
1031
1032        # This is done to avoid failing in the expression validator due to the arg count
1033        del args[2:]
1034        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1035
1036    return _parse_json_extract_path
1037
1038
1039def json_extract_segments(
1040    name: str, quoted_index: bool = True
1041) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1042    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1043        path = expression.expression
1044        if not isinstance(path, exp.JSONPath):
1045            return rename_func(name)(self, expression)
1046
1047        segments = []
1048        for segment in path.expressions:
1049            path = self.sql(segment)
1050            if path:
1051                if isinstance(segment, exp.JSONPathPart) and (
1052                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1053                ):
1054                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1055
1056                segments.append(path)
1057
1058        return self.func(name, expression.this, *segments)
1059
1060    return _json_extract_segments
1061
1062
1063def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1064    if isinstance(expression.this, exp.JSONPathWildcard):
1065        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1066
1067    return expression.name
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    BIGQUERY = "bigquery"
36    CLICKHOUSE = "clickhouse"
37    DATABRICKS = "databricks"
38    DORIS = "doris"
39    DRILL = "drill"
40    DUCKDB = "duckdb"
41    HIVE = "hive"
42    MYSQL = "mysql"
43    ORACLE = "oracle"
44    POSTGRES = "postgres"
45    PRESTO = "presto"
46    REDSHIFT = "redshift"
47    SNOWFLAKE = "snowflake"
48    SPARK = "spark"
49    SPARK2 = "spark2"
50    SQLITE = "sqlite"
51    STARROCKS = "starrocks"
52    TABLEAU = "tableau"
53    TERADATA = "teradata"
54    TRINO = "trino"
55    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
58class NormalizationStrategy(str, AutoName):
59    """Specifies the strategy according to which identifiers should be normalized."""
60
61    LOWERCASE = auto()
62    """Unquoted identifiers are lowercased."""
63
64    UPPERCASE = auto()
65    """Unquoted identifiers are uppercased."""
66
67    CASE_SENSITIVE = auto()
68    """Always case-sensitive, regardless of quotes."""
69
70    CASE_INSENSITIVE = auto()
71    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
150class Dialect(metaclass=_Dialect):
151    INDEX_OFFSET = 0
152    """Determines the base index offset for arrays."""
153
154    WEEK_OFFSET = 0
155    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
156
157    UNNEST_COLUMN_ONLY = False
158    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
159
160    ALIAS_POST_TABLESAMPLE = False
161    """Determines whether or not the table alias comes after tablesample."""
162
163    TABLESAMPLE_SIZE_IS_PERCENT = False
164    """Determines whether or not a size in the table sample clause represents percentage."""
165
166    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
167    """Specifies the strategy according to which identifiers should be normalized."""
168
169    IDENTIFIERS_CAN_START_WITH_DIGIT = False
170    """Determines whether or not an unquoted identifier can start with a digit."""
171
172    DPIPE_IS_STRING_CONCAT = True
173    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
174
175    STRICT_STRING_CONCAT = False
176    """Determines whether or not `CONCAT`'s arguments must be strings."""
177
178    SUPPORTS_USER_DEFINED_TYPES = True
179    """Determines whether or not user-defined data types are supported."""
180
181    SUPPORTS_SEMI_ANTI_JOIN = True
182    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
183
184    NORMALIZE_FUNCTIONS: bool | str = "upper"
185    """Determines how function names are going to be normalized."""
186
187    LOG_BASE_FIRST = True
188    """Determines whether the base comes first in the `LOG` function."""
189
190    NULL_ORDERING = "nulls_are_small"
191    """
192    Indicates the default `NULL` ordering method to use if not explicitly set.
193    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
194    """
195
196    TYPED_DIVISION = False
197    """
198    Whether the behavior of `a / b` depends on the types of `a` and `b`.
199    False means `a / b` is always float division.
200    True means `a / b` is integer division if both `a` and `b` are integers.
201    """
202
203    SAFE_DIVISION = False
204    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
205
206    CONCAT_COALESCE = False
207    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
208
209    DATE_FORMAT = "'%Y-%m-%d'"
210    DATEINT_FORMAT = "'%Y%m%d'"
211    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
212
213    TIME_MAPPING: t.Dict[str, str] = {}
214    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
215
216    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
217    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
218    FORMAT_MAPPING: t.Dict[str, str] = {}
219    """
220    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
221    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
222    """
223
224    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
225    """Mapping of an unescaped escape sequence to the corresponding character."""
226
227    PSEUDOCOLUMNS: t.Set[str] = set()
228    """
229    Columns that are auto-generated by the engine corresponding to this dialect.
230    For example, such columns may be excluded from `SELECT *` queries.
231    """
232
233    PREFER_CTE_ALIAS_COLUMN = False
234    """
235    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
236    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
237    any projection aliases in the subquery.
238
239    For example,
240        WITH y(c) AS (
241            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
242        ) SELECT c FROM y;
243
244        will be rewritten as
245
246        WITH y(c) AS (
247            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
248        ) SELECT c FROM y;
249    """
250
251    # --- Autofilled ---
252
253    tokenizer_class = Tokenizer
254    parser_class = Parser
255    generator_class = Generator
256
257    # A trie of the time_mapping keys
258    TIME_TRIE: t.Dict = {}
259    FORMAT_TRIE: t.Dict = {}
260
261    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
262    INVERSE_TIME_TRIE: t.Dict = {}
263
264    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
265
266    # Delimiters for string literals and identifiers
267    QUOTE_START = "'"
268    QUOTE_END = "'"
269    IDENTIFIER_START = '"'
270    IDENTIFIER_END = '"'
271
272    # Delimiters for bit, hex, byte and unicode literals
273    BIT_START: t.Optional[str] = None
274    BIT_END: t.Optional[str] = None
275    HEX_START: t.Optional[str] = None
276    HEX_END: t.Optional[str] = None
277    BYTE_START: t.Optional[str] = None
278    BYTE_END: t.Optional[str] = None
279    UNICODE_START: t.Optional[str] = None
280    UNICODE_END: t.Optional[str] = None
281
282    @classmethod
283    def get_or_raise(cls, dialect: DialectType) -> Dialect:
284        """
285        Look up a dialect in the global dialect registry and return it if it exists.
286
287        Args:
288            dialect: The target dialect. If this is a string, it can be optionally followed by
289                additional key-value pairs that are separated by commas and are used to specify
290                dialect settings, such as whether the dialect's identifiers are case-sensitive.
291
292        Example:
293            >>> dialect = dialect_class = get_or_raise("duckdb")
294            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
295
296        Returns:
297            The corresponding Dialect instance.
298        """
299
300        if not dialect:
301            return cls()
302        if isinstance(dialect, _Dialect):
303            return dialect()
304        if isinstance(dialect, Dialect):
305            return dialect
306        if isinstance(dialect, str):
307            try:
308                dialect_name, *kv_pairs = dialect.split(",")
309                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
310            except ValueError:
311                raise ValueError(
312                    f"Invalid dialect format: '{dialect}'. "
313                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
314                )
315
316            result = cls.get(dialect_name.strip())
317            if not result:
318                from difflib import get_close_matches
319
320                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
321                if similar:
322                    similar = f" Did you mean {similar}?"
323
324                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
325
326            return result(**kwargs)
327
328        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
329
330    @classmethod
331    def format_time(
332        cls, expression: t.Optional[str | exp.Expression]
333    ) -> t.Optional[exp.Expression]:
334        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
335        if isinstance(expression, str):
336            return exp.Literal.string(
337                # the time formats are quoted
338                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
339            )
340
341        if expression and expression.is_string:
342            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
343
344        return expression
345
346    def __init__(self, **kwargs) -> None:
347        normalization_strategy = kwargs.get("normalization_strategy")
348
349        if normalization_strategy is None:
350            self.normalization_strategy = self.NORMALIZATION_STRATEGY
351        else:
352            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
353
354    def __eq__(self, other: t.Any) -> bool:
355        # Does not currently take dialect state into account
356        return type(self) == other
357
358    def __hash__(self) -> int:
359        # Does not currently take dialect state into account
360        return hash(type(self))
361
362    def normalize_identifier(self, expression: E) -> E:
363        """
364        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
365
366        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
367        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
368        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
369        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
370
371        There are also dialects like Spark, which are case-insensitive even when quotes are
372        present, and dialects like MySQL, whose resolution rules match those employed by the
373        underlying operating system, for example they may always be case-sensitive in Linux.
374
375        Finally, the normalization behavior of some engines can even be controlled through flags,
376        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
377
378        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
379        that it can analyze queries in the optimizer and successfully capture their semantics.
380        """
381        if (
382            isinstance(expression, exp.Identifier)
383            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
384            and (
385                not expression.quoted
386                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
387            )
388        ):
389            expression.set(
390                "this",
391                (
392                    expression.this.upper()
393                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
394                    else expression.this.lower()
395                ),
396            )
397
398        return expression
399
400    def case_sensitive(self, text: str) -> bool:
401        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
402        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
403            return False
404
405        unsafe = (
406            str.islower
407            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
408            else str.isupper
409        )
410        return any(unsafe(char) for char in text)
411
412    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
413        """Checks if text can be identified given an identify option.
414
415        Args:
416            text: The text to check.
417            identify:
418                `"always"` or `True`: Always returns `True`.
419                `"safe"`: Only returns `True` if the identifier is case-insensitive.
420
421        Returns:
422            Whether or not the given text can be identified.
423        """
424        if identify is True or identify == "always":
425            return True
426
427        if identify == "safe":
428            return not self.case_sensitive(text)
429
430        return False
431
432    def quote_identifier(self, expression: E, identify: bool = True) -> E:
433        """
434        Adds quotes to a given identifier.
435
436        Args:
437            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
438            identify: If set to `False`, the quotes will only be added if the identifier is deemed
439                "unsafe", with respect to its characters and this dialect's normalization strategy.
440        """
441        if isinstance(expression, exp.Identifier):
442            name = expression.this
443            expression.set(
444                "quoted",
445                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
446            )
447
448        return expression
449
450    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
451        if isinstance(path, exp.Literal):
452            path_text = path.name
453            if path.is_number:
454                path_text = f"[{path_text}]"
455
456            try:
457                return parse_json_path(path_text)
458            except ParseError as e:
459                logger.warning(f"Invalid JSON path syntax. {str(e)}")
460
461        return path
462
463    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
464        return self.parser(**opts).parse(self.tokenize(sql), sql)
465
466    def parse_into(
467        self, expression_type: exp.IntoType, sql: str, **opts
468    ) -> t.List[t.Optional[exp.Expression]]:
469        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
470
471    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
472        return self.generator(**opts).generate(expression, copy=copy)
473
474    def transpile(self, sql: str, **opts) -> t.List[str]:
475        return [
476            self.generate(expression, copy=False, **opts) if expression else ""
477            for expression in self.parse(sql)
478        ]
479
480    def tokenize(self, sql: str) -> t.List[Token]:
481        return self.tokenizer.tokenize(sql)
482
483    @property
484    def tokenizer(self) -> Tokenizer:
485        if not hasattr(self, "_tokenizer"):
486            self._tokenizer = self.tokenizer_class(dialect=self)
487        return self._tokenizer
488
489    def parser(self, **opts) -> Parser:
490        return self.parser_class(dialect=self, **opts)
491
492    def generator(self, **opts) -> Generator:
493        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
346    def __init__(self, **kwargs) -> None:
347        normalization_strategy = kwargs.get("normalization_strategy")
348
349        if normalization_strategy is None:
350            self.normalization_strategy = self.NORMALIZATION_STRATEGY
351        else:
352            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

Determines the base index offset for arrays.

WEEK_OFFSET = 0

Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Determines whether or not UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Determines whether or not the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Determines whether or not a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Determines whether or not an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Determines whether or not the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Determines whether or not CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Determines whether or not user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Determines whether or not SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

LOG_BASE_FIRST = True

Determines whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Indicates the default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Determines whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime format.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
282    @classmethod
283    def get_or_raise(cls, dialect: DialectType) -> Dialect:
284        """
285        Look up a dialect in the global dialect registry and return it if it exists.
286
287        Args:
288            dialect: The target dialect. If this is a string, it can be optionally followed by
289                additional key-value pairs that are separated by commas and are used to specify
290                dialect settings, such as whether the dialect's identifiers are case-sensitive.
291
292        Example:
293            >>> dialect = dialect_class = get_or_raise("duckdb")
294            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
295
296        Returns:
297            The corresponding Dialect instance.
298        """
299
300        if not dialect:
301            return cls()
302        if isinstance(dialect, _Dialect):
303            return dialect()
304        if isinstance(dialect, Dialect):
305            return dialect
306        if isinstance(dialect, str):
307            try:
308                dialect_name, *kv_pairs = dialect.split(",")
309                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
310            except ValueError:
311                raise ValueError(
312                    f"Invalid dialect format: '{dialect}'. "
313                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
314                )
315
316            result = cls.get(dialect_name.strip())
317            if not result:
318                from difflib import get_close_matches
319
320                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
321                if similar:
322                    similar = f" Did you mean {similar}?"
323
324                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
325
326            return result(**kwargs)
327
328        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
330    @classmethod
331    def format_time(
332        cls, expression: t.Optional[str | exp.Expression]
333    ) -> t.Optional[exp.Expression]:
334        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
335        if isinstance(expression, str):
336            return exp.Literal.string(
337                # the time formats are quoted
338                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
339            )
340
341        if expression and expression.is_string:
342            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
343
344        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
362    def normalize_identifier(self, expression: E) -> E:
363        """
364        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
365
366        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
367        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
368        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
369        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
370
371        There are also dialects like Spark, which are case-insensitive even when quotes are
372        present, and dialects like MySQL, whose resolution rules match those employed by the
373        underlying operating system, for example they may always be case-sensitive in Linux.
374
375        Finally, the normalization behavior of some engines can even be controlled through flags,
376        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
377
378        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
379        that it can analyze queries in the optimizer and successfully capture their semantics.
380        """
381        if (
382            isinstance(expression, exp.Identifier)
383            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
384            and (
385                not expression.quoted
386                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
387            )
388        ):
389            expression.set(
390                "this",
391                (
392                    expression.this.upper()
393                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
394                    else expression.this.lower()
395                ),
396            )
397
398        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
400    def case_sensitive(self, text: str) -> bool:
401        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
402        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
403            return False
404
405        unsafe = (
406            str.islower
407            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
408            else str.isupper
409        )
410        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
412    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
413        """Checks if text can be identified given an identify option.
414
415        Args:
416            text: The text to check.
417            identify:
418                `"always"` or `True`: Always returns `True`.
419                `"safe"`: Only returns `True` if the identifier is case-insensitive.
420
421        Returns:
422            Whether or not the given text can be identified.
423        """
424        if identify is True or identify == "always":
425            return True
426
427        if identify == "safe":
428            return not self.case_sensitive(text)
429
430        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
432    def quote_identifier(self, expression: E, identify: bool = True) -> E:
433        """
434        Adds quotes to a given identifier.
435
436        Args:
437            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
438            identify: If set to `False`, the quotes will only be added if the identifier is deemed
439                "unsafe", with respect to its characters and this dialect's normalization strategy.
440        """
441        if isinstance(expression, exp.Identifier):
442            name = expression.this
443            expression.set(
444                "quoted",
445                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
446            )
447
448        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
450    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
451        if isinstance(path, exp.Literal):
452            path_text = path.name
453            if path.is_number:
454                path_text = f"[{path_text}]"
455
456            try:
457                return parse_json_path(path_text)
458            except ParseError as e:
459                logger.warning(f"Invalid JSON path syntax. {str(e)}")
460
461        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
463    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
464        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
466    def parse_into(
467        self, expression_type: exp.IntoType, sql: str, **opts
468    ) -> t.List[t.Optional[exp.Expression]]:
469        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
471    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
472        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
474    def transpile(self, sql: str, **opts) -> t.List[str]:
475        return [
476            self.generate(expression, copy=False, **opts) if expression else ""
477            for expression in self.parse(sql)
478        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
480    def tokenize(self, sql: str) -> t.List[Token]:
481        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
483    @property
484    def tokenizer(self) -> Tokenizer:
485        if not hasattr(self, "_tokenizer"):
486            self._tokenizer = self.tokenizer_class(dialect=self)
487        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
489    def parser(self, **opts) -> Parser:
490        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
492    def generator(self, **opts) -> Generator:
493        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
499def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
500    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
503def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
504    if expression.args.get("accuracy"):
505        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
506    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
509def if_sql(
510    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
511) -> t.Callable[[Generator, exp.If], str]:
512    def _if_sql(self: Generator, expression: exp.If) -> str:
513        return self.func(
514            name,
515            expression.this,
516            expression.args.get("true"),
517            expression.args.get("false") or false_value,
518        )
519
520    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
523def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
524    this = expression.this
525    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
526        this.replace(exp.cast(this, "json"))
527
528    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
531def inline_array_sql(self: Generator, expression: exp.Array) -> str:
532    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
535def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
536    return self.like_sql(
537        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
538    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
541def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
542    zone = self.sql(expression, "this")
543    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
546def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
547    if expression.args.get("recursive"):
548        self.unsupported("Recursive CTEs are unsupported")
549        expression.args["recursive"] = False
550    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
553def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
554    n = self.sql(expression, "this")
555    d = self.sql(expression, "expression")
556    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
559def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
560    self.unsupported("TABLESAMPLE unsupported")
561    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
564def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
565    self.unsupported("PIVOT unsupported")
566    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
569def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
570    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
573def no_comment_column_constraint_sql(
574    self: Generator, expression: exp.CommentColumnConstraint
575) -> str:
576    self.unsupported("CommentColumnConstraint unsupported")
577    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
580def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
581    self.unsupported("MAP_FROM_ENTRIES unsupported")
582    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
585def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
586    this = self.sql(expression, "this")
587    substr = self.sql(expression, "substr")
588    position = self.sql(expression, "position")
589    if position:
590        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
591    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
594def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
595    return (
596        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
597    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
600def var_map_sql(
601    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
602) -> str:
603    keys = expression.args["keys"]
604    values = expression.args["values"]
605
606    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
607        self.unsupported("Cannot convert array columns into map.")
608        return self.func(map_func_name, keys, values)
609
610    args = []
611    for key, value in zip(keys.expressions, values.expressions):
612        args.append(self.sql(key))
613        args.append(self.sql(value))
614
615    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
618def format_time_lambda(
619    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
620) -> t.Callable[[t.List], E]:
621    """Helper used for time expressions.
622
623    Args:
624        exp_class: the expression class to instantiate.
625        dialect: target sql dialect.
626        default: the default format, True being time.
627
628    Returns:
629        A callable that can be used to return the appropriately formatted time expression.
630    """
631
632    def _format_time(args: t.List):
633        return exp_class(
634            this=seq_get(args, 0),
635            format=Dialect[dialect].format_time(
636                seq_get(args, 1)
637                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
638            ),
639        )
640
641    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
644def time_format(
645    dialect: DialectType = None,
646) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
647    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
648        """
649        Returns the time format for a given expression, unless it's equivalent
650        to the default time format of the dialect of interest.
651        """
652        time_format = self.format_time(expression)
653        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
654
655    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
658def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
659    """
660    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
661    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
662    columns are removed from the create statement.
663    """
664    has_schema = isinstance(expression.this, exp.Schema)
665    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
666
667    if has_schema and is_partitionable:
668        prop = expression.find(exp.PartitionedByProperty)
669        if prop and prop.this and not isinstance(prop.this, exp.Schema):
670            schema = expression.this
671            columns = {v.name.upper() for v in prop.this.expressions}
672            partitions = [col for col in schema.expressions if col.name.upper() in columns]
673            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
674            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
675            expression.set("this", schema)
676
677    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
680def parse_date_delta(
681    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
682) -> t.Callable[[t.List], E]:
683    def inner_func(args: t.List) -> E:
684        unit_based = len(args) == 3
685        this = args[2] if unit_based else seq_get(args, 0)
686        unit = args[0] if unit_based else exp.Literal.string("DAY")
687        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
688        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
689
690    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
693def parse_date_delta_with_interval(
694    expression_class: t.Type[E],
695) -> t.Callable[[t.List], t.Optional[E]]:
696    def func(args: t.List) -> t.Optional[E]:
697        if len(args) < 2:
698            return None
699
700        interval = args[1]
701
702        if not isinstance(interval, exp.Interval):
703            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
704
705        expression = interval.this
706        if expression and expression.is_string:
707            expression = exp.Literal.number(expression.this)
708
709        return expression_class(
710            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
711        )
712
713    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
716def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
717    unit = seq_get(args, 0)
718    this = seq_get(args, 1)
719
720    if isinstance(this, exp.Cast) and this.is_type("date"):
721        return exp.DateTrunc(unit=unit, this=this)
722    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
725def date_add_interval_sql(
726    data_type: str, kind: str
727) -> t.Callable[[Generator, exp.Expression], str]:
728    def func(self: Generator, expression: exp.Expression) -> str:
729        this = self.sql(expression, "this")
730        unit = expression.args.get("unit")
731        unit = exp.var(unit.name.upper() if unit else "DAY")
732        interval = exp.Interval(this=expression.expression, unit=unit)
733        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
734
735    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
738def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
739    return self.func(
740        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
741    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
744def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
745    if not expression.expression:
746        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
747    if expression.text("expression").lower() in TIMEZONES:
748        return self.sql(
749            exp.AtTimeZone(
750                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
751                zone=expression.expression,
752            )
753        )
754    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
757def locate_to_strposition(args: t.List) -> exp.Expression:
758    return exp.StrPosition(
759        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
760    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
763def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
764    return self.func(
765        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
766    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
769def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
770    return self.sql(
771        exp.Substring(
772            this=expression.this, start=exp.Literal.number(1), length=expression.expression
773        )
774    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
777def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
778    return self.sql(
779        exp.Substring(
780            this=expression.this,
781            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
782        )
783    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
786def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
787    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
790def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
791    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
795def encode_decode_sql(
796    self: Generator, expression: exp.Expression, name: str, replace: bool = True
797) -> str:
798    charset = expression.args.get("charset")
799    if charset and charset.name.lower() != "utf-8":
800        self.unsupported(f"Expected utf-8 character set, got {charset}.")
801
802    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
805def min_or_least(self: Generator, expression: exp.Min) -> str:
806    name = "LEAST" if expression.expressions else "MIN"
807    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
810def max_or_greatest(self: Generator, expression: exp.Max) -> str:
811    name = "GREATEST" if expression.expressions else "MAX"
812    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
815def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
816    cond = expression.this
817
818    if isinstance(expression.this, exp.Distinct):
819        cond = expression.this.expressions[0]
820        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
821
822    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
825def trim_sql(self: Generator, expression: exp.Trim) -> str:
826    target = self.sql(expression, "this")
827    trim_type = self.sql(expression, "position")
828    remove_chars = self.sql(expression, "expression")
829    collation = self.sql(expression, "collation")
830
831    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
832    if not remove_chars and not collation:
833        return self.trim_sql(expression)
834
835    trim_type = f"{trim_type} " if trim_type else ""
836    remove_chars = f"{remove_chars} " if remove_chars else ""
837    from_part = "FROM " if trim_type or remove_chars else ""
838    collation = f" COLLATE {collation}" if collation else ""
839    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
842def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
843    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
846def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
847    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
850def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
851    delim, *rest_args = expression.expressions
852    return self.sql(
853        reduce(
854            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
855            rest_args,
856        )
857    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
860def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
861    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
862    if bad_args:
863        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
864
865    return self.func(
866        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
867    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
870def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
871    bad_args = list(
872        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
873    )
874    if bad_args:
875        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
876
877    return self.func(
878        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
879    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
882def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
883    names = []
884    for agg in aggregations:
885        if isinstance(agg, exp.Alias):
886            names.append(agg.alias)
887        else:
888            """
889            This case corresponds to aggregations without aliases being used as suffixes
890            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
891            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
892            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
893            """
894            agg_all_unquoted = agg.transform(
895                lambda node: (
896                    exp.Identifier(this=node.name, quoted=False)
897                    if isinstance(node, exp.Identifier)
898                    else node
899                )
900            )
901            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
902
903    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
906def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
907    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
911def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
912    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
915def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
916    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
919def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
920    a = self.sql(expression.left)
921    b = self.sql(expression.right)
922    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
925def is_parse_json(expression: exp.Expression) -> bool:
926    return isinstance(expression, exp.ParseJSON) or (
927        isinstance(expression, exp.Cast) and expression.is_type("json")
928    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
931def isnull_to_is_null(args: t.List) -> exp.Expression:
932    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
935def generatedasidentitycolumnconstraint_sql(
936    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
937) -> str:
938    start = self.sql(expression, "start") or "1"
939    increment = self.sql(expression, "increment") or "1"
940    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
943def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
944    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
945        if expression.args.get("count"):
946            self.unsupported(f"Only two arguments are supported in function {name}.")
947
948        return self.func(name, expression.this, expression.expression)
949
950    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
953def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
954    this = expression.this.copy()
955
956    return_type = expression.return_type
957    if return_type.is_type(exp.DataType.Type.DATE):
958        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
959        # can truncate timestamp strings, because some dialects can't cast them to DATE
960        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
961
962    expression.this.replace(exp.cast(this, return_type))
963    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
966def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
967    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
968        if cast and isinstance(expression, exp.TsOrDsAdd):
969            expression = ts_or_ds_add_cast(expression)
970
971        return self.func(
972            name,
973            exp.var(expression.text("unit").upper() or "DAY"),
974            expression.expression,
975            expression.this,
976        )
977
978    return _delta_sql
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
981def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
982    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
983    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
984    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
985
986    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
 989def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 990    """Remove table refs from columns in when statements."""
 991    alias = expression.this.args.get("alias")
 992
 993    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 994        return self.dialect.normalize_identifier(identifier).name if identifier else None
 995
 996    targets = {normalize(expression.this.this)}
 997
 998    if alias:
 999        targets.add(normalize(alias.this))
1000
1001    for when in expression.expressions:
1002        when.transform(
1003            lambda node: (
1004                exp.column(node.this)
1005                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1006                else node
1007            ),
1008            copy=False,
1009        )
1010
1011    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def parse_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True) -> Callable[[List], ~F]:
1014def parse_json_extract_path(
1015    expr_type: t.Type[F], zero_based_indexing: bool = True
1016) -> t.Callable[[t.List], F]:
1017    def _parse_json_extract_path(args: t.List) -> F:
1018        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1019        for arg in args[1:]:
1020            if not isinstance(arg, exp.Literal):
1021                # We use the fallback parser because we can't really transpile non-literals safely
1022                return expr_type.from_arg_list(args)
1023
1024            text = arg.name
1025            if is_int(text):
1026                index = int(text)
1027                segments.append(
1028                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1029                )
1030            else:
1031                segments.append(exp.JSONPathKey(this=text))
1032
1033        # This is done to avoid failing in the expression validator due to the arg count
1034        del args[2:]
1035        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1036
1037    return _parse_json_extract_path
def json_extract_segments( name: str, quoted_index: bool = True) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1040def json_extract_segments(
1041    name: str, quoted_index: bool = True
1042) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1043    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1044        path = expression.expression
1045        if not isinstance(path, exp.JSONPath):
1046            return rename_func(name)(self, expression)
1047
1048        segments = []
1049        for segment in path.expressions:
1050            path = self.sql(segment)
1051            if path:
1052                if isinstance(segment, exp.JSONPathPart) and (
1053                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1054                ):
1055                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1056
1057                segments.append(path)
1058
1059        return self.func(name, expression.this, *segments)
1060
1061    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1064def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1065    if isinstance(expression.this, exp.JSONPathWildcard):
1066        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1067
1068    return expression.name