sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20 21if t.TYPE_CHECKING: 22 from sqlglot._typing import B, E 23 24logger = logging.getLogger("sqlglot") 25 26 27class Dialects(str, Enum): 28 """Dialects supported by SQLGLot.""" 29 30 DIALECT = "" 31 32 BIGQUERY = "bigquery" 33 CLICKHOUSE = "clickhouse" 34 DATABRICKS = "databricks" 35 DORIS = "doris" 36 DRILL = "drill" 37 DUCKDB = "duckdb" 38 HIVE = "hive" 39 MYSQL = "mysql" 40 ORACLE = "oracle" 41 POSTGRES = "postgres" 42 PRESTO = "presto" 43 REDSHIFT = "redshift" 44 SNOWFLAKE = "snowflake" 45 SPARK = "spark" 46 SPARK2 = "spark2" 47 SQLITE = "sqlite" 48 STARROCKS = "starrocks" 49 TABLEAU = "tableau" 50 TERADATA = "teradata" 51 TRINO = "trino" 52 TSQL = "tsql" 53 54 55class NormalizationStrategy(str, AutoName): 56 """Specifies the strategy according to which identifiers should be normalized.""" 57 58 LOWERCASE = auto() 59 """Unquoted identifiers are lowercased.""" 60 61 UPPERCASE = auto() 62 """Unquoted identifiers are uppercased.""" 63 64 CASE_SENSITIVE = auto() 65 """Always case-sensitive, regardless of quotes.""" 66 67 CASE_INSENSITIVE = auto() 68 """Always case-insensitive, regardless of quotes.""" 69 70 71class _Dialect(type): 72 classes: t.Dict[str, t.Type[Dialect]] = {} 73 74 def __eq__(cls, other: t.Any) -> bool: 75 if cls is other: 76 return True 77 if isinstance(other, str): 78 return cls is cls.get(other) 79 if isinstance(other, Dialect): 80 return cls is type(other) 81 82 return False 83 84 def __hash__(cls) -> int: 85 return hash(cls.__name__.lower()) 86 87 @classmethod 88 def __getitem__(cls, key: str) -> t.Type[Dialect]: 89 return cls.classes[key] 90 91 @classmethod 92 def get( 93 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 94 ) -> t.Optional[t.Type[Dialect]]: 95 return cls.classes.get(key, default) 96 97 def __new__(cls, clsname, bases, attrs): 98 klass = super().__new__(cls, clsname, bases, attrs) 99 enum = Dialects.__members__.get(clsname.upper()) 100 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 101 102 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 103 klass.FORMAT_TRIE = ( 104 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 105 ) 106 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 107 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 108 109 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 110 111 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 112 klass.parser_class = getattr(klass, "Parser", Parser) 113 klass.generator_class = getattr(klass, "Generator", Generator) 114 115 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 116 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 117 klass.tokenizer_class._IDENTIFIERS.items() 118 )[0] 119 120 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 121 return next( 122 ( 123 (s, e) 124 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 125 if t == token_type 126 ), 127 (None, None), 128 ) 129 130 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 131 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 132 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 133 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 134 135 if enum not in ("", "bigquery"): 136 klass.generator_class.SELECT_KINDS = () 137 138 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 139 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 140 TokenType.ANTI, 141 TokenType.SEMI, 142 } 143 144 return klass 145 146 147class Dialect(metaclass=_Dialect): 148 INDEX_OFFSET = 0 149 """Determines the base index offset for arrays.""" 150 151 WEEK_OFFSET = 0 152 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 153 154 UNNEST_COLUMN_ONLY = False 155 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 156 157 ALIAS_POST_TABLESAMPLE = False 158 """Determines whether or not the table alias comes after tablesample.""" 159 160 TABLESAMPLE_SIZE_IS_PERCENT = False 161 """Determines whether or not a size in the table sample clause represents percentage.""" 162 163 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 164 """Specifies the strategy according to which identifiers should be normalized.""" 165 166 IDENTIFIERS_CAN_START_WITH_DIGIT = False 167 """Determines whether or not an unquoted identifier can start with a digit.""" 168 169 DPIPE_IS_STRING_CONCAT = True 170 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 171 172 STRICT_STRING_CONCAT = False 173 """Determines whether or not `CONCAT`'s arguments must be strings.""" 174 175 SUPPORTS_USER_DEFINED_TYPES = True 176 """Determines whether or not user-defined data types are supported.""" 177 178 SUPPORTS_SEMI_ANTI_JOIN = True 179 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 180 181 NORMALIZE_FUNCTIONS: bool | str = "upper" 182 """Determines how function names are going to be normalized.""" 183 184 LOG_BASE_FIRST = True 185 """Determines whether the base comes first in the `LOG` function.""" 186 187 NULL_ORDERING = "nulls_are_small" 188 """ 189 Indicates the default `NULL` ordering method to use if not explicitly set. 190 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 191 """ 192 193 TYPED_DIVISION = False 194 """ 195 Whether the behavior of `a / b` depends on the types of `a` and `b`. 196 False means `a / b` is always float division. 197 True means `a / b` is integer division if both `a` and `b` are integers. 198 """ 199 200 SAFE_DIVISION = False 201 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 202 203 CONCAT_COALESCE = False 204 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 205 206 DATE_FORMAT = "'%Y-%m-%d'" 207 DATEINT_FORMAT = "'%Y%m%d'" 208 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 209 210 TIME_MAPPING: t.Dict[str, str] = {} 211 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 212 213 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 214 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 215 FORMAT_MAPPING: t.Dict[str, str] = {} 216 """ 217 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 218 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 219 """ 220 221 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 222 """Mapping of an unescaped escape sequence to the corresponding character.""" 223 224 PSEUDOCOLUMNS: t.Set[str] = set() 225 """ 226 Columns that are auto-generated by the engine corresponding to this dialect. 227 For example, such columns may be excluded from `SELECT *` queries. 228 """ 229 230 PREFER_CTE_ALIAS_COLUMN = False 231 """ 232 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 233 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 234 any projection aliases in the subquery. 235 236 For example, 237 WITH y(c) AS ( 238 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 239 ) SELECT c FROM y; 240 241 will be rewritten as 242 243 WITH y(c) AS ( 244 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 245 ) SELECT c FROM y; 246 """ 247 248 # --- Autofilled --- 249 250 tokenizer_class = Tokenizer 251 parser_class = Parser 252 generator_class = Generator 253 254 # A trie of the time_mapping keys 255 TIME_TRIE: t.Dict = {} 256 FORMAT_TRIE: t.Dict = {} 257 258 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 259 INVERSE_TIME_TRIE: t.Dict = {} 260 261 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 262 263 # Delimiters for string literals and identifiers 264 QUOTE_START = "'" 265 QUOTE_END = "'" 266 IDENTIFIER_START = '"' 267 IDENTIFIER_END = '"' 268 269 # Delimiters for bit, hex, byte and unicode literals 270 BIT_START: t.Optional[str] = None 271 BIT_END: t.Optional[str] = None 272 HEX_START: t.Optional[str] = None 273 HEX_END: t.Optional[str] = None 274 BYTE_START: t.Optional[str] = None 275 BYTE_END: t.Optional[str] = None 276 UNICODE_START: t.Optional[str] = None 277 UNICODE_END: t.Optional[str] = None 278 279 @classmethod 280 def get_or_raise(cls, dialect: DialectType) -> Dialect: 281 """ 282 Look up a dialect in the global dialect registry and return it if it exists. 283 284 Args: 285 dialect: The target dialect. If this is a string, it can be optionally followed by 286 additional key-value pairs that are separated by commas and are used to specify 287 dialect settings, such as whether the dialect's identifiers are case-sensitive. 288 289 Example: 290 >>> dialect = dialect_class = get_or_raise("duckdb") 291 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 292 293 Returns: 294 The corresponding Dialect instance. 295 """ 296 297 if not dialect: 298 return cls() 299 if isinstance(dialect, _Dialect): 300 return dialect() 301 if isinstance(dialect, Dialect): 302 return dialect 303 if isinstance(dialect, str): 304 try: 305 dialect_name, *kv_pairs = dialect.split(",") 306 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 307 except ValueError: 308 raise ValueError( 309 f"Invalid dialect format: '{dialect}'. " 310 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 311 ) 312 313 result = cls.get(dialect_name.strip()) 314 if not result: 315 from difflib import get_close_matches 316 317 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 318 if similar: 319 similar = f" Did you mean {similar}?" 320 321 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 322 323 return result(**kwargs) 324 325 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 326 327 @classmethod 328 def format_time( 329 cls, expression: t.Optional[str | exp.Expression] 330 ) -> t.Optional[exp.Expression]: 331 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 332 if isinstance(expression, str): 333 return exp.Literal.string( 334 # the time formats are quoted 335 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 336 ) 337 338 if expression and expression.is_string: 339 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 340 341 return expression 342 343 def __init__(self, **kwargs) -> None: 344 normalization_strategy = kwargs.get("normalization_strategy") 345 346 if normalization_strategy is None: 347 self.normalization_strategy = self.NORMALIZATION_STRATEGY 348 else: 349 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 350 351 def __eq__(self, other: t.Any) -> bool: 352 # Does not currently take dialect state into account 353 return type(self) == other 354 355 def __hash__(self) -> int: 356 # Does not currently take dialect state into account 357 return hash(type(self)) 358 359 def normalize_identifier(self, expression: E) -> E: 360 """ 361 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 362 363 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 364 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 365 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 366 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 367 368 There are also dialects like Spark, which are case-insensitive even when quotes are 369 present, and dialects like MySQL, whose resolution rules match those employed by the 370 underlying operating system, for example they may always be case-sensitive in Linux. 371 372 Finally, the normalization behavior of some engines can even be controlled through flags, 373 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 374 375 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 376 that it can analyze queries in the optimizer and successfully capture their semantics. 377 """ 378 if ( 379 isinstance(expression, exp.Identifier) 380 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 381 and ( 382 not expression.quoted 383 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 384 ) 385 ): 386 expression.set( 387 "this", 388 ( 389 expression.this.upper() 390 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 391 else expression.this.lower() 392 ), 393 ) 394 395 return expression 396 397 def case_sensitive(self, text: str) -> bool: 398 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 399 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 400 return False 401 402 unsafe = ( 403 str.islower 404 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 405 else str.isupper 406 ) 407 return any(unsafe(char) for char in text) 408 409 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 410 """Checks if text can be identified given an identify option. 411 412 Args: 413 text: The text to check. 414 identify: 415 `"always"` or `True`: Always returns `True`. 416 `"safe"`: Only returns `True` if the identifier is case-insensitive. 417 418 Returns: 419 Whether or not the given text can be identified. 420 """ 421 if identify is True or identify == "always": 422 return True 423 424 if identify == "safe": 425 return not self.case_sensitive(text) 426 427 return False 428 429 def quote_identifier(self, expression: E, identify: bool = True) -> E: 430 """ 431 Adds quotes to a given identifier. 432 433 Args: 434 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 435 identify: If set to `False`, the quotes will only be added if the identifier is deemed 436 "unsafe", with respect to its characters and this dialect's normalization strategy. 437 """ 438 if isinstance(expression, exp.Identifier): 439 name = expression.this 440 expression.set( 441 "quoted", 442 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 443 ) 444 445 return expression 446 447 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 448 if isinstance(path, exp.Literal): 449 path_text = path.name 450 if path.is_number: 451 path_text = f"[{path_text}]" 452 453 try: 454 return parse_json_path(path_text) 455 except ParseError as e: 456 logger.warning(f"Invalid JSON path syntax. {str(e)}") 457 458 return path 459 460 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 461 return self.parser(**opts).parse(self.tokenize(sql), sql) 462 463 def parse_into( 464 self, expression_type: exp.IntoType, sql: str, **opts 465 ) -> t.List[t.Optional[exp.Expression]]: 466 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 467 468 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 469 return self.generator(**opts).generate(expression, copy=copy) 470 471 def transpile(self, sql: str, **opts) -> t.List[str]: 472 return [ 473 self.generate(expression, copy=False, **opts) if expression else "" 474 for expression in self.parse(sql) 475 ] 476 477 def tokenize(self, sql: str) -> t.List[Token]: 478 return self.tokenizer.tokenize(sql) 479 480 @property 481 def tokenizer(self) -> Tokenizer: 482 if not hasattr(self, "_tokenizer"): 483 self._tokenizer = self.tokenizer_class(dialect=self) 484 return self._tokenizer 485 486 def parser(self, **opts) -> Parser: 487 return self.parser_class(dialect=self, **opts) 488 489 def generator(self, **opts) -> Generator: 490 return self.generator_class(dialect=self, **opts) 491 492 493DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 494 495 496def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 497 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 498 499 500def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 501 if expression.args.get("accuracy"): 502 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 503 return self.func("APPROX_COUNT_DISTINCT", expression.this) 504 505 506def if_sql( 507 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 508) -> t.Callable[[Generator, exp.If], str]: 509 def _if_sql(self: Generator, expression: exp.If) -> str: 510 return self.func( 511 name, 512 expression.this, 513 expression.args.get("true"), 514 expression.args.get("false") or false_value, 515 ) 516 517 return _if_sql 518 519 520def arrow_json_extract_sql( 521 self: Generator, expression: exp.JSONExtract | exp.JSONExtractScalar 522) -> str: 523 this = expression.this 524 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 525 this.replace(exp.cast(this, "json")) 526 527 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 528 529 530def inline_array_sql(self: Generator, expression: exp.Array) -> str: 531 return f"[{self.expressions(expression, flat=True)}]" 532 533 534def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 535 return self.like_sql( 536 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 537 ) 538 539 540def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 541 zone = self.sql(expression, "this") 542 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 543 544 545def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 546 if expression.args.get("recursive"): 547 self.unsupported("Recursive CTEs are unsupported") 548 expression.args["recursive"] = False 549 return self.with_sql(expression) 550 551 552def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 553 n = self.sql(expression, "this") 554 d = self.sql(expression, "expression") 555 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 556 557 558def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 559 self.unsupported("TABLESAMPLE unsupported") 560 return self.sql(expression.this) 561 562 563def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 564 self.unsupported("PIVOT unsupported") 565 return "" 566 567 568def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 569 return self.cast_sql(expression) 570 571 572def no_comment_column_constraint_sql( 573 self: Generator, expression: exp.CommentColumnConstraint 574) -> str: 575 self.unsupported("CommentColumnConstraint unsupported") 576 return "" 577 578 579def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 580 self.unsupported("MAP_FROM_ENTRIES unsupported") 581 return "" 582 583 584def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 585 this = self.sql(expression, "this") 586 substr = self.sql(expression, "substr") 587 position = self.sql(expression, "position") 588 if position: 589 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 590 return f"STRPOS({this}, {substr})" 591 592 593def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 594 return ( 595 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 596 ) 597 598 599def var_map_sql( 600 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 601) -> str: 602 keys = expression.args["keys"] 603 values = expression.args["values"] 604 605 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 606 self.unsupported("Cannot convert array columns into map.") 607 return self.func(map_func_name, keys, values) 608 609 args = [] 610 for key, value in zip(keys.expressions, values.expressions): 611 args.append(self.sql(key)) 612 args.append(self.sql(value)) 613 614 return self.func(map_func_name, *args) 615 616 617def format_time_lambda( 618 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 619) -> t.Callable[[t.List], E]: 620 """Helper used for time expressions. 621 622 Args: 623 exp_class: the expression class to instantiate. 624 dialect: target sql dialect. 625 default: the default format, True being time. 626 627 Returns: 628 A callable that can be used to return the appropriately formatted time expression. 629 """ 630 631 def _format_time(args: t.List): 632 return exp_class( 633 this=seq_get(args, 0), 634 format=Dialect[dialect].format_time( 635 seq_get(args, 1) 636 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 637 ), 638 ) 639 640 return _format_time 641 642 643def time_format( 644 dialect: DialectType = None, 645) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 646 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 647 """ 648 Returns the time format for a given expression, unless it's equivalent 649 to the default time format of the dialect of interest. 650 """ 651 time_format = self.format_time(expression) 652 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 653 654 return _time_format 655 656 657def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 658 """ 659 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 660 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 661 columns are removed from the create statement. 662 """ 663 has_schema = isinstance(expression.this, exp.Schema) 664 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 665 666 if has_schema and is_partitionable: 667 prop = expression.find(exp.PartitionedByProperty) 668 if prop and prop.this and not isinstance(prop.this, exp.Schema): 669 schema = expression.this 670 columns = {v.name.upper() for v in prop.this.expressions} 671 partitions = [col for col in schema.expressions if col.name.upper() in columns] 672 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 673 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 674 expression.set("this", schema) 675 676 return self.create_sql(expression) 677 678 679def parse_date_delta( 680 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 681) -> t.Callable[[t.List], E]: 682 def inner_func(args: t.List) -> E: 683 unit_based = len(args) == 3 684 this = args[2] if unit_based else seq_get(args, 0) 685 unit = args[0] if unit_based else exp.Literal.string("DAY") 686 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 687 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 688 689 return inner_func 690 691 692def parse_date_delta_with_interval( 693 expression_class: t.Type[E], 694) -> t.Callable[[t.List], t.Optional[E]]: 695 def func(args: t.List) -> t.Optional[E]: 696 if len(args) < 2: 697 return None 698 699 interval = args[1] 700 701 if not isinstance(interval, exp.Interval): 702 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 703 704 expression = interval.this 705 if expression and expression.is_string: 706 expression = exp.Literal.number(expression.this) 707 708 return expression_class( 709 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 710 ) 711 712 return func 713 714 715def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 716 unit = seq_get(args, 0) 717 this = seq_get(args, 1) 718 719 if isinstance(this, exp.Cast) and this.is_type("date"): 720 return exp.DateTrunc(unit=unit, this=this) 721 return exp.TimestampTrunc(this=this, unit=unit) 722 723 724def date_add_interval_sql( 725 data_type: str, kind: str 726) -> t.Callable[[Generator, exp.Expression], str]: 727 def func(self: Generator, expression: exp.Expression) -> str: 728 this = self.sql(expression, "this") 729 unit = expression.args.get("unit") 730 unit = exp.var(unit.name.upper() if unit else "DAY") 731 interval = exp.Interval(this=expression.expression, unit=unit) 732 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 733 734 return func 735 736 737def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 738 return self.func( 739 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 740 ) 741 742 743def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 744 if not expression.expression: 745 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 746 if expression.text("expression").lower() in TIMEZONES: 747 return self.sql( 748 exp.AtTimeZone( 749 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 750 zone=expression.expression, 751 ) 752 ) 753 return self.function_fallback_sql(expression) 754 755 756def locate_to_strposition(args: t.List) -> exp.Expression: 757 return exp.StrPosition( 758 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 759 ) 760 761 762def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 763 return self.func( 764 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 765 ) 766 767 768def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 769 return self.sql( 770 exp.Substring( 771 this=expression.this, start=exp.Literal.number(1), length=expression.expression 772 ) 773 ) 774 775 776def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 777 return self.sql( 778 exp.Substring( 779 this=expression.this, 780 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 781 ) 782 ) 783 784 785def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 786 return self.sql(exp.cast(expression.this, "timestamp")) 787 788 789def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 790 return self.sql(exp.cast(expression.this, "date")) 791 792 793# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 794def encode_decode_sql( 795 self: Generator, expression: exp.Expression, name: str, replace: bool = True 796) -> str: 797 charset = expression.args.get("charset") 798 if charset and charset.name.lower() != "utf-8": 799 self.unsupported(f"Expected utf-8 character set, got {charset}.") 800 801 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 802 803 804def min_or_least(self: Generator, expression: exp.Min) -> str: 805 name = "LEAST" if expression.expressions else "MIN" 806 return rename_func(name)(self, expression) 807 808 809def max_or_greatest(self: Generator, expression: exp.Max) -> str: 810 name = "GREATEST" if expression.expressions else "MAX" 811 return rename_func(name)(self, expression) 812 813 814def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 815 cond = expression.this 816 817 if isinstance(expression.this, exp.Distinct): 818 cond = expression.this.expressions[0] 819 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 820 821 return self.func("sum", exp.func("if", cond, 1, 0)) 822 823 824def trim_sql(self: Generator, expression: exp.Trim) -> str: 825 target = self.sql(expression, "this") 826 trim_type = self.sql(expression, "position") 827 remove_chars = self.sql(expression, "expression") 828 collation = self.sql(expression, "collation") 829 830 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 831 if not remove_chars and not collation: 832 return self.trim_sql(expression) 833 834 trim_type = f"{trim_type} " if trim_type else "" 835 remove_chars = f"{remove_chars} " if remove_chars else "" 836 from_part = "FROM " if trim_type or remove_chars else "" 837 collation = f" COLLATE {collation}" if collation else "" 838 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 839 840 841def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 842 return self.func("STRPTIME", expression.this, self.format_time(expression)) 843 844 845def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 846 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 847 848 849def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 850 delim, *rest_args = expression.expressions 851 return self.sql( 852 reduce( 853 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 854 rest_args, 855 ) 856 ) 857 858 859def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 860 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 861 if bad_args: 862 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 863 864 return self.func( 865 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 866 ) 867 868 869def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 870 bad_args = list( 871 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 872 ) 873 if bad_args: 874 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 875 876 return self.func( 877 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 878 ) 879 880 881def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 882 names = [] 883 for agg in aggregations: 884 if isinstance(agg, exp.Alias): 885 names.append(agg.alias) 886 else: 887 """ 888 This case corresponds to aggregations without aliases being used as suffixes 889 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 890 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 891 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 892 """ 893 agg_all_unquoted = agg.transform( 894 lambda node: ( 895 exp.Identifier(this=node.name, quoted=False) 896 if isinstance(node, exp.Identifier) 897 else node 898 ) 899 ) 900 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 901 902 return names 903 904 905def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 906 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 907 908 909# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 910def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 911 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 912 913 914def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 915 return self.func("MAX", expression.this) 916 917 918def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 919 a = self.sql(expression.left) 920 b = self.sql(expression.right) 921 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 922 923 924def is_parse_json(expression: exp.Expression) -> bool: 925 return isinstance(expression, exp.ParseJSON) or ( 926 isinstance(expression, exp.Cast) and expression.is_type("json") 927 ) 928 929 930def isnull_to_is_null(args: t.List) -> exp.Expression: 931 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 932 933 934def generatedasidentitycolumnconstraint_sql( 935 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 936) -> str: 937 start = self.sql(expression, "start") or "1" 938 increment = self.sql(expression, "increment") or "1" 939 return f"IDENTITY({start}, {increment})" 940 941 942def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 943 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 944 if expression.args.get("count"): 945 self.unsupported(f"Only two arguments are supported in function {name}.") 946 947 return self.func(name, expression.this, expression.expression) 948 949 return _arg_max_or_min_sql 950 951 952def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 953 this = expression.this.copy() 954 955 return_type = expression.return_type 956 if return_type.is_type(exp.DataType.Type.DATE): 957 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 958 # can truncate timestamp strings, because some dialects can't cast them to DATE 959 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 960 961 expression.this.replace(exp.cast(this, return_type)) 962 return expression 963 964 965def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 966 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 967 if cast and isinstance(expression, exp.TsOrDsAdd): 968 expression = ts_or_ds_add_cast(expression) 969 970 return self.func( 971 name, 972 exp.var(expression.text("unit").upper() or "DAY"), 973 expression.expression, 974 expression.this, 975 ) 976 977 return _delta_sql 978 979 980def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 981 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 982 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 983 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 984 985 return self.sql(exp.cast(minus_one_day, "date")) 986 987 988def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 989 """Remove table refs from columns in when statements.""" 990 alias = expression.this.args.get("alias") 991 992 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 993 return self.dialect.normalize_identifier(identifier).name if identifier else None 994 995 targets = {normalize(expression.this.this)} 996 997 if alias: 998 targets.add(normalize(alias.this)) 999 1000 for when in expression.expressions: 1001 when.transform( 1002 lambda node: ( 1003 exp.column(node.this) 1004 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1005 else node 1006 ), 1007 copy=False, 1008 ) 1009 1010 return self.merge_sql(expression) 1011 1012 1013def parse_json_extract_path( 1014 expr_type: t.Type[E], 1015 supports_null_if_invalid: bool = False, 1016) -> t.Callable[[t.List], E]: 1017 def _parse_json_extract_path(args: t.List) -> E: 1018 null_if_invalid = None 1019 1020 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1021 for arg in args[1:]: 1022 if isinstance(arg, exp.Literal): 1023 text = arg.name 1024 if is_int(text): 1025 segments.append(exp.JSONPathSubscript(this=int(text))) 1026 else: 1027 segments.append(exp.JSONPathKey(this=text)) 1028 elif supports_null_if_invalid: 1029 null_if_invalid = arg 1030 1031 this = seq_get(args, 0) 1032 jsonpath = exp.JSONPath(expressions=segments) 1033 1034 # This is done to avoid failing in the expression validator due to the arg count 1035 del args[2:] 1036 1037 if expr_type is exp.JSONExtractScalar: 1038 return expr_type(this=this, expression=jsonpath, null_if_invalid=null_if_invalid) 1039 1040 return expr_type(this=this, expression=jsonpath) 1041 1042 return _parse_json_extract_path 1043 1044 1045def json_path_segments(self: Generator, expression: exp.JSONPath) -> t.List[str]: 1046 segments = [] 1047 for segment in expression.expressions: 1048 path = self.sql(segment) 1049 if path: 1050 segments.append(f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}") 1051 1052 return segments
28class Dialects(str, Enum): 29 """Dialects supported by SQLGLot.""" 30 31 DIALECT = "" 32 33 BIGQUERY = "bigquery" 34 CLICKHOUSE = "clickhouse" 35 DATABRICKS = "databricks" 36 DORIS = "doris" 37 DRILL = "drill" 38 DUCKDB = "duckdb" 39 HIVE = "hive" 40 MYSQL = "mysql" 41 ORACLE = "oracle" 42 POSTGRES = "postgres" 43 PRESTO = "presto" 44 REDSHIFT = "redshift" 45 SNOWFLAKE = "snowflake" 46 SPARK = "spark" 47 SPARK2 = "spark2" 48 SQLITE = "sqlite" 49 STARROCKS = "starrocks" 50 TABLEAU = "tableau" 51 TERADATA = "teradata" 52 TRINO = "trino" 53 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
56class NormalizationStrategy(str, AutoName): 57 """Specifies the strategy according to which identifiers should be normalized.""" 58 59 LOWERCASE = auto() 60 """Unquoted identifiers are lowercased.""" 61 62 UPPERCASE = auto() 63 """Unquoted identifiers are uppercased.""" 64 65 CASE_SENSITIVE = auto() 66 """Always case-sensitive, regardless of quotes.""" 67 68 CASE_INSENSITIVE = auto() 69 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
148class Dialect(metaclass=_Dialect): 149 INDEX_OFFSET = 0 150 """Determines the base index offset for arrays.""" 151 152 WEEK_OFFSET = 0 153 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 154 155 UNNEST_COLUMN_ONLY = False 156 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 157 158 ALIAS_POST_TABLESAMPLE = False 159 """Determines whether or not the table alias comes after tablesample.""" 160 161 TABLESAMPLE_SIZE_IS_PERCENT = False 162 """Determines whether or not a size in the table sample clause represents percentage.""" 163 164 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 165 """Specifies the strategy according to which identifiers should be normalized.""" 166 167 IDENTIFIERS_CAN_START_WITH_DIGIT = False 168 """Determines whether or not an unquoted identifier can start with a digit.""" 169 170 DPIPE_IS_STRING_CONCAT = True 171 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 172 173 STRICT_STRING_CONCAT = False 174 """Determines whether or not `CONCAT`'s arguments must be strings.""" 175 176 SUPPORTS_USER_DEFINED_TYPES = True 177 """Determines whether or not user-defined data types are supported.""" 178 179 SUPPORTS_SEMI_ANTI_JOIN = True 180 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 181 182 NORMALIZE_FUNCTIONS: bool | str = "upper" 183 """Determines how function names are going to be normalized.""" 184 185 LOG_BASE_FIRST = True 186 """Determines whether the base comes first in the `LOG` function.""" 187 188 NULL_ORDERING = "nulls_are_small" 189 """ 190 Indicates the default `NULL` ordering method to use if not explicitly set. 191 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 192 """ 193 194 TYPED_DIVISION = False 195 """ 196 Whether the behavior of `a / b` depends on the types of `a` and `b`. 197 False means `a / b` is always float division. 198 True means `a / b` is integer division if both `a` and `b` are integers. 199 """ 200 201 SAFE_DIVISION = False 202 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 203 204 CONCAT_COALESCE = False 205 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 206 207 DATE_FORMAT = "'%Y-%m-%d'" 208 DATEINT_FORMAT = "'%Y%m%d'" 209 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 210 211 TIME_MAPPING: t.Dict[str, str] = {} 212 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 213 214 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 215 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 216 FORMAT_MAPPING: t.Dict[str, str] = {} 217 """ 218 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 219 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 220 """ 221 222 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 223 """Mapping of an unescaped escape sequence to the corresponding character.""" 224 225 PSEUDOCOLUMNS: t.Set[str] = set() 226 """ 227 Columns that are auto-generated by the engine corresponding to this dialect. 228 For example, such columns may be excluded from `SELECT *` queries. 229 """ 230 231 PREFER_CTE_ALIAS_COLUMN = False 232 """ 233 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 234 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 235 any projection aliases in the subquery. 236 237 For example, 238 WITH y(c) AS ( 239 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 240 ) SELECT c FROM y; 241 242 will be rewritten as 243 244 WITH y(c) AS ( 245 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 246 ) SELECT c FROM y; 247 """ 248 249 # --- Autofilled --- 250 251 tokenizer_class = Tokenizer 252 parser_class = Parser 253 generator_class = Generator 254 255 # A trie of the time_mapping keys 256 TIME_TRIE: t.Dict = {} 257 FORMAT_TRIE: t.Dict = {} 258 259 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 260 INVERSE_TIME_TRIE: t.Dict = {} 261 262 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 263 264 # Delimiters for string literals and identifiers 265 QUOTE_START = "'" 266 QUOTE_END = "'" 267 IDENTIFIER_START = '"' 268 IDENTIFIER_END = '"' 269 270 # Delimiters for bit, hex, byte and unicode literals 271 BIT_START: t.Optional[str] = None 272 BIT_END: t.Optional[str] = None 273 HEX_START: t.Optional[str] = None 274 HEX_END: t.Optional[str] = None 275 BYTE_START: t.Optional[str] = None 276 BYTE_END: t.Optional[str] = None 277 UNICODE_START: t.Optional[str] = None 278 UNICODE_END: t.Optional[str] = None 279 280 @classmethod 281 def get_or_raise(cls, dialect: DialectType) -> Dialect: 282 """ 283 Look up a dialect in the global dialect registry and return it if it exists. 284 285 Args: 286 dialect: The target dialect. If this is a string, it can be optionally followed by 287 additional key-value pairs that are separated by commas and are used to specify 288 dialect settings, such as whether the dialect's identifiers are case-sensitive. 289 290 Example: 291 >>> dialect = dialect_class = get_or_raise("duckdb") 292 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 293 294 Returns: 295 The corresponding Dialect instance. 296 """ 297 298 if not dialect: 299 return cls() 300 if isinstance(dialect, _Dialect): 301 return dialect() 302 if isinstance(dialect, Dialect): 303 return dialect 304 if isinstance(dialect, str): 305 try: 306 dialect_name, *kv_pairs = dialect.split(",") 307 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 308 except ValueError: 309 raise ValueError( 310 f"Invalid dialect format: '{dialect}'. " 311 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 312 ) 313 314 result = cls.get(dialect_name.strip()) 315 if not result: 316 from difflib import get_close_matches 317 318 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 319 if similar: 320 similar = f" Did you mean {similar}?" 321 322 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 323 324 return result(**kwargs) 325 326 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 327 328 @classmethod 329 def format_time( 330 cls, expression: t.Optional[str | exp.Expression] 331 ) -> t.Optional[exp.Expression]: 332 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 333 if isinstance(expression, str): 334 return exp.Literal.string( 335 # the time formats are quoted 336 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 337 ) 338 339 if expression and expression.is_string: 340 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 341 342 return expression 343 344 def __init__(self, **kwargs) -> None: 345 normalization_strategy = kwargs.get("normalization_strategy") 346 347 if normalization_strategy is None: 348 self.normalization_strategy = self.NORMALIZATION_STRATEGY 349 else: 350 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 351 352 def __eq__(self, other: t.Any) -> bool: 353 # Does not currently take dialect state into account 354 return type(self) == other 355 356 def __hash__(self) -> int: 357 # Does not currently take dialect state into account 358 return hash(type(self)) 359 360 def normalize_identifier(self, expression: E) -> E: 361 """ 362 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 363 364 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 365 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 366 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 367 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 368 369 There are also dialects like Spark, which are case-insensitive even when quotes are 370 present, and dialects like MySQL, whose resolution rules match those employed by the 371 underlying operating system, for example they may always be case-sensitive in Linux. 372 373 Finally, the normalization behavior of some engines can even be controlled through flags, 374 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 375 376 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 377 that it can analyze queries in the optimizer and successfully capture their semantics. 378 """ 379 if ( 380 isinstance(expression, exp.Identifier) 381 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 382 and ( 383 not expression.quoted 384 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 385 ) 386 ): 387 expression.set( 388 "this", 389 ( 390 expression.this.upper() 391 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 392 else expression.this.lower() 393 ), 394 ) 395 396 return expression 397 398 def case_sensitive(self, text: str) -> bool: 399 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 400 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 401 return False 402 403 unsafe = ( 404 str.islower 405 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 406 else str.isupper 407 ) 408 return any(unsafe(char) for char in text) 409 410 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 411 """Checks if text can be identified given an identify option. 412 413 Args: 414 text: The text to check. 415 identify: 416 `"always"` or `True`: Always returns `True`. 417 `"safe"`: Only returns `True` if the identifier is case-insensitive. 418 419 Returns: 420 Whether or not the given text can be identified. 421 """ 422 if identify is True or identify == "always": 423 return True 424 425 if identify == "safe": 426 return not self.case_sensitive(text) 427 428 return False 429 430 def quote_identifier(self, expression: E, identify: bool = True) -> E: 431 """ 432 Adds quotes to a given identifier. 433 434 Args: 435 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 436 identify: If set to `False`, the quotes will only be added if the identifier is deemed 437 "unsafe", with respect to its characters and this dialect's normalization strategy. 438 """ 439 if isinstance(expression, exp.Identifier): 440 name = expression.this 441 expression.set( 442 "quoted", 443 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 444 ) 445 446 return expression 447 448 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 449 if isinstance(path, exp.Literal): 450 path_text = path.name 451 if path.is_number: 452 path_text = f"[{path_text}]" 453 454 try: 455 return parse_json_path(path_text) 456 except ParseError as e: 457 logger.warning(f"Invalid JSON path syntax. {str(e)}") 458 459 return path 460 461 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 462 return self.parser(**opts).parse(self.tokenize(sql), sql) 463 464 def parse_into( 465 self, expression_type: exp.IntoType, sql: str, **opts 466 ) -> t.List[t.Optional[exp.Expression]]: 467 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 468 469 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 470 return self.generator(**opts).generate(expression, copy=copy) 471 472 def transpile(self, sql: str, **opts) -> t.List[str]: 473 return [ 474 self.generate(expression, copy=False, **opts) if expression else "" 475 for expression in self.parse(sql) 476 ] 477 478 def tokenize(self, sql: str) -> t.List[Token]: 479 return self.tokenizer.tokenize(sql) 480 481 @property 482 def tokenizer(self) -> Tokenizer: 483 if not hasattr(self, "_tokenizer"): 484 self._tokenizer = self.tokenizer_class(dialect=self) 485 return self._tokenizer 486 487 def parser(self, **opts) -> Parser: 488 return self.parser_class(dialect=self, **opts) 489 490 def generator(self, **opts) -> Generator: 491 return self.generator_class(dialect=self, **opts)
344 def __init__(self, **kwargs) -> None: 345 normalization_strategy = kwargs.get("normalization_strategy") 346 347 if normalization_strategy is None: 348 self.normalization_strategy = self.NORMALIZATION_STRATEGY 349 else: 350 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Determines whether or not UNNEST
table aliases are treated as column aliases.
Determines whether or not a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines whether or not an unquoted identifier can start with a digit.
Determines whether or not the DPIPE token (||
) is a string concatenation operator.
Indicates the default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
Determines whether division by zero throws an error (False
) or returns NULL (True
).
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
format.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
280 @classmethod 281 def get_or_raise(cls, dialect: DialectType) -> Dialect: 282 """ 283 Look up a dialect in the global dialect registry and return it if it exists. 284 285 Args: 286 dialect: The target dialect. If this is a string, it can be optionally followed by 287 additional key-value pairs that are separated by commas and are used to specify 288 dialect settings, such as whether the dialect's identifiers are case-sensitive. 289 290 Example: 291 >>> dialect = dialect_class = get_or_raise("duckdb") 292 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 293 294 Returns: 295 The corresponding Dialect instance. 296 """ 297 298 if not dialect: 299 return cls() 300 if isinstance(dialect, _Dialect): 301 return dialect() 302 if isinstance(dialect, Dialect): 303 return dialect 304 if isinstance(dialect, str): 305 try: 306 dialect_name, *kv_pairs = dialect.split(",") 307 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 308 except ValueError: 309 raise ValueError( 310 f"Invalid dialect format: '{dialect}'. " 311 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 312 ) 313 314 result = cls.get(dialect_name.strip()) 315 if not result: 316 from difflib import get_close_matches 317 318 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 319 if similar: 320 similar = f" Did you mean {similar}?" 321 322 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 323 324 return result(**kwargs) 325 326 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
328 @classmethod 329 def format_time( 330 cls, expression: t.Optional[str | exp.Expression] 331 ) -> t.Optional[exp.Expression]: 332 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 333 if isinstance(expression, str): 334 return exp.Literal.string( 335 # the time formats are quoted 336 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 337 ) 338 339 if expression and expression.is_string: 340 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 341 342 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
360 def normalize_identifier(self, expression: E) -> E: 361 """ 362 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 363 364 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 365 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 366 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 367 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 368 369 There are also dialects like Spark, which are case-insensitive even when quotes are 370 present, and dialects like MySQL, whose resolution rules match those employed by the 371 underlying operating system, for example they may always be case-sensitive in Linux. 372 373 Finally, the normalization behavior of some engines can even be controlled through flags, 374 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 375 376 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 377 that it can analyze queries in the optimizer and successfully capture their semantics. 378 """ 379 if ( 380 isinstance(expression, exp.Identifier) 381 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 382 and ( 383 not expression.quoted 384 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 385 ) 386 ): 387 expression.set( 388 "this", 389 ( 390 expression.this.upper() 391 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 392 else expression.this.lower() 393 ), 394 ) 395 396 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
398 def case_sensitive(self, text: str) -> bool: 399 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 400 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 401 return False 402 403 unsafe = ( 404 str.islower 405 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 406 else str.isupper 407 ) 408 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
410 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 411 """Checks if text can be identified given an identify option. 412 413 Args: 414 text: The text to check. 415 identify: 416 `"always"` or `True`: Always returns `True`. 417 `"safe"`: Only returns `True` if the identifier is case-insensitive. 418 419 Returns: 420 Whether or not the given text can be identified. 421 """ 422 if identify is True or identify == "always": 423 return True 424 425 if identify == "safe": 426 return not self.case_sensitive(text) 427 428 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
430 def quote_identifier(self, expression: E, identify: bool = True) -> E: 431 """ 432 Adds quotes to a given identifier. 433 434 Args: 435 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 436 identify: If set to `False`, the quotes will only be added if the identifier is deemed 437 "unsafe", with respect to its characters and this dialect's normalization strategy. 438 """ 439 if isinstance(expression, exp.Identifier): 440 name = expression.this 441 expression.set( 442 "quoted", 443 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 444 ) 445 446 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
448 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 449 if isinstance(path, exp.Literal): 450 path_text = path.name 451 if path.is_number: 452 path_text = f"[{path_text}]" 453 454 try: 455 return parse_json_path(path_text) 456 except ParseError as e: 457 logger.warning(f"Invalid JSON path syntax. {str(e)}") 458 459 return path
507def if_sql( 508 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 509) -> t.Callable[[Generator, exp.If], str]: 510 def _if_sql(self: Generator, expression: exp.If) -> str: 511 return self.func( 512 name, 513 expression.this, 514 expression.args.get("true"), 515 expression.args.get("false") or false_value, 516 ) 517 518 return _if_sql
521def arrow_json_extract_sql( 522 self: Generator, expression: exp.JSONExtract | exp.JSONExtractScalar 523) -> str: 524 this = expression.this 525 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 526 this.replace(exp.cast(this, "json")) 527 528 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
585def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 586 this = self.sql(expression, "this") 587 substr = self.sql(expression, "substr") 588 position = self.sql(expression, "position") 589 if position: 590 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 591 return f"STRPOS({this}, {substr})"
600def var_map_sql( 601 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 602) -> str: 603 keys = expression.args["keys"] 604 values = expression.args["values"] 605 606 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 607 self.unsupported("Cannot convert array columns into map.") 608 return self.func(map_func_name, keys, values) 609 610 args = [] 611 for key, value in zip(keys.expressions, values.expressions): 612 args.append(self.sql(key)) 613 args.append(self.sql(value)) 614 615 return self.func(map_func_name, *args)
618def format_time_lambda( 619 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 620) -> t.Callable[[t.List], E]: 621 """Helper used for time expressions. 622 623 Args: 624 exp_class: the expression class to instantiate. 625 dialect: target sql dialect. 626 default: the default format, True being time. 627 628 Returns: 629 A callable that can be used to return the appropriately formatted time expression. 630 """ 631 632 def _format_time(args: t.List): 633 return exp_class( 634 this=seq_get(args, 0), 635 format=Dialect[dialect].format_time( 636 seq_get(args, 1) 637 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 638 ), 639 ) 640 641 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
644def time_format( 645 dialect: DialectType = None, 646) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 647 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 648 """ 649 Returns the time format for a given expression, unless it's equivalent 650 to the default time format of the dialect of interest. 651 """ 652 time_format = self.format_time(expression) 653 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 654 655 return _time_format
658def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 659 """ 660 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 661 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 662 columns are removed from the create statement. 663 """ 664 has_schema = isinstance(expression.this, exp.Schema) 665 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 666 667 if has_schema and is_partitionable: 668 prop = expression.find(exp.PartitionedByProperty) 669 if prop and prop.this and not isinstance(prop.this, exp.Schema): 670 schema = expression.this 671 columns = {v.name.upper() for v in prop.this.expressions} 672 partitions = [col for col in schema.expressions if col.name.upper() in columns] 673 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 674 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 675 expression.set("this", schema) 676 677 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
680def parse_date_delta( 681 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 682) -> t.Callable[[t.List], E]: 683 def inner_func(args: t.List) -> E: 684 unit_based = len(args) == 3 685 this = args[2] if unit_based else seq_get(args, 0) 686 unit = args[0] if unit_based else exp.Literal.string("DAY") 687 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 688 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 689 690 return inner_func
693def parse_date_delta_with_interval( 694 expression_class: t.Type[E], 695) -> t.Callable[[t.List], t.Optional[E]]: 696 def func(args: t.List) -> t.Optional[E]: 697 if len(args) < 2: 698 return None 699 700 interval = args[1] 701 702 if not isinstance(interval, exp.Interval): 703 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 704 705 expression = interval.this 706 if expression and expression.is_string: 707 expression = exp.Literal.number(expression.this) 708 709 return expression_class( 710 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 711 ) 712 713 return func
725def date_add_interval_sql( 726 data_type: str, kind: str 727) -> t.Callable[[Generator, exp.Expression], str]: 728 def func(self: Generator, expression: exp.Expression) -> str: 729 this = self.sql(expression, "this") 730 unit = expression.args.get("unit") 731 unit = exp.var(unit.name.upper() if unit else "DAY") 732 interval = exp.Interval(this=expression.expression, unit=unit) 733 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 734 735 return func
744def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 745 if not expression.expression: 746 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 747 if expression.text("expression").lower() in TIMEZONES: 748 return self.sql( 749 exp.AtTimeZone( 750 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 751 zone=expression.expression, 752 ) 753 ) 754 return self.function_fallback_sql(expression)
795def encode_decode_sql( 796 self: Generator, expression: exp.Expression, name: str, replace: bool = True 797) -> str: 798 charset = expression.args.get("charset") 799 if charset and charset.name.lower() != "utf-8": 800 self.unsupported(f"Expected utf-8 character set, got {charset}.") 801 802 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
815def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 816 cond = expression.this 817 818 if isinstance(expression.this, exp.Distinct): 819 cond = expression.this.expressions[0] 820 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 821 822 return self.func("sum", exp.func("if", cond, 1, 0))
825def trim_sql(self: Generator, expression: exp.Trim) -> str: 826 target = self.sql(expression, "this") 827 trim_type = self.sql(expression, "position") 828 remove_chars = self.sql(expression, "expression") 829 collation = self.sql(expression, "collation") 830 831 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 832 if not remove_chars and not collation: 833 return self.trim_sql(expression) 834 835 trim_type = f"{trim_type} " if trim_type else "" 836 remove_chars = f"{remove_chars} " if remove_chars else "" 837 from_part = "FROM " if trim_type or remove_chars else "" 838 collation = f" COLLATE {collation}" if collation else "" 839 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
860def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 861 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 862 if bad_args: 863 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 864 865 return self.func( 866 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 867 )
870def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 871 bad_args = list( 872 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 873 ) 874 if bad_args: 875 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 876 877 return self.func( 878 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 879 )
882def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 883 names = [] 884 for agg in aggregations: 885 if isinstance(agg, exp.Alias): 886 names.append(agg.alias) 887 else: 888 """ 889 This case corresponds to aggregations without aliases being used as suffixes 890 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 891 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 892 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 893 """ 894 agg_all_unquoted = agg.transform( 895 lambda node: ( 896 exp.Identifier(this=node.name, quoted=False) 897 if isinstance(node, exp.Identifier) 898 else node 899 ) 900 ) 901 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 902 903 return names
943def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 944 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 945 if expression.args.get("count"): 946 self.unsupported(f"Only two arguments are supported in function {name}.") 947 948 return self.func(name, expression.this, expression.expression) 949 950 return _arg_max_or_min_sql
953def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 954 this = expression.this.copy() 955 956 return_type = expression.return_type 957 if return_type.is_type(exp.DataType.Type.DATE): 958 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 959 # can truncate timestamp strings, because some dialects can't cast them to DATE 960 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 961 962 expression.this.replace(exp.cast(this, return_type)) 963 return expression
966def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 967 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 968 if cast and isinstance(expression, exp.TsOrDsAdd): 969 expression = ts_or_ds_add_cast(expression) 970 971 return self.func( 972 name, 973 exp.var(expression.text("unit").upper() or "DAY"), 974 expression.expression, 975 expression.this, 976 ) 977 978 return _delta_sql
981def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 982 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 983 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 984 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 985 986 return self.sql(exp.cast(minus_one_day, "date"))
989def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 990 """Remove table refs from columns in when statements.""" 991 alias = expression.this.args.get("alias") 992 993 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 994 return self.dialect.normalize_identifier(identifier).name if identifier else None 995 996 targets = {normalize(expression.this.this)} 997 998 if alias: 999 targets.add(normalize(alias.this)) 1000 1001 for when in expression.expressions: 1002 when.transform( 1003 lambda node: ( 1004 exp.column(node.this) 1005 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1006 else node 1007 ), 1008 copy=False, 1009 ) 1010 1011 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1014def parse_json_extract_path( 1015 expr_type: t.Type[E], 1016 supports_null_if_invalid: bool = False, 1017) -> t.Callable[[t.List], E]: 1018 def _parse_json_extract_path(args: t.List) -> E: 1019 null_if_invalid = None 1020 1021 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1022 for arg in args[1:]: 1023 if isinstance(arg, exp.Literal): 1024 text = arg.name 1025 if is_int(text): 1026 segments.append(exp.JSONPathSubscript(this=int(text))) 1027 else: 1028 segments.append(exp.JSONPathKey(this=text)) 1029 elif supports_null_if_invalid: 1030 null_if_invalid = arg 1031 1032 this = seq_get(args, 0) 1033 jsonpath = exp.JSONPath(expressions=segments) 1034 1035 # This is done to avoid failing in the expression validator due to the arg count 1036 del args[2:] 1037 1038 if expr_type is exp.JSONExtractScalar: 1039 return expr_type(this=this, expression=jsonpath, null_if_invalid=null_if_invalid) 1040 1041 return expr_type(this=this, expression=jsonpath) 1042 1043 return _parse_json_extract_path
1046def json_path_segments(self: Generator, expression: exp.JSONPath) -> t.List[str]: 1047 segments = [] 1048 for segment in expression.expressions: 1049 path = self.sql(segment) 1050 if path: 1051 segments.append(f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}") 1052 1053 return segments