Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20
  21if t.TYPE_CHECKING:
  22    from sqlglot._typing import B, E
  23
  24logger = logging.getLogger("sqlglot")
  25
  26
  27class Dialects(str, Enum):
  28    """Dialects supported by SQLGLot."""
  29
  30    DIALECT = ""
  31
  32    BIGQUERY = "bigquery"
  33    CLICKHOUSE = "clickhouse"
  34    DATABRICKS = "databricks"
  35    DORIS = "doris"
  36    DRILL = "drill"
  37    DUCKDB = "duckdb"
  38    HIVE = "hive"
  39    MYSQL = "mysql"
  40    ORACLE = "oracle"
  41    POSTGRES = "postgres"
  42    PRESTO = "presto"
  43    REDSHIFT = "redshift"
  44    SNOWFLAKE = "snowflake"
  45    SPARK = "spark"
  46    SPARK2 = "spark2"
  47    SQLITE = "sqlite"
  48    STARROCKS = "starrocks"
  49    TABLEAU = "tableau"
  50    TERADATA = "teradata"
  51    TRINO = "trino"
  52    TSQL = "tsql"
  53
  54
  55class NormalizationStrategy(str, AutoName):
  56    """Specifies the strategy according to which identifiers should be normalized."""
  57
  58    LOWERCASE = auto()
  59    """Unquoted identifiers are lowercased."""
  60
  61    UPPERCASE = auto()
  62    """Unquoted identifiers are uppercased."""
  63
  64    CASE_SENSITIVE = auto()
  65    """Always case-sensitive, regardless of quotes."""
  66
  67    CASE_INSENSITIVE = auto()
  68    """Always case-insensitive, regardless of quotes."""
  69
  70
  71class _Dialect(type):
  72    classes: t.Dict[str, t.Type[Dialect]] = {}
  73
  74    def __eq__(cls, other: t.Any) -> bool:
  75        if cls is other:
  76            return True
  77        if isinstance(other, str):
  78            return cls is cls.get(other)
  79        if isinstance(other, Dialect):
  80            return cls is type(other)
  81
  82        return False
  83
  84    def __hash__(cls) -> int:
  85        return hash(cls.__name__.lower())
  86
  87    @classmethod
  88    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  89        return cls.classes[key]
  90
  91    @classmethod
  92    def get(
  93        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  94    ) -> t.Optional[t.Type[Dialect]]:
  95        return cls.classes.get(key, default)
  96
  97    def __new__(cls, clsname, bases, attrs):
  98        klass = super().__new__(cls, clsname, bases, attrs)
  99        enum = Dialects.__members__.get(clsname.upper())
 100        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 101
 102        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 103        klass.FORMAT_TRIE = (
 104            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 105        )
 106        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 107        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 108
 109        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 110
 111        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 112        klass.parser_class = getattr(klass, "Parser", Parser)
 113        klass.generator_class = getattr(klass, "Generator", Generator)
 114
 115        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 116        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 117            klass.tokenizer_class._IDENTIFIERS.items()
 118        )[0]
 119
 120        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 121            return next(
 122                (
 123                    (s, e)
 124                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 125                    if t == token_type
 126                ),
 127                (None, None),
 128            )
 129
 130        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 131        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 132        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 133        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 134
 135        if enum not in ("", "bigquery"):
 136            klass.generator_class.SELECT_KINDS = ()
 137
 138        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 139            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 140                TokenType.ANTI,
 141                TokenType.SEMI,
 142            }
 143
 144        return klass
 145
 146
 147class Dialect(metaclass=_Dialect):
 148    INDEX_OFFSET = 0
 149    """Determines the base index offset for arrays."""
 150
 151    WEEK_OFFSET = 0
 152    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 153
 154    UNNEST_COLUMN_ONLY = False
 155    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
 156
 157    ALIAS_POST_TABLESAMPLE = False
 158    """Determines whether or not the table alias comes after tablesample."""
 159
 160    TABLESAMPLE_SIZE_IS_PERCENT = False
 161    """Determines whether or not a size in the table sample clause represents percentage."""
 162
 163    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 164    """Specifies the strategy according to which identifiers should be normalized."""
 165
 166    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 167    """Determines whether or not an unquoted identifier can start with a digit."""
 168
 169    DPIPE_IS_STRING_CONCAT = True
 170    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
 171
 172    STRICT_STRING_CONCAT = False
 173    """Determines whether or not `CONCAT`'s arguments must be strings."""
 174
 175    SUPPORTS_USER_DEFINED_TYPES = True
 176    """Determines whether or not user-defined data types are supported."""
 177
 178    SUPPORTS_SEMI_ANTI_JOIN = True
 179    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
 180
 181    NORMALIZE_FUNCTIONS: bool | str = "upper"
 182    """Determines how function names are going to be normalized."""
 183
 184    LOG_BASE_FIRST = True
 185    """Determines whether the base comes first in the `LOG` function."""
 186
 187    NULL_ORDERING = "nulls_are_small"
 188    """
 189    Indicates the default `NULL` ordering method to use if not explicitly set.
 190    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 191    """
 192
 193    TYPED_DIVISION = False
 194    """
 195    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 196    False means `a / b` is always float division.
 197    True means `a / b` is integer division if both `a` and `b` are integers.
 198    """
 199
 200    SAFE_DIVISION = False
 201    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 202
 203    CONCAT_COALESCE = False
 204    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 205
 206    DATE_FORMAT = "'%Y-%m-%d'"
 207    DATEINT_FORMAT = "'%Y%m%d'"
 208    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 209
 210    TIME_MAPPING: t.Dict[str, str] = {}
 211    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
 212
 213    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 214    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 215    FORMAT_MAPPING: t.Dict[str, str] = {}
 216    """
 217    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 218    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 219    """
 220
 221    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 222    """Mapping of an unescaped escape sequence to the corresponding character."""
 223
 224    PSEUDOCOLUMNS: t.Set[str] = set()
 225    """
 226    Columns that are auto-generated by the engine corresponding to this dialect.
 227    For example, such columns may be excluded from `SELECT *` queries.
 228    """
 229
 230    PREFER_CTE_ALIAS_COLUMN = False
 231    """
 232    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 233    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 234    any projection aliases in the subquery.
 235
 236    For example,
 237        WITH y(c) AS (
 238            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 239        ) SELECT c FROM y;
 240
 241        will be rewritten as
 242
 243        WITH y(c) AS (
 244            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 245        ) SELECT c FROM y;
 246    """
 247
 248    # --- Autofilled ---
 249
 250    tokenizer_class = Tokenizer
 251    parser_class = Parser
 252    generator_class = Generator
 253
 254    # A trie of the time_mapping keys
 255    TIME_TRIE: t.Dict = {}
 256    FORMAT_TRIE: t.Dict = {}
 257
 258    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 259    INVERSE_TIME_TRIE: t.Dict = {}
 260
 261    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 262
 263    # Delimiters for string literals and identifiers
 264    QUOTE_START = "'"
 265    QUOTE_END = "'"
 266    IDENTIFIER_START = '"'
 267    IDENTIFIER_END = '"'
 268
 269    # Delimiters for bit, hex, byte and unicode literals
 270    BIT_START: t.Optional[str] = None
 271    BIT_END: t.Optional[str] = None
 272    HEX_START: t.Optional[str] = None
 273    HEX_END: t.Optional[str] = None
 274    BYTE_START: t.Optional[str] = None
 275    BYTE_END: t.Optional[str] = None
 276    UNICODE_START: t.Optional[str] = None
 277    UNICODE_END: t.Optional[str] = None
 278
 279    @classmethod
 280    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 281        """
 282        Look up a dialect in the global dialect registry and return it if it exists.
 283
 284        Args:
 285            dialect: The target dialect. If this is a string, it can be optionally followed by
 286                additional key-value pairs that are separated by commas and are used to specify
 287                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 288
 289        Example:
 290            >>> dialect = dialect_class = get_or_raise("duckdb")
 291            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 292
 293        Returns:
 294            The corresponding Dialect instance.
 295        """
 296
 297        if not dialect:
 298            return cls()
 299        if isinstance(dialect, _Dialect):
 300            return dialect()
 301        if isinstance(dialect, Dialect):
 302            return dialect
 303        if isinstance(dialect, str):
 304            try:
 305                dialect_name, *kv_pairs = dialect.split(",")
 306                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 307            except ValueError:
 308                raise ValueError(
 309                    f"Invalid dialect format: '{dialect}'. "
 310                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 311                )
 312
 313            result = cls.get(dialect_name.strip())
 314            if not result:
 315                from difflib import get_close_matches
 316
 317                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 318                if similar:
 319                    similar = f" Did you mean {similar}?"
 320
 321                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 322
 323            return result(**kwargs)
 324
 325        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 326
 327    @classmethod
 328    def format_time(
 329        cls, expression: t.Optional[str | exp.Expression]
 330    ) -> t.Optional[exp.Expression]:
 331        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 332        if isinstance(expression, str):
 333            return exp.Literal.string(
 334                # the time formats are quoted
 335                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 336            )
 337
 338        if expression and expression.is_string:
 339            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 340
 341        return expression
 342
 343    def __init__(self, **kwargs) -> None:
 344        normalization_strategy = kwargs.get("normalization_strategy")
 345
 346        if normalization_strategy is None:
 347            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 348        else:
 349            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 350
 351    def __eq__(self, other: t.Any) -> bool:
 352        # Does not currently take dialect state into account
 353        return type(self) == other
 354
 355    def __hash__(self) -> int:
 356        # Does not currently take dialect state into account
 357        return hash(type(self))
 358
 359    def normalize_identifier(self, expression: E) -> E:
 360        """
 361        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 362
 363        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 364        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 365        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 366        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 367
 368        There are also dialects like Spark, which are case-insensitive even when quotes are
 369        present, and dialects like MySQL, whose resolution rules match those employed by the
 370        underlying operating system, for example they may always be case-sensitive in Linux.
 371
 372        Finally, the normalization behavior of some engines can even be controlled through flags,
 373        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 374
 375        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 376        that it can analyze queries in the optimizer and successfully capture their semantics.
 377        """
 378        if (
 379            isinstance(expression, exp.Identifier)
 380            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 381            and (
 382                not expression.quoted
 383                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 384            )
 385        ):
 386            expression.set(
 387                "this",
 388                (
 389                    expression.this.upper()
 390                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 391                    else expression.this.lower()
 392                ),
 393            )
 394
 395        return expression
 396
 397    def case_sensitive(self, text: str) -> bool:
 398        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 399        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 400            return False
 401
 402        unsafe = (
 403            str.islower
 404            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 405            else str.isupper
 406        )
 407        return any(unsafe(char) for char in text)
 408
 409    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 410        """Checks if text can be identified given an identify option.
 411
 412        Args:
 413            text: The text to check.
 414            identify:
 415                `"always"` or `True`: Always returns `True`.
 416                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 417
 418        Returns:
 419            Whether or not the given text can be identified.
 420        """
 421        if identify is True or identify == "always":
 422            return True
 423
 424        if identify == "safe":
 425            return not self.case_sensitive(text)
 426
 427        return False
 428
 429    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 430        """
 431        Adds quotes to a given identifier.
 432
 433        Args:
 434            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 435            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 436                "unsafe", with respect to its characters and this dialect's normalization strategy.
 437        """
 438        if isinstance(expression, exp.Identifier):
 439            name = expression.this
 440            expression.set(
 441                "quoted",
 442                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 443            )
 444
 445        return expression
 446
 447    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 448        if isinstance(path, exp.Literal):
 449            path_text = path.name
 450            if path.is_number:
 451                path_text = f"[{path_text}]"
 452
 453            try:
 454                return parse_json_path(path_text)
 455            except ParseError as e:
 456                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 457
 458        return path
 459
 460    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 461        return self.parser(**opts).parse(self.tokenize(sql), sql)
 462
 463    def parse_into(
 464        self, expression_type: exp.IntoType, sql: str, **opts
 465    ) -> t.List[t.Optional[exp.Expression]]:
 466        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 467
 468    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 469        return self.generator(**opts).generate(expression, copy=copy)
 470
 471    def transpile(self, sql: str, **opts) -> t.List[str]:
 472        return [
 473            self.generate(expression, copy=False, **opts) if expression else ""
 474            for expression in self.parse(sql)
 475        ]
 476
 477    def tokenize(self, sql: str) -> t.List[Token]:
 478        return self.tokenizer.tokenize(sql)
 479
 480    @property
 481    def tokenizer(self) -> Tokenizer:
 482        if not hasattr(self, "_tokenizer"):
 483            self._tokenizer = self.tokenizer_class(dialect=self)
 484        return self._tokenizer
 485
 486    def parser(self, **opts) -> Parser:
 487        return self.parser_class(dialect=self, **opts)
 488
 489    def generator(self, **opts) -> Generator:
 490        return self.generator_class(dialect=self, **opts)
 491
 492
 493DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 494
 495
 496def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 497    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 498
 499
 500def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 501    if expression.args.get("accuracy"):
 502        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 503    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 504
 505
 506def if_sql(
 507    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 508) -> t.Callable[[Generator, exp.If], str]:
 509    def _if_sql(self: Generator, expression: exp.If) -> str:
 510        return self.func(
 511            name,
 512            expression.this,
 513            expression.args.get("true"),
 514            expression.args.get("false") or false_value,
 515        )
 516
 517    return _if_sql
 518
 519
 520def arrow_json_extract_sql(
 521    self: Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
 522) -> str:
 523    this = expression.this
 524    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 525        this.replace(exp.cast(this, "json"))
 526
 527    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 528
 529
 530def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 531    return f"[{self.expressions(expression, flat=True)}]"
 532
 533
 534def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 535    return self.like_sql(
 536        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 537    )
 538
 539
 540def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 541    zone = self.sql(expression, "this")
 542    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 543
 544
 545def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 546    if expression.args.get("recursive"):
 547        self.unsupported("Recursive CTEs are unsupported")
 548        expression.args["recursive"] = False
 549    return self.with_sql(expression)
 550
 551
 552def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 553    n = self.sql(expression, "this")
 554    d = self.sql(expression, "expression")
 555    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 556
 557
 558def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 559    self.unsupported("TABLESAMPLE unsupported")
 560    return self.sql(expression.this)
 561
 562
 563def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 564    self.unsupported("PIVOT unsupported")
 565    return ""
 566
 567
 568def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 569    return self.cast_sql(expression)
 570
 571
 572def no_comment_column_constraint_sql(
 573    self: Generator, expression: exp.CommentColumnConstraint
 574) -> str:
 575    self.unsupported("CommentColumnConstraint unsupported")
 576    return ""
 577
 578
 579def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 580    self.unsupported("MAP_FROM_ENTRIES unsupported")
 581    return ""
 582
 583
 584def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 585    this = self.sql(expression, "this")
 586    substr = self.sql(expression, "substr")
 587    position = self.sql(expression, "position")
 588    if position:
 589        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 590    return f"STRPOS({this}, {substr})"
 591
 592
 593def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 594    return (
 595        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 596    )
 597
 598
 599def var_map_sql(
 600    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 601) -> str:
 602    keys = expression.args["keys"]
 603    values = expression.args["values"]
 604
 605    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 606        self.unsupported("Cannot convert array columns into map.")
 607        return self.func(map_func_name, keys, values)
 608
 609    args = []
 610    for key, value in zip(keys.expressions, values.expressions):
 611        args.append(self.sql(key))
 612        args.append(self.sql(value))
 613
 614    return self.func(map_func_name, *args)
 615
 616
 617def format_time_lambda(
 618    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 619) -> t.Callable[[t.List], E]:
 620    """Helper used for time expressions.
 621
 622    Args:
 623        exp_class: the expression class to instantiate.
 624        dialect: target sql dialect.
 625        default: the default format, True being time.
 626
 627    Returns:
 628        A callable that can be used to return the appropriately formatted time expression.
 629    """
 630
 631    def _format_time(args: t.List):
 632        return exp_class(
 633            this=seq_get(args, 0),
 634            format=Dialect[dialect].format_time(
 635                seq_get(args, 1)
 636                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 637            ),
 638        )
 639
 640    return _format_time
 641
 642
 643def time_format(
 644    dialect: DialectType = None,
 645) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 646    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 647        """
 648        Returns the time format for a given expression, unless it's equivalent
 649        to the default time format of the dialect of interest.
 650        """
 651        time_format = self.format_time(expression)
 652        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 653
 654    return _time_format
 655
 656
 657def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
 658    """
 659    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
 660    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
 661    columns are removed from the create statement.
 662    """
 663    has_schema = isinstance(expression.this, exp.Schema)
 664    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
 665
 666    if has_schema and is_partitionable:
 667        prop = expression.find(exp.PartitionedByProperty)
 668        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 669            schema = expression.this
 670            columns = {v.name.upper() for v in prop.this.expressions}
 671            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 672            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 673            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 674            expression.set("this", schema)
 675
 676    return self.create_sql(expression)
 677
 678
 679def parse_date_delta(
 680    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 681) -> t.Callable[[t.List], E]:
 682    def inner_func(args: t.List) -> E:
 683        unit_based = len(args) == 3
 684        this = args[2] if unit_based else seq_get(args, 0)
 685        unit = args[0] if unit_based else exp.Literal.string("DAY")
 686        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 687        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 688
 689    return inner_func
 690
 691
 692def parse_date_delta_with_interval(
 693    expression_class: t.Type[E],
 694) -> t.Callable[[t.List], t.Optional[E]]:
 695    def func(args: t.List) -> t.Optional[E]:
 696        if len(args) < 2:
 697            return None
 698
 699        interval = args[1]
 700
 701        if not isinstance(interval, exp.Interval):
 702            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 703
 704        expression = interval.this
 705        if expression and expression.is_string:
 706            expression = exp.Literal.number(expression.this)
 707
 708        return expression_class(
 709            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 710        )
 711
 712    return func
 713
 714
 715def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 716    unit = seq_get(args, 0)
 717    this = seq_get(args, 1)
 718
 719    if isinstance(this, exp.Cast) and this.is_type("date"):
 720        return exp.DateTrunc(unit=unit, this=this)
 721    return exp.TimestampTrunc(this=this, unit=unit)
 722
 723
 724def date_add_interval_sql(
 725    data_type: str, kind: str
 726) -> t.Callable[[Generator, exp.Expression], str]:
 727    def func(self: Generator, expression: exp.Expression) -> str:
 728        this = self.sql(expression, "this")
 729        unit = expression.args.get("unit")
 730        unit = exp.var(unit.name.upper() if unit else "DAY")
 731        interval = exp.Interval(this=expression.expression, unit=unit)
 732        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 733
 734    return func
 735
 736
 737def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 738    return self.func(
 739        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 740    )
 741
 742
 743def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 744    if not expression.expression:
 745        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
 746    if expression.text("expression").lower() in TIMEZONES:
 747        return self.sql(
 748            exp.AtTimeZone(
 749                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 750                zone=expression.expression,
 751            )
 752        )
 753    return self.function_fallback_sql(expression)
 754
 755
 756def locate_to_strposition(args: t.List) -> exp.Expression:
 757    return exp.StrPosition(
 758        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 759    )
 760
 761
 762def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 763    return self.func(
 764        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 765    )
 766
 767
 768def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 769    return self.sql(
 770        exp.Substring(
 771            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 772        )
 773    )
 774
 775
 776def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 777    return self.sql(
 778        exp.Substring(
 779            this=expression.this,
 780            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 781        )
 782    )
 783
 784
 785def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 786    return self.sql(exp.cast(expression.this, "timestamp"))
 787
 788
 789def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 790    return self.sql(exp.cast(expression.this, "date"))
 791
 792
 793# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 794def encode_decode_sql(
 795    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 796) -> str:
 797    charset = expression.args.get("charset")
 798    if charset and charset.name.lower() != "utf-8":
 799        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 800
 801    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 802
 803
 804def min_or_least(self: Generator, expression: exp.Min) -> str:
 805    name = "LEAST" if expression.expressions else "MIN"
 806    return rename_func(name)(self, expression)
 807
 808
 809def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 810    name = "GREATEST" if expression.expressions else "MAX"
 811    return rename_func(name)(self, expression)
 812
 813
 814def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 815    cond = expression.this
 816
 817    if isinstance(expression.this, exp.Distinct):
 818        cond = expression.this.expressions[0]
 819        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 820
 821    return self.func("sum", exp.func("if", cond, 1, 0))
 822
 823
 824def trim_sql(self: Generator, expression: exp.Trim) -> str:
 825    target = self.sql(expression, "this")
 826    trim_type = self.sql(expression, "position")
 827    remove_chars = self.sql(expression, "expression")
 828    collation = self.sql(expression, "collation")
 829
 830    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 831    if not remove_chars and not collation:
 832        return self.trim_sql(expression)
 833
 834    trim_type = f"{trim_type} " if trim_type else ""
 835    remove_chars = f"{remove_chars} " if remove_chars else ""
 836    from_part = "FROM " if trim_type or remove_chars else ""
 837    collation = f" COLLATE {collation}" if collation else ""
 838    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 839
 840
 841def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 842    return self.func("STRPTIME", expression.this, self.format_time(expression))
 843
 844
 845def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 846    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 847
 848
 849def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 850    delim, *rest_args = expression.expressions
 851    return self.sql(
 852        reduce(
 853            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 854            rest_args,
 855        )
 856    )
 857
 858
 859def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 860    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 861    if bad_args:
 862        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 863
 864    return self.func(
 865        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 866    )
 867
 868
 869def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 870    bad_args = list(
 871        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 872    )
 873    if bad_args:
 874        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 875
 876    return self.func(
 877        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 878    )
 879
 880
 881def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 882    names = []
 883    for agg in aggregations:
 884        if isinstance(agg, exp.Alias):
 885            names.append(agg.alias)
 886        else:
 887            """
 888            This case corresponds to aggregations without aliases being used as suffixes
 889            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 890            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 891            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 892            """
 893            agg_all_unquoted = agg.transform(
 894                lambda node: (
 895                    exp.Identifier(this=node.name, quoted=False)
 896                    if isinstance(node, exp.Identifier)
 897                    else node
 898                )
 899            )
 900            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 901
 902    return names
 903
 904
 905def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 906    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 907
 908
 909# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 910def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 911    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 912
 913
 914def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 915    return self.func("MAX", expression.this)
 916
 917
 918def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 919    a = self.sql(expression.left)
 920    b = self.sql(expression.right)
 921    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 922
 923
 924def is_parse_json(expression: exp.Expression) -> bool:
 925    return isinstance(expression, exp.ParseJSON) or (
 926        isinstance(expression, exp.Cast) and expression.is_type("json")
 927    )
 928
 929
 930def isnull_to_is_null(args: t.List) -> exp.Expression:
 931    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 932
 933
 934def generatedasidentitycolumnconstraint_sql(
 935    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 936) -> str:
 937    start = self.sql(expression, "start") or "1"
 938    increment = self.sql(expression, "increment") or "1"
 939    return f"IDENTITY({start}, {increment})"
 940
 941
 942def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 943    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 944        if expression.args.get("count"):
 945            self.unsupported(f"Only two arguments are supported in function {name}.")
 946
 947        return self.func(name, expression.this, expression.expression)
 948
 949    return _arg_max_or_min_sql
 950
 951
 952def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 953    this = expression.this.copy()
 954
 955    return_type = expression.return_type
 956    if return_type.is_type(exp.DataType.Type.DATE):
 957        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 958        # can truncate timestamp strings, because some dialects can't cast them to DATE
 959        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 960
 961    expression.this.replace(exp.cast(this, return_type))
 962    return expression
 963
 964
 965def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 966    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 967        if cast and isinstance(expression, exp.TsOrDsAdd):
 968            expression = ts_or_ds_add_cast(expression)
 969
 970        return self.func(
 971            name,
 972            exp.var(expression.text("unit").upper() or "DAY"),
 973            expression.expression,
 974            expression.this,
 975        )
 976
 977    return _delta_sql
 978
 979
 980def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 981    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 982    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 983    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 984
 985    return self.sql(exp.cast(minus_one_day, "date"))
 986
 987
 988def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 989    """Remove table refs from columns in when statements."""
 990    alias = expression.this.args.get("alias")
 991
 992    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 993        return self.dialect.normalize_identifier(identifier).name if identifier else None
 994
 995    targets = {normalize(expression.this.this)}
 996
 997    if alias:
 998        targets.add(normalize(alias.this))
 999
1000    for when in expression.expressions:
1001        when.transform(
1002            lambda node: (
1003                exp.column(node.this)
1004                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1005                else node
1006            ),
1007            copy=False,
1008        )
1009
1010    return self.merge_sql(expression)
1011
1012
1013def parse_json_extract_path(
1014    expr_type: t.Type[E],
1015    supports_null_if_invalid: bool = False,
1016) -> t.Callable[[t.List], E]:
1017    def _parse_json_extract_path(args: t.List) -> E:
1018        null_if_invalid = None
1019
1020        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1021        for arg in args[1:]:
1022            if isinstance(arg, exp.Literal):
1023                text = arg.name
1024                if is_int(text):
1025                    segments.append(exp.JSONPathSubscript(this=int(text)))
1026                else:
1027                    segments.append(exp.JSONPathKey(this=text))
1028            elif supports_null_if_invalid:
1029                null_if_invalid = arg
1030
1031        this = seq_get(args, 0)
1032        jsonpath = exp.JSONPath(expressions=segments)
1033
1034        # This is done to avoid failing in the expression validator due to the arg count
1035        del args[2:]
1036
1037        if expr_type is exp.JSONExtractScalar:
1038            return expr_type(this=this, expression=jsonpath, null_if_invalid=null_if_invalid)
1039
1040        return expr_type(this=this, expression=jsonpath)
1041
1042    return _parse_json_extract_path
1043
1044
1045def json_path_segments(self: Generator, expression: exp.JSONPath) -> t.List[str]:
1046    segments = []
1047    for segment in expression.expressions:
1048        path = self.sql(segment)
1049        if path:
1050            segments.append(f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}")
1051
1052    return segments
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
28class Dialects(str, Enum):
29    """Dialects supported by SQLGLot."""
30
31    DIALECT = ""
32
33    BIGQUERY = "bigquery"
34    CLICKHOUSE = "clickhouse"
35    DATABRICKS = "databricks"
36    DORIS = "doris"
37    DRILL = "drill"
38    DUCKDB = "duckdb"
39    HIVE = "hive"
40    MYSQL = "mysql"
41    ORACLE = "oracle"
42    POSTGRES = "postgres"
43    PRESTO = "presto"
44    REDSHIFT = "redshift"
45    SNOWFLAKE = "snowflake"
46    SPARK = "spark"
47    SPARK2 = "spark2"
48    SQLITE = "sqlite"
49    STARROCKS = "starrocks"
50    TABLEAU = "tableau"
51    TERADATA = "teradata"
52    TRINO = "trino"
53    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
56class NormalizationStrategy(str, AutoName):
57    """Specifies the strategy according to which identifiers should be normalized."""
58
59    LOWERCASE = auto()
60    """Unquoted identifiers are lowercased."""
61
62    UPPERCASE = auto()
63    """Unquoted identifiers are uppercased."""
64
65    CASE_SENSITIVE = auto()
66    """Always case-sensitive, regardless of quotes."""
67
68    CASE_INSENSITIVE = auto()
69    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
148class Dialect(metaclass=_Dialect):
149    INDEX_OFFSET = 0
150    """Determines the base index offset for arrays."""
151
152    WEEK_OFFSET = 0
153    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
154
155    UNNEST_COLUMN_ONLY = False
156    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
157
158    ALIAS_POST_TABLESAMPLE = False
159    """Determines whether or not the table alias comes after tablesample."""
160
161    TABLESAMPLE_SIZE_IS_PERCENT = False
162    """Determines whether or not a size in the table sample clause represents percentage."""
163
164    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
165    """Specifies the strategy according to which identifiers should be normalized."""
166
167    IDENTIFIERS_CAN_START_WITH_DIGIT = False
168    """Determines whether or not an unquoted identifier can start with a digit."""
169
170    DPIPE_IS_STRING_CONCAT = True
171    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
172
173    STRICT_STRING_CONCAT = False
174    """Determines whether or not `CONCAT`'s arguments must be strings."""
175
176    SUPPORTS_USER_DEFINED_TYPES = True
177    """Determines whether or not user-defined data types are supported."""
178
179    SUPPORTS_SEMI_ANTI_JOIN = True
180    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
181
182    NORMALIZE_FUNCTIONS: bool | str = "upper"
183    """Determines how function names are going to be normalized."""
184
185    LOG_BASE_FIRST = True
186    """Determines whether the base comes first in the `LOG` function."""
187
188    NULL_ORDERING = "nulls_are_small"
189    """
190    Indicates the default `NULL` ordering method to use if not explicitly set.
191    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
192    """
193
194    TYPED_DIVISION = False
195    """
196    Whether the behavior of `a / b` depends on the types of `a` and `b`.
197    False means `a / b` is always float division.
198    True means `a / b` is integer division if both `a` and `b` are integers.
199    """
200
201    SAFE_DIVISION = False
202    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
203
204    CONCAT_COALESCE = False
205    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
206
207    DATE_FORMAT = "'%Y-%m-%d'"
208    DATEINT_FORMAT = "'%Y%m%d'"
209    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
210
211    TIME_MAPPING: t.Dict[str, str] = {}
212    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
213
214    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
215    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
216    FORMAT_MAPPING: t.Dict[str, str] = {}
217    """
218    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
219    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
220    """
221
222    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
223    """Mapping of an unescaped escape sequence to the corresponding character."""
224
225    PSEUDOCOLUMNS: t.Set[str] = set()
226    """
227    Columns that are auto-generated by the engine corresponding to this dialect.
228    For example, such columns may be excluded from `SELECT *` queries.
229    """
230
231    PREFER_CTE_ALIAS_COLUMN = False
232    """
233    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
234    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
235    any projection aliases in the subquery.
236
237    For example,
238        WITH y(c) AS (
239            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
240        ) SELECT c FROM y;
241
242        will be rewritten as
243
244        WITH y(c) AS (
245            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
246        ) SELECT c FROM y;
247    """
248
249    # --- Autofilled ---
250
251    tokenizer_class = Tokenizer
252    parser_class = Parser
253    generator_class = Generator
254
255    # A trie of the time_mapping keys
256    TIME_TRIE: t.Dict = {}
257    FORMAT_TRIE: t.Dict = {}
258
259    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
260    INVERSE_TIME_TRIE: t.Dict = {}
261
262    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
263
264    # Delimiters for string literals and identifiers
265    QUOTE_START = "'"
266    QUOTE_END = "'"
267    IDENTIFIER_START = '"'
268    IDENTIFIER_END = '"'
269
270    # Delimiters for bit, hex, byte and unicode literals
271    BIT_START: t.Optional[str] = None
272    BIT_END: t.Optional[str] = None
273    HEX_START: t.Optional[str] = None
274    HEX_END: t.Optional[str] = None
275    BYTE_START: t.Optional[str] = None
276    BYTE_END: t.Optional[str] = None
277    UNICODE_START: t.Optional[str] = None
278    UNICODE_END: t.Optional[str] = None
279
280    @classmethod
281    def get_or_raise(cls, dialect: DialectType) -> Dialect:
282        """
283        Look up a dialect in the global dialect registry and return it if it exists.
284
285        Args:
286            dialect: The target dialect. If this is a string, it can be optionally followed by
287                additional key-value pairs that are separated by commas and are used to specify
288                dialect settings, such as whether the dialect's identifiers are case-sensitive.
289
290        Example:
291            >>> dialect = dialect_class = get_or_raise("duckdb")
292            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
293
294        Returns:
295            The corresponding Dialect instance.
296        """
297
298        if not dialect:
299            return cls()
300        if isinstance(dialect, _Dialect):
301            return dialect()
302        if isinstance(dialect, Dialect):
303            return dialect
304        if isinstance(dialect, str):
305            try:
306                dialect_name, *kv_pairs = dialect.split(",")
307                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
308            except ValueError:
309                raise ValueError(
310                    f"Invalid dialect format: '{dialect}'. "
311                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
312                )
313
314            result = cls.get(dialect_name.strip())
315            if not result:
316                from difflib import get_close_matches
317
318                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
319                if similar:
320                    similar = f" Did you mean {similar}?"
321
322                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
323
324            return result(**kwargs)
325
326        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
327
328    @classmethod
329    def format_time(
330        cls, expression: t.Optional[str | exp.Expression]
331    ) -> t.Optional[exp.Expression]:
332        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
333        if isinstance(expression, str):
334            return exp.Literal.string(
335                # the time formats are quoted
336                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
337            )
338
339        if expression and expression.is_string:
340            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
341
342        return expression
343
344    def __init__(self, **kwargs) -> None:
345        normalization_strategy = kwargs.get("normalization_strategy")
346
347        if normalization_strategy is None:
348            self.normalization_strategy = self.NORMALIZATION_STRATEGY
349        else:
350            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
351
352    def __eq__(self, other: t.Any) -> bool:
353        # Does not currently take dialect state into account
354        return type(self) == other
355
356    def __hash__(self) -> int:
357        # Does not currently take dialect state into account
358        return hash(type(self))
359
360    def normalize_identifier(self, expression: E) -> E:
361        """
362        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
363
364        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
365        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
366        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
367        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
368
369        There are also dialects like Spark, which are case-insensitive even when quotes are
370        present, and dialects like MySQL, whose resolution rules match those employed by the
371        underlying operating system, for example they may always be case-sensitive in Linux.
372
373        Finally, the normalization behavior of some engines can even be controlled through flags,
374        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
375
376        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
377        that it can analyze queries in the optimizer and successfully capture their semantics.
378        """
379        if (
380            isinstance(expression, exp.Identifier)
381            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
382            and (
383                not expression.quoted
384                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
385            )
386        ):
387            expression.set(
388                "this",
389                (
390                    expression.this.upper()
391                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
392                    else expression.this.lower()
393                ),
394            )
395
396        return expression
397
398    def case_sensitive(self, text: str) -> bool:
399        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
400        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
401            return False
402
403        unsafe = (
404            str.islower
405            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
406            else str.isupper
407        )
408        return any(unsafe(char) for char in text)
409
410    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
411        """Checks if text can be identified given an identify option.
412
413        Args:
414            text: The text to check.
415            identify:
416                `"always"` or `True`: Always returns `True`.
417                `"safe"`: Only returns `True` if the identifier is case-insensitive.
418
419        Returns:
420            Whether or not the given text can be identified.
421        """
422        if identify is True or identify == "always":
423            return True
424
425        if identify == "safe":
426            return not self.case_sensitive(text)
427
428        return False
429
430    def quote_identifier(self, expression: E, identify: bool = True) -> E:
431        """
432        Adds quotes to a given identifier.
433
434        Args:
435            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
436            identify: If set to `False`, the quotes will only be added if the identifier is deemed
437                "unsafe", with respect to its characters and this dialect's normalization strategy.
438        """
439        if isinstance(expression, exp.Identifier):
440            name = expression.this
441            expression.set(
442                "quoted",
443                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
444            )
445
446        return expression
447
448    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
449        if isinstance(path, exp.Literal):
450            path_text = path.name
451            if path.is_number:
452                path_text = f"[{path_text}]"
453
454            try:
455                return parse_json_path(path_text)
456            except ParseError as e:
457                logger.warning(f"Invalid JSON path syntax. {str(e)}")
458
459        return path
460
461    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
462        return self.parser(**opts).parse(self.tokenize(sql), sql)
463
464    def parse_into(
465        self, expression_type: exp.IntoType, sql: str, **opts
466    ) -> t.List[t.Optional[exp.Expression]]:
467        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
468
469    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
470        return self.generator(**opts).generate(expression, copy=copy)
471
472    def transpile(self, sql: str, **opts) -> t.List[str]:
473        return [
474            self.generate(expression, copy=False, **opts) if expression else ""
475            for expression in self.parse(sql)
476        ]
477
478    def tokenize(self, sql: str) -> t.List[Token]:
479        return self.tokenizer.tokenize(sql)
480
481    @property
482    def tokenizer(self) -> Tokenizer:
483        if not hasattr(self, "_tokenizer"):
484            self._tokenizer = self.tokenizer_class(dialect=self)
485        return self._tokenizer
486
487    def parser(self, **opts) -> Parser:
488        return self.parser_class(dialect=self, **opts)
489
490    def generator(self, **opts) -> Generator:
491        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
344    def __init__(self, **kwargs) -> None:
345        normalization_strategy = kwargs.get("normalization_strategy")
346
347        if normalization_strategy is None:
348            self.normalization_strategy = self.NORMALIZATION_STRATEGY
349        else:
350            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

Determines the base index offset for arrays.

WEEK_OFFSET = 0

Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Determines whether or not UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Determines whether or not the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Determines whether or not a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Determines whether or not an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Determines whether or not the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Determines whether or not CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Determines whether or not user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Determines whether or not SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

LOG_BASE_FIRST = True

Determines whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Indicates the default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Determines whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime format.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
280    @classmethod
281    def get_or_raise(cls, dialect: DialectType) -> Dialect:
282        """
283        Look up a dialect in the global dialect registry and return it if it exists.
284
285        Args:
286            dialect: The target dialect. If this is a string, it can be optionally followed by
287                additional key-value pairs that are separated by commas and are used to specify
288                dialect settings, such as whether the dialect's identifiers are case-sensitive.
289
290        Example:
291            >>> dialect = dialect_class = get_or_raise("duckdb")
292            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
293
294        Returns:
295            The corresponding Dialect instance.
296        """
297
298        if not dialect:
299            return cls()
300        if isinstance(dialect, _Dialect):
301            return dialect()
302        if isinstance(dialect, Dialect):
303            return dialect
304        if isinstance(dialect, str):
305            try:
306                dialect_name, *kv_pairs = dialect.split(",")
307                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
308            except ValueError:
309                raise ValueError(
310                    f"Invalid dialect format: '{dialect}'. "
311                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
312                )
313
314            result = cls.get(dialect_name.strip())
315            if not result:
316                from difflib import get_close_matches
317
318                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
319                if similar:
320                    similar = f" Did you mean {similar}?"
321
322                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
323
324            return result(**kwargs)
325
326        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
328    @classmethod
329    def format_time(
330        cls, expression: t.Optional[str | exp.Expression]
331    ) -> t.Optional[exp.Expression]:
332        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
333        if isinstance(expression, str):
334            return exp.Literal.string(
335                # the time formats are quoted
336                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
337            )
338
339        if expression and expression.is_string:
340            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
341
342        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
360    def normalize_identifier(self, expression: E) -> E:
361        """
362        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
363
364        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
365        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
366        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
367        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
368
369        There are also dialects like Spark, which are case-insensitive even when quotes are
370        present, and dialects like MySQL, whose resolution rules match those employed by the
371        underlying operating system, for example they may always be case-sensitive in Linux.
372
373        Finally, the normalization behavior of some engines can even be controlled through flags,
374        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
375
376        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
377        that it can analyze queries in the optimizer and successfully capture their semantics.
378        """
379        if (
380            isinstance(expression, exp.Identifier)
381            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
382            and (
383                not expression.quoted
384                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
385            )
386        ):
387            expression.set(
388                "this",
389                (
390                    expression.this.upper()
391                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
392                    else expression.this.lower()
393                ),
394            )
395
396        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
398    def case_sensitive(self, text: str) -> bool:
399        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
400        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
401            return False
402
403        unsafe = (
404            str.islower
405            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
406            else str.isupper
407        )
408        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
410    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
411        """Checks if text can be identified given an identify option.
412
413        Args:
414            text: The text to check.
415            identify:
416                `"always"` or `True`: Always returns `True`.
417                `"safe"`: Only returns `True` if the identifier is case-insensitive.
418
419        Returns:
420            Whether or not the given text can be identified.
421        """
422        if identify is True or identify == "always":
423            return True
424
425        if identify == "safe":
426            return not self.case_sensitive(text)
427
428        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
430    def quote_identifier(self, expression: E, identify: bool = True) -> E:
431        """
432        Adds quotes to a given identifier.
433
434        Args:
435            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
436            identify: If set to `False`, the quotes will only be added if the identifier is deemed
437                "unsafe", with respect to its characters and this dialect's normalization strategy.
438        """
439        if isinstance(expression, exp.Identifier):
440            name = expression.this
441            expression.set(
442                "quoted",
443                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
444            )
445
446        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
448    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
449        if isinstance(path, exp.Literal):
450            path_text = path.name
451            if path.is_number:
452                path_text = f"[{path_text}]"
453
454            try:
455                return parse_json_path(path_text)
456            except ParseError as e:
457                logger.warning(f"Invalid JSON path syntax. {str(e)}")
458
459        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
461    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
462        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
464    def parse_into(
465        self, expression_type: exp.IntoType, sql: str, **opts
466    ) -> t.List[t.Optional[exp.Expression]]:
467        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
469    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
470        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
472    def transpile(self, sql: str, **opts) -> t.List[str]:
473        return [
474            self.generate(expression, copy=False, **opts) if expression else ""
475            for expression in self.parse(sql)
476        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
478    def tokenize(self, sql: str) -> t.List[Token]:
479        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
481    @property
482    def tokenizer(self) -> Tokenizer:
483        if not hasattr(self, "_tokenizer"):
484            self._tokenizer = self.tokenizer_class(dialect=self)
485        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
487    def parser(self, **opts) -> Parser:
488        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
490    def generator(self, **opts) -> Generator:
491        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
497def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
498    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
501def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
502    if expression.args.get("accuracy"):
503        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
504    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
507def if_sql(
508    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
509) -> t.Callable[[Generator, exp.If], str]:
510    def _if_sql(self: Generator, expression: exp.If) -> str:
511        return self.func(
512            name,
513            expression.this,
514            expression.args.get("true"),
515            expression.args.get("false") or false_value,
516        )
517
518    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONExtractScalar) -> str:
521def arrow_json_extract_sql(
522    self: Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
523) -> str:
524    this = expression.this
525    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
526        this.replace(exp.cast(this, "json"))
527
528    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
531def inline_array_sql(self: Generator, expression: exp.Array) -> str:
532    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
535def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
536    return self.like_sql(
537        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
538    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
541def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
542    zone = self.sql(expression, "this")
543    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
546def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
547    if expression.args.get("recursive"):
548        self.unsupported("Recursive CTEs are unsupported")
549        expression.args["recursive"] = False
550    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
553def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
554    n = self.sql(expression, "this")
555    d = self.sql(expression, "expression")
556    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
559def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
560    self.unsupported("TABLESAMPLE unsupported")
561    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
564def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
565    self.unsupported("PIVOT unsupported")
566    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
569def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
570    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
573def no_comment_column_constraint_sql(
574    self: Generator, expression: exp.CommentColumnConstraint
575) -> str:
576    self.unsupported("CommentColumnConstraint unsupported")
577    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
580def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
581    self.unsupported("MAP_FROM_ENTRIES unsupported")
582    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
585def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
586    this = self.sql(expression, "this")
587    substr = self.sql(expression, "substr")
588    position = self.sql(expression, "position")
589    if position:
590        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
591    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
594def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
595    return (
596        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
597    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
600def var_map_sql(
601    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
602) -> str:
603    keys = expression.args["keys"]
604    values = expression.args["values"]
605
606    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
607        self.unsupported("Cannot convert array columns into map.")
608        return self.func(map_func_name, keys, values)
609
610    args = []
611    for key, value in zip(keys.expressions, values.expressions):
612        args.append(self.sql(key))
613        args.append(self.sql(value))
614
615    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
618def format_time_lambda(
619    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
620) -> t.Callable[[t.List], E]:
621    """Helper used for time expressions.
622
623    Args:
624        exp_class: the expression class to instantiate.
625        dialect: target sql dialect.
626        default: the default format, True being time.
627
628    Returns:
629        A callable that can be used to return the appropriately formatted time expression.
630    """
631
632    def _format_time(args: t.List):
633        return exp_class(
634            this=seq_get(args, 0),
635            format=Dialect[dialect].format_time(
636                seq_get(args, 1)
637                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
638            ),
639        )
640
641    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
644def time_format(
645    dialect: DialectType = None,
646) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
647    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
648        """
649        Returns the time format for a given expression, unless it's equivalent
650        to the default time format of the dialect of interest.
651        """
652        time_format = self.format_time(expression)
653        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
654
655    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
658def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
659    """
660    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
661    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
662    columns are removed from the create statement.
663    """
664    has_schema = isinstance(expression.this, exp.Schema)
665    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
666
667    if has_schema and is_partitionable:
668        prop = expression.find(exp.PartitionedByProperty)
669        if prop and prop.this and not isinstance(prop.this, exp.Schema):
670            schema = expression.this
671            columns = {v.name.upper() for v in prop.this.expressions}
672            partitions = [col for col in schema.expressions if col.name.upper() in columns]
673            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
674            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
675            expression.set("this", schema)
676
677    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
680def parse_date_delta(
681    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
682) -> t.Callable[[t.List], E]:
683    def inner_func(args: t.List) -> E:
684        unit_based = len(args) == 3
685        this = args[2] if unit_based else seq_get(args, 0)
686        unit = args[0] if unit_based else exp.Literal.string("DAY")
687        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
688        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
689
690    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
693def parse_date_delta_with_interval(
694    expression_class: t.Type[E],
695) -> t.Callable[[t.List], t.Optional[E]]:
696    def func(args: t.List) -> t.Optional[E]:
697        if len(args) < 2:
698            return None
699
700        interval = args[1]
701
702        if not isinstance(interval, exp.Interval):
703            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
704
705        expression = interval.this
706        if expression and expression.is_string:
707            expression = exp.Literal.number(expression.this)
708
709        return expression_class(
710            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
711        )
712
713    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
716def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
717    unit = seq_get(args, 0)
718    this = seq_get(args, 1)
719
720    if isinstance(this, exp.Cast) and this.is_type("date"):
721        return exp.DateTrunc(unit=unit, this=this)
722    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
725def date_add_interval_sql(
726    data_type: str, kind: str
727) -> t.Callable[[Generator, exp.Expression], str]:
728    def func(self: Generator, expression: exp.Expression) -> str:
729        this = self.sql(expression, "this")
730        unit = expression.args.get("unit")
731        unit = exp.var(unit.name.upper() if unit else "DAY")
732        interval = exp.Interval(this=expression.expression, unit=unit)
733        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
734
735    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
738def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
739    return self.func(
740        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
741    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
744def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
745    if not expression.expression:
746        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
747    if expression.text("expression").lower() in TIMEZONES:
748        return self.sql(
749            exp.AtTimeZone(
750                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
751                zone=expression.expression,
752            )
753        )
754    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
757def locate_to_strposition(args: t.List) -> exp.Expression:
758    return exp.StrPosition(
759        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
760    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
763def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
764    return self.func(
765        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
766    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
769def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
770    return self.sql(
771        exp.Substring(
772            this=expression.this, start=exp.Literal.number(1), length=expression.expression
773        )
774    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
777def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
778    return self.sql(
779        exp.Substring(
780            this=expression.this,
781            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
782        )
783    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
786def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
787    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
790def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
791    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
795def encode_decode_sql(
796    self: Generator, expression: exp.Expression, name: str, replace: bool = True
797) -> str:
798    charset = expression.args.get("charset")
799    if charset and charset.name.lower() != "utf-8":
800        self.unsupported(f"Expected utf-8 character set, got {charset}.")
801
802    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
805def min_or_least(self: Generator, expression: exp.Min) -> str:
806    name = "LEAST" if expression.expressions else "MIN"
807    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
810def max_or_greatest(self: Generator, expression: exp.Max) -> str:
811    name = "GREATEST" if expression.expressions else "MAX"
812    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
815def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
816    cond = expression.this
817
818    if isinstance(expression.this, exp.Distinct):
819        cond = expression.this.expressions[0]
820        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
821
822    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
825def trim_sql(self: Generator, expression: exp.Trim) -> str:
826    target = self.sql(expression, "this")
827    trim_type = self.sql(expression, "position")
828    remove_chars = self.sql(expression, "expression")
829    collation = self.sql(expression, "collation")
830
831    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
832    if not remove_chars and not collation:
833        return self.trim_sql(expression)
834
835    trim_type = f"{trim_type} " if trim_type else ""
836    remove_chars = f"{remove_chars} " if remove_chars else ""
837    from_part = "FROM " if trim_type or remove_chars else ""
838    collation = f" COLLATE {collation}" if collation else ""
839    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
842def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
843    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
846def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
847    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
850def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
851    delim, *rest_args = expression.expressions
852    return self.sql(
853        reduce(
854            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
855            rest_args,
856        )
857    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
860def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
861    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
862    if bad_args:
863        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
864
865    return self.func(
866        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
867    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
870def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
871    bad_args = list(
872        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
873    )
874    if bad_args:
875        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
876
877    return self.func(
878        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
879    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
882def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
883    names = []
884    for agg in aggregations:
885        if isinstance(agg, exp.Alias):
886            names.append(agg.alias)
887        else:
888            """
889            This case corresponds to aggregations without aliases being used as suffixes
890            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
891            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
892            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
893            """
894            agg_all_unquoted = agg.transform(
895                lambda node: (
896                    exp.Identifier(this=node.name, quoted=False)
897                    if isinstance(node, exp.Identifier)
898                    else node
899                )
900            )
901            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
902
903    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
906def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
907    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
911def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
912    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
915def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
916    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
919def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
920    a = self.sql(expression.left)
921    b = self.sql(expression.right)
922    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
925def is_parse_json(expression: exp.Expression) -> bool:
926    return isinstance(expression, exp.ParseJSON) or (
927        isinstance(expression, exp.Cast) and expression.is_type("json")
928    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
931def isnull_to_is_null(args: t.List) -> exp.Expression:
932    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
935def generatedasidentitycolumnconstraint_sql(
936    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
937) -> str:
938    start = self.sql(expression, "start") or "1"
939    increment = self.sql(expression, "increment") or "1"
940    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
943def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
944    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
945        if expression.args.get("count"):
946            self.unsupported(f"Only two arguments are supported in function {name}.")
947
948        return self.func(name, expression.this, expression.expression)
949
950    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
953def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
954    this = expression.this.copy()
955
956    return_type = expression.return_type
957    if return_type.is_type(exp.DataType.Type.DATE):
958        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
959        # can truncate timestamp strings, because some dialects can't cast them to DATE
960        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
961
962    expression.this.replace(exp.cast(this, return_type))
963    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
966def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
967    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
968        if cast and isinstance(expression, exp.TsOrDsAdd):
969            expression = ts_or_ds_add_cast(expression)
970
971        return self.func(
972            name,
973            exp.var(expression.text("unit").upper() or "DAY"),
974            expression.expression,
975            expression.this,
976        )
977
978    return _delta_sql
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
981def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
982    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
983    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
984    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
985
986    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
 989def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 990    """Remove table refs from columns in when statements."""
 991    alias = expression.this.args.get("alias")
 992
 993    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 994        return self.dialect.normalize_identifier(identifier).name if identifier else None
 995
 996    targets = {normalize(expression.this.this)}
 997
 998    if alias:
 999        targets.add(normalize(alias.this))
1000
1001    for when in expression.expressions:
1002        when.transform(
1003            lambda node: (
1004                exp.column(node.this)
1005                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1006                else node
1007            ),
1008            copy=False,
1009        )
1010
1011    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def parse_json_extract_path( expr_type: Type[~E], supports_null_if_invalid: bool = False) -> Callable[[List], ~E]:
1014def parse_json_extract_path(
1015    expr_type: t.Type[E],
1016    supports_null_if_invalid: bool = False,
1017) -> t.Callable[[t.List], E]:
1018    def _parse_json_extract_path(args: t.List) -> E:
1019        null_if_invalid = None
1020
1021        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1022        for arg in args[1:]:
1023            if isinstance(arg, exp.Literal):
1024                text = arg.name
1025                if is_int(text):
1026                    segments.append(exp.JSONPathSubscript(this=int(text)))
1027                else:
1028                    segments.append(exp.JSONPathKey(this=text))
1029            elif supports_null_if_invalid:
1030                null_if_invalid = arg
1031
1032        this = seq_get(args, 0)
1033        jsonpath = exp.JSONPath(expressions=segments)
1034
1035        # This is done to avoid failing in the expression validator due to the arg count
1036        del args[2:]
1037
1038        if expr_type is exp.JSONExtractScalar:
1039            return expr_type(this=this, expression=jsonpath, null_if_invalid=null_if_invalid)
1040
1041        return expr_type(this=this, expression=jsonpath)
1042
1043    return _parse_json_extract_path
def json_path_segments( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPath) -> List[str]:
1046def json_path_segments(self: Generator, expression: exp.JSONPath) -> t.List[str]:
1047    segments = []
1048    for segment in expression.expressions:
1049        path = self.sql(segment)
1050        if path:
1051            segments.append(f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}")
1052
1053    return segments