sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20 21if t.TYPE_CHECKING: 22 from sqlglot._typing import B, E, F 23 24 JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 BIGQUERY = "bigquery" 35 CLICKHOUSE = "clickhouse" 36 DATABRICKS = "databricks" 37 DORIS = "doris" 38 DRILL = "drill" 39 DUCKDB = "duckdb" 40 HIVE = "hive" 41 MYSQL = "mysql" 42 ORACLE = "oracle" 43 POSTGRES = "postgres" 44 PRESTO = "presto" 45 REDSHIFT = "redshift" 46 SNOWFLAKE = "snowflake" 47 SPARK = "spark" 48 SPARK2 = "spark2" 49 SQLITE = "sqlite" 50 STARROCKS = "starrocks" 51 TABLEAU = "tableau" 52 TERADATA = "teradata" 53 TRINO = "trino" 54 TSQL = "tsql" 55 56 57class NormalizationStrategy(str, AutoName): 58 """Specifies the strategy according to which identifiers should be normalized.""" 59 60 LOWERCASE = auto() 61 """Unquoted identifiers are lowercased.""" 62 63 UPPERCASE = auto() 64 """Unquoted identifiers are uppercased.""" 65 66 CASE_SENSITIVE = auto() 67 """Always case-sensitive, regardless of quotes.""" 68 69 CASE_INSENSITIVE = auto() 70 """Always case-insensitive, regardless of quotes.""" 71 72 73class _Dialect(type): 74 classes: t.Dict[str, t.Type[Dialect]] = {} 75 76 def __eq__(cls, other: t.Any) -> bool: 77 if cls is other: 78 return True 79 if isinstance(other, str): 80 return cls is cls.get(other) 81 if isinstance(other, Dialect): 82 return cls is type(other) 83 84 return False 85 86 def __hash__(cls) -> int: 87 return hash(cls.__name__.lower()) 88 89 @classmethod 90 def __getitem__(cls, key: str) -> t.Type[Dialect]: 91 return cls.classes[key] 92 93 @classmethod 94 def get( 95 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 96 ) -> t.Optional[t.Type[Dialect]]: 97 return cls.classes.get(key, default) 98 99 def __new__(cls, clsname, bases, attrs): 100 klass = super().__new__(cls, clsname, bases, attrs) 101 enum = Dialects.__members__.get(clsname.upper()) 102 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 103 104 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 105 klass.FORMAT_TRIE = ( 106 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 107 ) 108 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 109 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 110 111 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 112 113 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 114 klass.parser_class = getattr(klass, "Parser", Parser) 115 klass.generator_class = getattr(klass, "Generator", Generator) 116 117 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 118 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 119 klass.tokenizer_class._IDENTIFIERS.items() 120 )[0] 121 122 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 123 return next( 124 ( 125 (s, e) 126 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 127 if t == token_type 128 ), 129 (None, None), 130 ) 131 132 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 133 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 134 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 135 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 136 137 if enum not in ("", "bigquery"): 138 klass.generator_class.SELECT_KINDS = () 139 140 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 141 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 142 TokenType.ANTI, 143 TokenType.SEMI, 144 } 145 146 return klass 147 148 149class Dialect(metaclass=_Dialect): 150 INDEX_OFFSET = 0 151 """Determines the base index offset for arrays.""" 152 153 WEEK_OFFSET = 0 154 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 155 156 UNNEST_COLUMN_ONLY = False 157 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 158 159 ALIAS_POST_TABLESAMPLE = False 160 """Determines whether or not the table alias comes after tablesample.""" 161 162 TABLESAMPLE_SIZE_IS_PERCENT = False 163 """Determines whether or not a size in the table sample clause represents percentage.""" 164 165 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 166 """Specifies the strategy according to which identifiers should be normalized.""" 167 168 IDENTIFIERS_CAN_START_WITH_DIGIT = False 169 """Determines whether or not an unquoted identifier can start with a digit.""" 170 171 DPIPE_IS_STRING_CONCAT = True 172 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 173 174 STRICT_STRING_CONCAT = False 175 """Determines whether or not `CONCAT`'s arguments must be strings.""" 176 177 SUPPORTS_USER_DEFINED_TYPES = True 178 """Determines whether or not user-defined data types are supported.""" 179 180 SUPPORTS_SEMI_ANTI_JOIN = True 181 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 182 183 NORMALIZE_FUNCTIONS: bool | str = "upper" 184 """Determines how function names are going to be normalized.""" 185 186 LOG_BASE_FIRST = True 187 """Determines whether the base comes first in the `LOG` function.""" 188 189 NULL_ORDERING = "nulls_are_small" 190 """ 191 Indicates the default `NULL` ordering method to use if not explicitly set. 192 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 193 """ 194 195 TYPED_DIVISION = False 196 """ 197 Whether the behavior of `a / b` depends on the types of `a` and `b`. 198 False means `a / b` is always float division. 199 True means `a / b` is integer division if both `a` and `b` are integers. 200 """ 201 202 SAFE_DIVISION = False 203 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 204 205 CONCAT_COALESCE = False 206 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 207 208 DATE_FORMAT = "'%Y-%m-%d'" 209 DATEINT_FORMAT = "'%Y%m%d'" 210 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 211 212 TIME_MAPPING: t.Dict[str, str] = {} 213 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 214 215 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 216 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 217 FORMAT_MAPPING: t.Dict[str, str] = {} 218 """ 219 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 220 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 221 """ 222 223 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 224 """Mapping of an unescaped escape sequence to the corresponding character.""" 225 226 PSEUDOCOLUMNS: t.Set[str] = set() 227 """ 228 Columns that are auto-generated by the engine corresponding to this dialect. 229 For example, such columns may be excluded from `SELECT *` queries. 230 """ 231 232 PREFER_CTE_ALIAS_COLUMN = False 233 """ 234 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 235 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 236 any projection aliases in the subquery. 237 238 For example, 239 WITH y(c) AS ( 240 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 241 ) SELECT c FROM y; 242 243 will be rewritten as 244 245 WITH y(c) AS ( 246 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 247 ) SELECT c FROM y; 248 """ 249 250 # --- Autofilled --- 251 252 tokenizer_class = Tokenizer 253 parser_class = Parser 254 generator_class = Generator 255 256 # A trie of the time_mapping keys 257 TIME_TRIE: t.Dict = {} 258 FORMAT_TRIE: t.Dict = {} 259 260 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 261 INVERSE_TIME_TRIE: t.Dict = {} 262 263 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 264 265 # Delimiters for string literals and identifiers 266 QUOTE_START = "'" 267 QUOTE_END = "'" 268 IDENTIFIER_START = '"' 269 IDENTIFIER_END = '"' 270 271 # Delimiters for bit, hex, byte and unicode literals 272 BIT_START: t.Optional[str] = None 273 BIT_END: t.Optional[str] = None 274 HEX_START: t.Optional[str] = None 275 HEX_END: t.Optional[str] = None 276 BYTE_START: t.Optional[str] = None 277 BYTE_END: t.Optional[str] = None 278 UNICODE_START: t.Optional[str] = None 279 UNICODE_END: t.Optional[str] = None 280 281 @classmethod 282 def get_or_raise(cls, dialect: DialectType) -> Dialect: 283 """ 284 Look up a dialect in the global dialect registry and return it if it exists. 285 286 Args: 287 dialect: The target dialect. If this is a string, it can be optionally followed by 288 additional key-value pairs that are separated by commas and are used to specify 289 dialect settings, such as whether the dialect's identifiers are case-sensitive. 290 291 Example: 292 >>> dialect = dialect_class = get_or_raise("duckdb") 293 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 294 295 Returns: 296 The corresponding Dialect instance. 297 """ 298 299 if not dialect: 300 return cls() 301 if isinstance(dialect, _Dialect): 302 return dialect() 303 if isinstance(dialect, Dialect): 304 return dialect 305 if isinstance(dialect, str): 306 try: 307 dialect_name, *kv_pairs = dialect.split(",") 308 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 309 except ValueError: 310 raise ValueError( 311 f"Invalid dialect format: '{dialect}'. " 312 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 313 ) 314 315 result = cls.get(dialect_name.strip()) 316 if not result: 317 from difflib import get_close_matches 318 319 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 320 if similar: 321 similar = f" Did you mean {similar}?" 322 323 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 324 325 return result(**kwargs) 326 327 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 328 329 @classmethod 330 def format_time( 331 cls, expression: t.Optional[str | exp.Expression] 332 ) -> t.Optional[exp.Expression]: 333 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 334 if isinstance(expression, str): 335 return exp.Literal.string( 336 # the time formats are quoted 337 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 338 ) 339 340 if expression and expression.is_string: 341 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 342 343 return expression 344 345 def __init__(self, **kwargs) -> None: 346 normalization_strategy = kwargs.get("normalization_strategy") 347 348 if normalization_strategy is None: 349 self.normalization_strategy = self.NORMALIZATION_STRATEGY 350 else: 351 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 352 353 def __eq__(self, other: t.Any) -> bool: 354 # Does not currently take dialect state into account 355 return type(self) == other 356 357 def __hash__(self) -> int: 358 # Does not currently take dialect state into account 359 return hash(type(self)) 360 361 def normalize_identifier(self, expression: E) -> E: 362 """ 363 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 364 365 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 366 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 367 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 368 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 369 370 There are also dialects like Spark, which are case-insensitive even when quotes are 371 present, and dialects like MySQL, whose resolution rules match those employed by the 372 underlying operating system, for example they may always be case-sensitive in Linux. 373 374 Finally, the normalization behavior of some engines can even be controlled through flags, 375 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 376 377 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 378 that it can analyze queries in the optimizer and successfully capture their semantics. 379 """ 380 if ( 381 isinstance(expression, exp.Identifier) 382 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 383 and ( 384 not expression.quoted 385 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 386 ) 387 ): 388 expression.set( 389 "this", 390 ( 391 expression.this.upper() 392 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 393 else expression.this.lower() 394 ), 395 ) 396 397 return expression 398 399 def case_sensitive(self, text: str) -> bool: 400 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 401 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 402 return False 403 404 unsafe = ( 405 str.islower 406 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 407 else str.isupper 408 ) 409 return any(unsafe(char) for char in text) 410 411 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 412 """Checks if text can be identified given an identify option. 413 414 Args: 415 text: The text to check. 416 identify: 417 `"always"` or `True`: Always returns `True`. 418 `"safe"`: Only returns `True` if the identifier is case-insensitive. 419 420 Returns: 421 Whether or not the given text can be identified. 422 """ 423 if identify is True or identify == "always": 424 return True 425 426 if identify == "safe": 427 return not self.case_sensitive(text) 428 429 return False 430 431 def quote_identifier(self, expression: E, identify: bool = True) -> E: 432 """ 433 Adds quotes to a given identifier. 434 435 Args: 436 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 437 identify: If set to `False`, the quotes will only be added if the identifier is deemed 438 "unsafe", with respect to its characters and this dialect's normalization strategy. 439 """ 440 if isinstance(expression, exp.Identifier): 441 name = expression.this 442 expression.set( 443 "quoted", 444 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 445 ) 446 447 return expression 448 449 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 450 if isinstance(path, exp.Literal): 451 path_text = path.name 452 if path.is_number: 453 path_text = f"[{path_text}]" 454 455 try: 456 return parse_json_path(path_text) 457 except ParseError as e: 458 logger.warning(f"Invalid JSON path syntax. {str(e)}") 459 460 return path 461 462 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 463 return self.parser(**opts).parse(self.tokenize(sql), sql) 464 465 def parse_into( 466 self, expression_type: exp.IntoType, sql: str, **opts 467 ) -> t.List[t.Optional[exp.Expression]]: 468 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 469 470 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 471 return self.generator(**opts).generate(expression, copy=copy) 472 473 def transpile(self, sql: str, **opts) -> t.List[str]: 474 return [ 475 self.generate(expression, copy=False, **opts) if expression else "" 476 for expression in self.parse(sql) 477 ] 478 479 def tokenize(self, sql: str) -> t.List[Token]: 480 return self.tokenizer.tokenize(sql) 481 482 @property 483 def tokenizer(self) -> Tokenizer: 484 if not hasattr(self, "_tokenizer"): 485 self._tokenizer = self.tokenizer_class(dialect=self) 486 return self._tokenizer 487 488 def parser(self, **opts) -> Parser: 489 return self.parser_class(dialect=self, **opts) 490 491 def generator(self, **opts) -> Generator: 492 return self.generator_class(dialect=self, **opts) 493 494 495DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 496 497 498def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 499 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 500 501 502def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 503 if expression.args.get("accuracy"): 504 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 505 return self.func("APPROX_COUNT_DISTINCT", expression.this) 506 507 508def if_sql( 509 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 510) -> t.Callable[[Generator, exp.If], str]: 511 def _if_sql(self: Generator, expression: exp.If) -> str: 512 return self.func( 513 name, 514 expression.this, 515 expression.args.get("true"), 516 expression.args.get("false") or false_value, 517 ) 518 519 return _if_sql 520 521 522def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 523 this = expression.this 524 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 525 this.replace(exp.cast(this, "json")) 526 527 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 528 529 530def inline_array_sql(self: Generator, expression: exp.Array) -> str: 531 return f"[{self.expressions(expression, flat=True)}]" 532 533 534def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 535 return self.like_sql( 536 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 537 ) 538 539 540def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 541 zone = self.sql(expression, "this") 542 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 543 544 545def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 546 if expression.args.get("recursive"): 547 self.unsupported("Recursive CTEs are unsupported") 548 expression.args["recursive"] = False 549 return self.with_sql(expression) 550 551 552def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 553 n = self.sql(expression, "this") 554 d = self.sql(expression, "expression") 555 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 556 557 558def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 559 self.unsupported("TABLESAMPLE unsupported") 560 return self.sql(expression.this) 561 562 563def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 564 self.unsupported("PIVOT unsupported") 565 return "" 566 567 568def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 569 return self.cast_sql(expression) 570 571 572def no_comment_column_constraint_sql( 573 self: Generator, expression: exp.CommentColumnConstraint 574) -> str: 575 self.unsupported("CommentColumnConstraint unsupported") 576 return "" 577 578 579def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 580 self.unsupported("MAP_FROM_ENTRIES unsupported") 581 return "" 582 583 584def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 585 this = self.sql(expression, "this") 586 substr = self.sql(expression, "substr") 587 position = self.sql(expression, "position") 588 if position: 589 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 590 return f"STRPOS({this}, {substr})" 591 592 593def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 594 return ( 595 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 596 ) 597 598 599def var_map_sql( 600 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 601) -> str: 602 keys = expression.args["keys"] 603 values = expression.args["values"] 604 605 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 606 self.unsupported("Cannot convert array columns into map.") 607 return self.func(map_func_name, keys, values) 608 609 args = [] 610 for key, value in zip(keys.expressions, values.expressions): 611 args.append(self.sql(key)) 612 args.append(self.sql(value)) 613 614 return self.func(map_func_name, *args) 615 616 617def format_time_lambda( 618 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 619) -> t.Callable[[t.List], E]: 620 """Helper used for time expressions. 621 622 Args: 623 exp_class: the expression class to instantiate. 624 dialect: target sql dialect. 625 default: the default format, True being time. 626 627 Returns: 628 A callable that can be used to return the appropriately formatted time expression. 629 """ 630 631 def _format_time(args: t.List): 632 return exp_class( 633 this=seq_get(args, 0), 634 format=Dialect[dialect].format_time( 635 seq_get(args, 1) 636 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 637 ), 638 ) 639 640 return _format_time 641 642 643def time_format( 644 dialect: DialectType = None, 645) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 646 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 647 """ 648 Returns the time format for a given expression, unless it's equivalent 649 to the default time format of the dialect of interest. 650 """ 651 time_format = self.format_time(expression) 652 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 653 654 return _time_format 655 656 657def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 658 """ 659 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 660 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 661 columns are removed from the create statement. 662 """ 663 has_schema = isinstance(expression.this, exp.Schema) 664 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 665 666 if has_schema and is_partitionable: 667 prop = expression.find(exp.PartitionedByProperty) 668 if prop and prop.this and not isinstance(prop.this, exp.Schema): 669 schema = expression.this 670 columns = {v.name.upper() for v in prop.this.expressions} 671 partitions = [col for col in schema.expressions if col.name.upper() in columns] 672 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 673 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 674 expression.set("this", schema) 675 676 return self.create_sql(expression) 677 678 679def parse_date_delta( 680 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 681) -> t.Callable[[t.List], E]: 682 def inner_func(args: t.List) -> E: 683 unit_based = len(args) == 3 684 this = args[2] if unit_based else seq_get(args, 0) 685 unit = args[0] if unit_based else exp.Literal.string("DAY") 686 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 687 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 688 689 return inner_func 690 691 692def parse_date_delta_with_interval( 693 expression_class: t.Type[E], 694) -> t.Callable[[t.List], t.Optional[E]]: 695 def func(args: t.List) -> t.Optional[E]: 696 if len(args) < 2: 697 return None 698 699 interval = args[1] 700 701 if not isinstance(interval, exp.Interval): 702 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 703 704 expression = interval.this 705 if expression and expression.is_string: 706 expression = exp.Literal.number(expression.this) 707 708 return expression_class( 709 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 710 ) 711 712 return func 713 714 715def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 716 unit = seq_get(args, 0) 717 this = seq_get(args, 1) 718 719 if isinstance(this, exp.Cast) and this.is_type("date"): 720 return exp.DateTrunc(unit=unit, this=this) 721 return exp.TimestampTrunc(this=this, unit=unit) 722 723 724def date_add_interval_sql( 725 data_type: str, kind: str 726) -> t.Callable[[Generator, exp.Expression], str]: 727 def func(self: Generator, expression: exp.Expression) -> str: 728 this = self.sql(expression, "this") 729 unit = expression.args.get("unit") 730 unit = exp.var(unit.name.upper() if unit else "DAY") 731 interval = exp.Interval(this=expression.expression, unit=unit) 732 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 733 734 return func 735 736 737def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 738 return self.func( 739 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 740 ) 741 742 743def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 744 if not expression.expression: 745 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 746 if expression.text("expression").lower() in TIMEZONES: 747 return self.sql( 748 exp.AtTimeZone( 749 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 750 zone=expression.expression, 751 ) 752 ) 753 return self.function_fallback_sql(expression) 754 755 756def locate_to_strposition(args: t.List) -> exp.Expression: 757 return exp.StrPosition( 758 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 759 ) 760 761 762def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 763 return self.func( 764 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 765 ) 766 767 768def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 769 return self.sql( 770 exp.Substring( 771 this=expression.this, start=exp.Literal.number(1), length=expression.expression 772 ) 773 ) 774 775 776def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 777 return self.sql( 778 exp.Substring( 779 this=expression.this, 780 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 781 ) 782 ) 783 784 785def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 786 return self.sql(exp.cast(expression.this, "timestamp")) 787 788 789def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 790 return self.sql(exp.cast(expression.this, "date")) 791 792 793# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 794def encode_decode_sql( 795 self: Generator, expression: exp.Expression, name: str, replace: bool = True 796) -> str: 797 charset = expression.args.get("charset") 798 if charset and charset.name.lower() != "utf-8": 799 self.unsupported(f"Expected utf-8 character set, got {charset}.") 800 801 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 802 803 804def min_or_least(self: Generator, expression: exp.Min) -> str: 805 name = "LEAST" if expression.expressions else "MIN" 806 return rename_func(name)(self, expression) 807 808 809def max_or_greatest(self: Generator, expression: exp.Max) -> str: 810 name = "GREATEST" if expression.expressions else "MAX" 811 return rename_func(name)(self, expression) 812 813 814def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 815 cond = expression.this 816 817 if isinstance(expression.this, exp.Distinct): 818 cond = expression.this.expressions[0] 819 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 820 821 return self.func("sum", exp.func("if", cond, 1, 0)) 822 823 824def trim_sql(self: Generator, expression: exp.Trim) -> str: 825 target = self.sql(expression, "this") 826 trim_type = self.sql(expression, "position") 827 remove_chars = self.sql(expression, "expression") 828 collation = self.sql(expression, "collation") 829 830 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 831 if not remove_chars and not collation: 832 return self.trim_sql(expression) 833 834 trim_type = f"{trim_type} " if trim_type else "" 835 remove_chars = f"{remove_chars} " if remove_chars else "" 836 from_part = "FROM " if trim_type or remove_chars else "" 837 collation = f" COLLATE {collation}" if collation else "" 838 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 839 840 841def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 842 return self.func("STRPTIME", expression.this, self.format_time(expression)) 843 844 845def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 846 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 847 848 849def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 850 delim, *rest_args = expression.expressions 851 return self.sql( 852 reduce( 853 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 854 rest_args, 855 ) 856 ) 857 858 859def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 860 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 861 if bad_args: 862 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 863 864 return self.func( 865 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 866 ) 867 868 869def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 870 bad_args = list( 871 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 872 ) 873 if bad_args: 874 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 875 876 return self.func( 877 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 878 ) 879 880 881def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 882 names = [] 883 for agg in aggregations: 884 if isinstance(agg, exp.Alias): 885 names.append(agg.alias) 886 else: 887 """ 888 This case corresponds to aggregations without aliases being used as suffixes 889 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 890 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 891 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 892 """ 893 agg_all_unquoted = agg.transform( 894 lambda node: ( 895 exp.Identifier(this=node.name, quoted=False) 896 if isinstance(node, exp.Identifier) 897 else node 898 ) 899 ) 900 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 901 902 return names 903 904 905def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 906 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 907 908 909# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 910def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 911 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 912 913 914def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 915 return self.func("MAX", expression.this) 916 917 918def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 919 a = self.sql(expression.left) 920 b = self.sql(expression.right) 921 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 922 923 924def is_parse_json(expression: exp.Expression) -> bool: 925 return isinstance(expression, exp.ParseJSON) or ( 926 isinstance(expression, exp.Cast) and expression.is_type("json") 927 ) 928 929 930def isnull_to_is_null(args: t.List) -> exp.Expression: 931 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 932 933 934def generatedasidentitycolumnconstraint_sql( 935 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 936) -> str: 937 start = self.sql(expression, "start") or "1" 938 increment = self.sql(expression, "increment") or "1" 939 return f"IDENTITY({start}, {increment})" 940 941 942def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 943 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 944 if expression.args.get("count"): 945 self.unsupported(f"Only two arguments are supported in function {name}.") 946 947 return self.func(name, expression.this, expression.expression) 948 949 return _arg_max_or_min_sql 950 951 952def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 953 this = expression.this.copy() 954 955 return_type = expression.return_type 956 if return_type.is_type(exp.DataType.Type.DATE): 957 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 958 # can truncate timestamp strings, because some dialects can't cast them to DATE 959 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 960 961 expression.this.replace(exp.cast(this, return_type)) 962 return expression 963 964 965def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 966 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 967 if cast and isinstance(expression, exp.TsOrDsAdd): 968 expression = ts_or_ds_add_cast(expression) 969 970 return self.func( 971 name, 972 exp.var(expression.text("unit").upper() or "DAY"), 973 expression.expression, 974 expression.this, 975 ) 976 977 return _delta_sql 978 979 980def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 981 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 982 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 983 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 984 985 return self.sql(exp.cast(minus_one_day, "date")) 986 987 988def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 989 """Remove table refs from columns in when statements.""" 990 alias = expression.this.args.get("alias") 991 992 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 993 return self.dialect.normalize_identifier(identifier).name if identifier else None 994 995 targets = {normalize(expression.this.this)} 996 997 if alias: 998 targets.add(normalize(alias.this)) 999 1000 for when in expression.expressions: 1001 when.transform( 1002 lambda node: ( 1003 exp.column(node.this) 1004 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1005 else node 1006 ), 1007 copy=False, 1008 ) 1009 1010 return self.merge_sql(expression) 1011 1012 1013def parse_json_extract_path( 1014 expr_type: t.Type[F], zero_based_indexing: bool = True 1015) -> t.Callable[[t.List], F]: 1016 def _parse_json_extract_path(args: t.List) -> F: 1017 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1018 for arg in args[1:]: 1019 if not isinstance(arg, exp.Literal): 1020 # We use the fallback parser because we can't really transpile non-literals safely 1021 return expr_type.from_arg_list(args) 1022 1023 text = arg.name 1024 if is_int(text): 1025 index = int(text) 1026 segments.append( 1027 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1028 ) 1029 else: 1030 segments.append(exp.JSONPathKey(this=text)) 1031 1032 # This is done to avoid failing in the expression validator due to the arg count 1033 del args[2:] 1034 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1035 1036 return _parse_json_extract_path 1037 1038 1039def json_extract_segments( 1040 name: str, quoted_index: bool = True 1041) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1042 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1043 path = expression.expression 1044 if not isinstance(path, exp.JSONPath): 1045 return rename_func(name)(self, expression) 1046 1047 segments = [] 1048 for segment in path.expressions: 1049 path = self.sql(segment) 1050 if path: 1051 if isinstance(segment, exp.JSONPathPart) and ( 1052 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1053 ): 1054 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1055 1056 segments.append(path) 1057 1058 return self.func(name, expression.this, *segments) 1059 1060 return _json_extract_segments 1061 1062 1063def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1064 if isinstance(expression.this, exp.JSONPathWildcard): 1065 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1066 1067 return expression.name
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 REDSHIFT = "redshift" 47 SNOWFLAKE = "snowflake" 48 SPARK = "spark" 49 SPARK2 = "spark2" 50 SQLITE = "sqlite" 51 STARROCKS = "starrocks" 52 TABLEAU = "tableau" 53 TERADATA = "teradata" 54 TRINO = "trino" 55 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
58class NormalizationStrategy(str, AutoName): 59 """Specifies the strategy according to which identifiers should be normalized.""" 60 61 LOWERCASE = auto() 62 """Unquoted identifiers are lowercased.""" 63 64 UPPERCASE = auto() 65 """Unquoted identifiers are uppercased.""" 66 67 CASE_SENSITIVE = auto() 68 """Always case-sensitive, regardless of quotes.""" 69 70 CASE_INSENSITIVE = auto() 71 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
150class Dialect(metaclass=_Dialect): 151 INDEX_OFFSET = 0 152 """Determines the base index offset for arrays.""" 153 154 WEEK_OFFSET = 0 155 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 156 157 UNNEST_COLUMN_ONLY = False 158 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 159 160 ALIAS_POST_TABLESAMPLE = False 161 """Determines whether or not the table alias comes after tablesample.""" 162 163 TABLESAMPLE_SIZE_IS_PERCENT = False 164 """Determines whether or not a size in the table sample clause represents percentage.""" 165 166 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 167 """Specifies the strategy according to which identifiers should be normalized.""" 168 169 IDENTIFIERS_CAN_START_WITH_DIGIT = False 170 """Determines whether or not an unquoted identifier can start with a digit.""" 171 172 DPIPE_IS_STRING_CONCAT = True 173 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 174 175 STRICT_STRING_CONCAT = False 176 """Determines whether or not `CONCAT`'s arguments must be strings.""" 177 178 SUPPORTS_USER_DEFINED_TYPES = True 179 """Determines whether or not user-defined data types are supported.""" 180 181 SUPPORTS_SEMI_ANTI_JOIN = True 182 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 183 184 NORMALIZE_FUNCTIONS: bool | str = "upper" 185 """Determines how function names are going to be normalized.""" 186 187 LOG_BASE_FIRST = True 188 """Determines whether the base comes first in the `LOG` function.""" 189 190 NULL_ORDERING = "nulls_are_small" 191 """ 192 Indicates the default `NULL` ordering method to use if not explicitly set. 193 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 194 """ 195 196 TYPED_DIVISION = False 197 """ 198 Whether the behavior of `a / b` depends on the types of `a` and `b`. 199 False means `a / b` is always float division. 200 True means `a / b` is integer division if both `a` and `b` are integers. 201 """ 202 203 SAFE_DIVISION = False 204 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 205 206 CONCAT_COALESCE = False 207 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 208 209 DATE_FORMAT = "'%Y-%m-%d'" 210 DATEINT_FORMAT = "'%Y%m%d'" 211 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 212 213 TIME_MAPPING: t.Dict[str, str] = {} 214 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 215 216 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 217 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 218 FORMAT_MAPPING: t.Dict[str, str] = {} 219 """ 220 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 221 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 222 """ 223 224 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 225 """Mapping of an unescaped escape sequence to the corresponding character.""" 226 227 PSEUDOCOLUMNS: t.Set[str] = set() 228 """ 229 Columns that are auto-generated by the engine corresponding to this dialect. 230 For example, such columns may be excluded from `SELECT *` queries. 231 """ 232 233 PREFER_CTE_ALIAS_COLUMN = False 234 """ 235 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 236 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 237 any projection aliases in the subquery. 238 239 For example, 240 WITH y(c) AS ( 241 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 242 ) SELECT c FROM y; 243 244 will be rewritten as 245 246 WITH y(c) AS ( 247 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 248 ) SELECT c FROM y; 249 """ 250 251 # --- Autofilled --- 252 253 tokenizer_class = Tokenizer 254 parser_class = Parser 255 generator_class = Generator 256 257 # A trie of the time_mapping keys 258 TIME_TRIE: t.Dict = {} 259 FORMAT_TRIE: t.Dict = {} 260 261 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 262 INVERSE_TIME_TRIE: t.Dict = {} 263 264 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 265 266 # Delimiters for string literals and identifiers 267 QUOTE_START = "'" 268 QUOTE_END = "'" 269 IDENTIFIER_START = '"' 270 IDENTIFIER_END = '"' 271 272 # Delimiters for bit, hex, byte and unicode literals 273 BIT_START: t.Optional[str] = None 274 BIT_END: t.Optional[str] = None 275 HEX_START: t.Optional[str] = None 276 HEX_END: t.Optional[str] = None 277 BYTE_START: t.Optional[str] = None 278 BYTE_END: t.Optional[str] = None 279 UNICODE_START: t.Optional[str] = None 280 UNICODE_END: t.Optional[str] = None 281 282 @classmethod 283 def get_or_raise(cls, dialect: DialectType) -> Dialect: 284 """ 285 Look up a dialect in the global dialect registry and return it if it exists. 286 287 Args: 288 dialect: The target dialect. If this is a string, it can be optionally followed by 289 additional key-value pairs that are separated by commas and are used to specify 290 dialect settings, such as whether the dialect's identifiers are case-sensitive. 291 292 Example: 293 >>> dialect = dialect_class = get_or_raise("duckdb") 294 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 295 296 Returns: 297 The corresponding Dialect instance. 298 """ 299 300 if not dialect: 301 return cls() 302 if isinstance(dialect, _Dialect): 303 return dialect() 304 if isinstance(dialect, Dialect): 305 return dialect 306 if isinstance(dialect, str): 307 try: 308 dialect_name, *kv_pairs = dialect.split(",") 309 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 310 except ValueError: 311 raise ValueError( 312 f"Invalid dialect format: '{dialect}'. " 313 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 314 ) 315 316 result = cls.get(dialect_name.strip()) 317 if not result: 318 from difflib import get_close_matches 319 320 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 321 if similar: 322 similar = f" Did you mean {similar}?" 323 324 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 325 326 return result(**kwargs) 327 328 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 329 330 @classmethod 331 def format_time( 332 cls, expression: t.Optional[str | exp.Expression] 333 ) -> t.Optional[exp.Expression]: 334 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 335 if isinstance(expression, str): 336 return exp.Literal.string( 337 # the time formats are quoted 338 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 339 ) 340 341 if expression and expression.is_string: 342 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 343 344 return expression 345 346 def __init__(self, **kwargs) -> None: 347 normalization_strategy = kwargs.get("normalization_strategy") 348 349 if normalization_strategy is None: 350 self.normalization_strategy = self.NORMALIZATION_STRATEGY 351 else: 352 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 353 354 def __eq__(self, other: t.Any) -> bool: 355 # Does not currently take dialect state into account 356 return type(self) == other 357 358 def __hash__(self) -> int: 359 # Does not currently take dialect state into account 360 return hash(type(self)) 361 362 def normalize_identifier(self, expression: E) -> E: 363 """ 364 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 365 366 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 367 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 368 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 369 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 370 371 There are also dialects like Spark, which are case-insensitive even when quotes are 372 present, and dialects like MySQL, whose resolution rules match those employed by the 373 underlying operating system, for example they may always be case-sensitive in Linux. 374 375 Finally, the normalization behavior of some engines can even be controlled through flags, 376 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 377 378 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 379 that it can analyze queries in the optimizer and successfully capture their semantics. 380 """ 381 if ( 382 isinstance(expression, exp.Identifier) 383 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 384 and ( 385 not expression.quoted 386 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 387 ) 388 ): 389 expression.set( 390 "this", 391 ( 392 expression.this.upper() 393 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 394 else expression.this.lower() 395 ), 396 ) 397 398 return expression 399 400 def case_sensitive(self, text: str) -> bool: 401 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 402 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 403 return False 404 405 unsafe = ( 406 str.islower 407 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 408 else str.isupper 409 ) 410 return any(unsafe(char) for char in text) 411 412 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 413 """Checks if text can be identified given an identify option. 414 415 Args: 416 text: The text to check. 417 identify: 418 `"always"` or `True`: Always returns `True`. 419 `"safe"`: Only returns `True` if the identifier is case-insensitive. 420 421 Returns: 422 Whether or not the given text can be identified. 423 """ 424 if identify is True or identify == "always": 425 return True 426 427 if identify == "safe": 428 return not self.case_sensitive(text) 429 430 return False 431 432 def quote_identifier(self, expression: E, identify: bool = True) -> E: 433 """ 434 Adds quotes to a given identifier. 435 436 Args: 437 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 438 identify: If set to `False`, the quotes will only be added if the identifier is deemed 439 "unsafe", with respect to its characters and this dialect's normalization strategy. 440 """ 441 if isinstance(expression, exp.Identifier): 442 name = expression.this 443 expression.set( 444 "quoted", 445 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 446 ) 447 448 return expression 449 450 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 451 if isinstance(path, exp.Literal): 452 path_text = path.name 453 if path.is_number: 454 path_text = f"[{path_text}]" 455 456 try: 457 return parse_json_path(path_text) 458 except ParseError as e: 459 logger.warning(f"Invalid JSON path syntax. {str(e)}") 460 461 return path 462 463 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 464 return self.parser(**opts).parse(self.tokenize(sql), sql) 465 466 def parse_into( 467 self, expression_type: exp.IntoType, sql: str, **opts 468 ) -> t.List[t.Optional[exp.Expression]]: 469 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 470 471 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 472 return self.generator(**opts).generate(expression, copy=copy) 473 474 def transpile(self, sql: str, **opts) -> t.List[str]: 475 return [ 476 self.generate(expression, copy=False, **opts) if expression else "" 477 for expression in self.parse(sql) 478 ] 479 480 def tokenize(self, sql: str) -> t.List[Token]: 481 return self.tokenizer.tokenize(sql) 482 483 @property 484 def tokenizer(self) -> Tokenizer: 485 if not hasattr(self, "_tokenizer"): 486 self._tokenizer = self.tokenizer_class(dialect=self) 487 return self._tokenizer 488 489 def parser(self, **opts) -> Parser: 490 return self.parser_class(dialect=self, **opts) 491 492 def generator(self, **opts) -> Generator: 493 return self.generator_class(dialect=self, **opts)
346 def __init__(self, **kwargs) -> None: 347 normalization_strategy = kwargs.get("normalization_strategy") 348 349 if normalization_strategy is None: 350 self.normalization_strategy = self.NORMALIZATION_STRATEGY 351 else: 352 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Determines whether or not UNNEST
table aliases are treated as column aliases.
Determines whether or not a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines whether or not an unquoted identifier can start with a digit.
Determines whether or not the DPIPE token (||
) is a string concatenation operator.
Indicates the default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
Determines whether division by zero throws an error (False
) or returns NULL (True
).
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
format.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
282 @classmethod 283 def get_or_raise(cls, dialect: DialectType) -> Dialect: 284 """ 285 Look up a dialect in the global dialect registry and return it if it exists. 286 287 Args: 288 dialect: The target dialect. If this is a string, it can be optionally followed by 289 additional key-value pairs that are separated by commas and are used to specify 290 dialect settings, such as whether the dialect's identifiers are case-sensitive. 291 292 Example: 293 >>> dialect = dialect_class = get_or_raise("duckdb") 294 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 295 296 Returns: 297 The corresponding Dialect instance. 298 """ 299 300 if not dialect: 301 return cls() 302 if isinstance(dialect, _Dialect): 303 return dialect() 304 if isinstance(dialect, Dialect): 305 return dialect 306 if isinstance(dialect, str): 307 try: 308 dialect_name, *kv_pairs = dialect.split(",") 309 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 310 except ValueError: 311 raise ValueError( 312 f"Invalid dialect format: '{dialect}'. " 313 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 314 ) 315 316 result = cls.get(dialect_name.strip()) 317 if not result: 318 from difflib import get_close_matches 319 320 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 321 if similar: 322 similar = f" Did you mean {similar}?" 323 324 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 325 326 return result(**kwargs) 327 328 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
330 @classmethod 331 def format_time( 332 cls, expression: t.Optional[str | exp.Expression] 333 ) -> t.Optional[exp.Expression]: 334 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 335 if isinstance(expression, str): 336 return exp.Literal.string( 337 # the time formats are quoted 338 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 339 ) 340 341 if expression and expression.is_string: 342 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 343 344 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
362 def normalize_identifier(self, expression: E) -> E: 363 """ 364 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 365 366 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 367 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 368 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 369 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 370 371 There are also dialects like Spark, which are case-insensitive even when quotes are 372 present, and dialects like MySQL, whose resolution rules match those employed by the 373 underlying operating system, for example they may always be case-sensitive in Linux. 374 375 Finally, the normalization behavior of some engines can even be controlled through flags, 376 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 377 378 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 379 that it can analyze queries in the optimizer and successfully capture their semantics. 380 """ 381 if ( 382 isinstance(expression, exp.Identifier) 383 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 384 and ( 385 not expression.quoted 386 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 387 ) 388 ): 389 expression.set( 390 "this", 391 ( 392 expression.this.upper() 393 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 394 else expression.this.lower() 395 ), 396 ) 397 398 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
400 def case_sensitive(self, text: str) -> bool: 401 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 402 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 403 return False 404 405 unsafe = ( 406 str.islower 407 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 408 else str.isupper 409 ) 410 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
412 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 413 """Checks if text can be identified given an identify option. 414 415 Args: 416 text: The text to check. 417 identify: 418 `"always"` or `True`: Always returns `True`. 419 `"safe"`: Only returns `True` if the identifier is case-insensitive. 420 421 Returns: 422 Whether or not the given text can be identified. 423 """ 424 if identify is True or identify == "always": 425 return True 426 427 if identify == "safe": 428 return not self.case_sensitive(text) 429 430 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
432 def quote_identifier(self, expression: E, identify: bool = True) -> E: 433 """ 434 Adds quotes to a given identifier. 435 436 Args: 437 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 438 identify: If set to `False`, the quotes will only be added if the identifier is deemed 439 "unsafe", with respect to its characters and this dialect's normalization strategy. 440 """ 441 if isinstance(expression, exp.Identifier): 442 name = expression.this 443 expression.set( 444 "quoted", 445 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 446 ) 447 448 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
450 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 451 if isinstance(path, exp.Literal): 452 path_text = path.name 453 if path.is_number: 454 path_text = f"[{path_text}]" 455 456 try: 457 return parse_json_path(path_text) 458 except ParseError as e: 459 logger.warning(f"Invalid JSON path syntax. {str(e)}") 460 461 return path
509def if_sql( 510 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 511) -> t.Callable[[Generator, exp.If], str]: 512 def _if_sql(self: Generator, expression: exp.If) -> str: 513 return self.func( 514 name, 515 expression.this, 516 expression.args.get("true"), 517 expression.args.get("false") or false_value, 518 ) 519 520 return _if_sql
523def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 524 this = expression.this 525 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 526 this.replace(exp.cast(this, "json")) 527 528 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
585def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 586 this = self.sql(expression, "this") 587 substr = self.sql(expression, "substr") 588 position = self.sql(expression, "position") 589 if position: 590 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 591 return f"STRPOS({this}, {substr})"
600def var_map_sql( 601 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 602) -> str: 603 keys = expression.args["keys"] 604 values = expression.args["values"] 605 606 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 607 self.unsupported("Cannot convert array columns into map.") 608 return self.func(map_func_name, keys, values) 609 610 args = [] 611 for key, value in zip(keys.expressions, values.expressions): 612 args.append(self.sql(key)) 613 args.append(self.sql(value)) 614 615 return self.func(map_func_name, *args)
618def format_time_lambda( 619 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 620) -> t.Callable[[t.List], E]: 621 """Helper used for time expressions. 622 623 Args: 624 exp_class: the expression class to instantiate. 625 dialect: target sql dialect. 626 default: the default format, True being time. 627 628 Returns: 629 A callable that can be used to return the appropriately formatted time expression. 630 """ 631 632 def _format_time(args: t.List): 633 return exp_class( 634 this=seq_get(args, 0), 635 format=Dialect[dialect].format_time( 636 seq_get(args, 1) 637 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 638 ), 639 ) 640 641 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
644def time_format( 645 dialect: DialectType = None, 646) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 647 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 648 """ 649 Returns the time format for a given expression, unless it's equivalent 650 to the default time format of the dialect of interest. 651 """ 652 time_format = self.format_time(expression) 653 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 654 655 return _time_format
658def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 659 """ 660 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 661 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 662 columns are removed from the create statement. 663 """ 664 has_schema = isinstance(expression.this, exp.Schema) 665 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 666 667 if has_schema and is_partitionable: 668 prop = expression.find(exp.PartitionedByProperty) 669 if prop and prop.this and not isinstance(prop.this, exp.Schema): 670 schema = expression.this 671 columns = {v.name.upper() for v in prop.this.expressions} 672 partitions = [col for col in schema.expressions if col.name.upper() in columns] 673 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 674 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 675 expression.set("this", schema) 676 677 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
680def parse_date_delta( 681 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 682) -> t.Callable[[t.List], E]: 683 def inner_func(args: t.List) -> E: 684 unit_based = len(args) == 3 685 this = args[2] if unit_based else seq_get(args, 0) 686 unit = args[0] if unit_based else exp.Literal.string("DAY") 687 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 688 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 689 690 return inner_func
693def parse_date_delta_with_interval( 694 expression_class: t.Type[E], 695) -> t.Callable[[t.List], t.Optional[E]]: 696 def func(args: t.List) -> t.Optional[E]: 697 if len(args) < 2: 698 return None 699 700 interval = args[1] 701 702 if not isinstance(interval, exp.Interval): 703 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 704 705 expression = interval.this 706 if expression and expression.is_string: 707 expression = exp.Literal.number(expression.this) 708 709 return expression_class( 710 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 711 ) 712 713 return func
725def date_add_interval_sql( 726 data_type: str, kind: str 727) -> t.Callable[[Generator, exp.Expression], str]: 728 def func(self: Generator, expression: exp.Expression) -> str: 729 this = self.sql(expression, "this") 730 unit = expression.args.get("unit") 731 unit = exp.var(unit.name.upper() if unit else "DAY") 732 interval = exp.Interval(this=expression.expression, unit=unit) 733 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 734 735 return func
744def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 745 if not expression.expression: 746 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 747 if expression.text("expression").lower() in TIMEZONES: 748 return self.sql( 749 exp.AtTimeZone( 750 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 751 zone=expression.expression, 752 ) 753 ) 754 return self.function_fallback_sql(expression)
795def encode_decode_sql( 796 self: Generator, expression: exp.Expression, name: str, replace: bool = True 797) -> str: 798 charset = expression.args.get("charset") 799 if charset and charset.name.lower() != "utf-8": 800 self.unsupported(f"Expected utf-8 character set, got {charset}.") 801 802 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
815def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 816 cond = expression.this 817 818 if isinstance(expression.this, exp.Distinct): 819 cond = expression.this.expressions[0] 820 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 821 822 return self.func("sum", exp.func("if", cond, 1, 0))
825def trim_sql(self: Generator, expression: exp.Trim) -> str: 826 target = self.sql(expression, "this") 827 trim_type = self.sql(expression, "position") 828 remove_chars = self.sql(expression, "expression") 829 collation = self.sql(expression, "collation") 830 831 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 832 if not remove_chars and not collation: 833 return self.trim_sql(expression) 834 835 trim_type = f"{trim_type} " if trim_type else "" 836 remove_chars = f"{remove_chars} " if remove_chars else "" 837 from_part = "FROM " if trim_type or remove_chars else "" 838 collation = f" COLLATE {collation}" if collation else "" 839 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
860def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 861 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 862 if bad_args: 863 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 864 865 return self.func( 866 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 867 )
870def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 871 bad_args = list( 872 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 873 ) 874 if bad_args: 875 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 876 877 return self.func( 878 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 879 )
882def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 883 names = [] 884 for agg in aggregations: 885 if isinstance(agg, exp.Alias): 886 names.append(agg.alias) 887 else: 888 """ 889 This case corresponds to aggregations without aliases being used as suffixes 890 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 891 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 892 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 893 """ 894 agg_all_unquoted = agg.transform( 895 lambda node: ( 896 exp.Identifier(this=node.name, quoted=False) 897 if isinstance(node, exp.Identifier) 898 else node 899 ) 900 ) 901 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 902 903 return names
943def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 944 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 945 if expression.args.get("count"): 946 self.unsupported(f"Only two arguments are supported in function {name}.") 947 948 return self.func(name, expression.this, expression.expression) 949 950 return _arg_max_or_min_sql
953def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 954 this = expression.this.copy() 955 956 return_type = expression.return_type 957 if return_type.is_type(exp.DataType.Type.DATE): 958 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 959 # can truncate timestamp strings, because some dialects can't cast them to DATE 960 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 961 962 expression.this.replace(exp.cast(this, return_type)) 963 return expression
966def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 967 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 968 if cast and isinstance(expression, exp.TsOrDsAdd): 969 expression = ts_or_ds_add_cast(expression) 970 971 return self.func( 972 name, 973 exp.var(expression.text("unit").upper() or "DAY"), 974 expression.expression, 975 expression.this, 976 ) 977 978 return _delta_sql
981def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 982 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 983 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 984 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 985 986 return self.sql(exp.cast(minus_one_day, "date"))
989def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 990 """Remove table refs from columns in when statements.""" 991 alias = expression.this.args.get("alias") 992 993 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 994 return self.dialect.normalize_identifier(identifier).name if identifier else None 995 996 targets = {normalize(expression.this.this)} 997 998 if alias: 999 targets.add(normalize(alias.this)) 1000 1001 for when in expression.expressions: 1002 when.transform( 1003 lambda node: ( 1004 exp.column(node.this) 1005 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1006 else node 1007 ), 1008 copy=False, 1009 ) 1010 1011 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1014def parse_json_extract_path( 1015 expr_type: t.Type[F], zero_based_indexing: bool = True 1016) -> t.Callable[[t.List], F]: 1017 def _parse_json_extract_path(args: t.List) -> F: 1018 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1019 for arg in args[1:]: 1020 if not isinstance(arg, exp.Literal): 1021 # We use the fallback parser because we can't really transpile non-literals safely 1022 return expr_type.from_arg_list(args) 1023 1024 text = arg.name 1025 if is_int(text): 1026 index = int(text) 1027 segments.append( 1028 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1029 ) 1030 else: 1031 segments.append(exp.JSONPathKey(this=text)) 1032 1033 # This is done to avoid failing in the expression validator due to the arg count 1034 del args[2:] 1035 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1036 1037 return _parse_json_extract_path
1040def json_extract_segments( 1041 name: str, quoted_index: bool = True 1042) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1043 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1044 path = expression.expression 1045 if not isinstance(path, exp.JSONPath): 1046 return rename_func(name)(self, expression) 1047 1048 segments = [] 1049 for segment in path.expressions: 1050 path = self.sql(segment) 1051 if path: 1052 if isinstance(segment, exp.JSONPathPart) and ( 1053 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1054 ): 1055 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1056 1057 segments.append(path) 1058 1059 return self.func(name, expression.this, *segments) 1060 1061 return _json_extract_segments