sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.generator import Generator 9from sqlglot.helper import flatten, seq_get 10from sqlglot.parser import Parser 11from sqlglot.time import format_time 12from sqlglot.tokens import Token, Tokenizer, TokenType 13from sqlglot.trie import new_trie 14 15 16class Dialects(str, Enum): 17 DIALECT = "" 18 19 BIGQUERY = "bigquery" 20 CLICKHOUSE = "clickhouse" 21 DATABRICKS = "databricks" 22 DRILL = "drill" 23 DUCKDB = "duckdb" 24 HIVE = "hive" 25 MYSQL = "mysql" 26 ORACLE = "oracle" 27 POSTGRES = "postgres" 28 PRESTO = "presto" 29 REDSHIFT = "redshift" 30 SNOWFLAKE = "snowflake" 31 SPARK = "spark" 32 SPARK2 = "spark2" 33 SQLITE = "sqlite" 34 STARROCKS = "starrocks" 35 TABLEAU = "tableau" 36 TERADATA = "teradata" 37 TRINO = "trino" 38 TSQL = "tsql" 39 40 41class _Dialect(type): 42 classes: t.Dict[str, t.Type[Dialect]] = {} 43 44 def __eq__(cls, other: t.Any) -> bool: 45 if cls is other: 46 return True 47 if isinstance(other, str): 48 return cls is cls.get(other) 49 if isinstance(other, Dialect): 50 return cls is type(other) 51 52 return False 53 54 def __hash__(cls) -> int: 55 return hash(cls.__name__.lower()) 56 57 @classmethod 58 def __getitem__(cls, key: str) -> t.Type[Dialect]: 59 return cls.classes[key] 60 61 @classmethod 62 def get( 63 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 64 ) -> t.Optional[t.Type[Dialect]]: 65 return cls.classes.get(key, default) 66 67 def __new__(cls, clsname, bases, attrs): 68 klass = super().__new__(cls, clsname, bases, attrs) 69 enum = Dialects.__members__.get(clsname.upper()) 70 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 71 72 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 73 klass.FORMAT_TRIE = ( 74 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 75 ) 76 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 77 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 78 79 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 80 klass.parser_class = getattr(klass, "Parser", Parser) 81 klass.generator_class = getattr(klass, "Generator", Generator) 82 83 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 84 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 85 klass.tokenizer_class._IDENTIFIERS.items() 86 )[0] 87 88 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 89 return next( 90 ( 91 (s, e) 92 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 93 if t == token_type 94 ), 95 (None, None), 96 ) 97 98 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 99 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 100 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 101 klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING) 102 103 dialect_properties = { 104 **{ 105 k: v 106 for k, v in vars(klass).items() 107 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 108 }, 109 "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], 110 "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], 111 } 112 113 if enum not in ("", "bigquery"): 114 dialect_properties["SELECT_KINDS"] = () 115 116 # Pass required dialect properties to the tokenizer, parser and generator classes 117 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 118 for name, value in dialect_properties.items(): 119 if hasattr(subclass, name): 120 setattr(subclass, name, value) 121 122 if not klass.STRICT_STRING_CONCAT: 123 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 124 125 klass.generator_class.can_identify = klass.can_identify 126 127 return klass 128 129 130class Dialect(metaclass=_Dialect): 131 # Determines the base index offset for arrays 132 INDEX_OFFSET = 0 133 134 # If true unnest table aliases are considered only as column aliases 135 UNNEST_COLUMN_ONLY = False 136 137 # Determines whether or not the table alias comes after tablesample 138 ALIAS_POST_TABLESAMPLE = False 139 140 # Determines whether or not unquoted identifiers are resolved as uppercase 141 # When set to None, it means that the dialect treats all identifiers as case-insensitive 142 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 143 144 # Determines whether or not an unquoted identifier can start with a digit 145 IDENTIFIERS_CAN_START_WITH_DIGIT = False 146 147 # Determines whether or not CONCAT's arguments must be strings 148 STRICT_STRING_CONCAT = False 149 150 # Determines how function names are going to be normalized 151 NORMALIZE_FUNCTIONS: bool | str = "upper" 152 153 # Indicates the default null ordering method to use if not explicitly set 154 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 155 NULL_ORDERING = "nulls_are_small" 156 157 DATE_FORMAT = "'%Y-%m-%d'" 158 DATEINT_FORMAT = "'%Y%m%d'" 159 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 160 161 # Custom time mappings in which the key represents dialect time format 162 # and the value represents a python time format 163 TIME_MAPPING: t.Dict[str, str] = {} 164 165 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 166 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 167 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 168 FORMAT_MAPPING: t.Dict[str, str] = {} 169 170 # Autofilled 171 tokenizer_class = Tokenizer 172 parser_class = Parser 173 generator_class = Generator 174 175 # A trie of the time_mapping keys 176 TIME_TRIE: t.Dict = {} 177 FORMAT_TRIE: t.Dict = {} 178 179 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 180 INVERSE_TIME_TRIE: t.Dict = {} 181 182 def __eq__(self, other: t.Any) -> bool: 183 return type(self) == other 184 185 def __hash__(self) -> int: 186 return hash(type(self)) 187 188 @classmethod 189 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 190 if not dialect: 191 return cls 192 if isinstance(dialect, _Dialect): 193 return dialect 194 if isinstance(dialect, Dialect): 195 return dialect.__class__ 196 197 result = cls.get(dialect) 198 if not result: 199 raise ValueError(f"Unknown dialect '{dialect}'") 200 201 return result 202 203 @classmethod 204 def format_time( 205 cls, expression: t.Optional[str | exp.Expression] 206 ) -> t.Optional[exp.Expression]: 207 if isinstance(expression, str): 208 return exp.Literal.string( 209 # the time formats are quoted 210 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 211 ) 212 213 if expression and expression.is_string: 214 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 215 216 return expression 217 218 @classmethod 219 def normalize_identifier(cls, expression: E) -> E: 220 """ 221 Normalizes an unquoted identifier to either lower or upper case, thus essentially 222 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 223 they will be normalized regardless of being quoted or not. 224 """ 225 if isinstance(expression, exp.Identifier) and ( 226 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 227 ): 228 expression.set( 229 "this", 230 expression.this.upper() 231 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 232 else expression.this.lower(), 233 ) 234 235 return expression 236 237 @classmethod 238 def case_sensitive(cls, text: str) -> bool: 239 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 240 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 241 return False 242 243 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 244 return any(unsafe(char) for char in text) 245 246 @classmethod 247 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 248 """Checks if text can be identified given an identify option. 249 250 Args: 251 text: The text to check. 252 identify: 253 "always" or `True`: Always returns true. 254 "safe": True if the identifier is case-insensitive. 255 256 Returns: 257 Whether or not the given text can be identified. 258 """ 259 if identify is True or identify == "always": 260 return True 261 262 if identify == "safe": 263 return not cls.case_sensitive(text) 264 265 return False 266 267 @classmethod 268 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 269 if isinstance(expression, exp.Identifier): 270 name = expression.this 271 expression.set( 272 "quoted", 273 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 274 ) 275 276 return expression 277 278 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 279 return self.parser(**opts).parse(self.tokenize(sql), sql) 280 281 def parse_into( 282 self, expression_type: exp.IntoType, sql: str, **opts 283 ) -> t.List[t.Optional[exp.Expression]]: 284 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 285 286 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 287 return self.generator(**opts).generate(expression) 288 289 def transpile(self, sql: str, **opts) -> t.List[str]: 290 return [self.generate(expression, **opts) for expression in self.parse(sql)] 291 292 def tokenize(self, sql: str) -> t.List[Token]: 293 return self.tokenizer.tokenize(sql) 294 295 @property 296 def tokenizer(self) -> Tokenizer: 297 if not hasattr(self, "_tokenizer"): 298 self._tokenizer = self.tokenizer_class() 299 return self._tokenizer 300 301 def parser(self, **opts) -> Parser: 302 return self.parser_class(**opts) 303 304 def generator(self, **opts) -> Generator: 305 return self.generator_class(**opts) 306 307 308DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 309 310 311def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 312 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 313 314 315def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 316 if expression.args.get("accuracy"): 317 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 318 return self.func("APPROX_COUNT_DISTINCT", expression.this) 319 320 321def if_sql(self: Generator, expression: exp.If) -> str: 322 return self.func( 323 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 324 ) 325 326 327def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 328 return self.binary(expression, "->") 329 330 331def arrow_json_extract_scalar_sql( 332 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 333) -> str: 334 return self.binary(expression, "->>") 335 336 337def inline_array_sql(self: Generator, expression: exp.Array) -> str: 338 return f"[{self.expressions(expression)}]" 339 340 341def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 342 return self.like_sql( 343 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 344 ) 345 346 347def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 348 zone = self.sql(expression, "this") 349 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 350 351 352def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 353 if expression.args.get("recursive"): 354 self.unsupported("Recursive CTEs are unsupported") 355 expression.args["recursive"] = False 356 return self.with_sql(expression) 357 358 359def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 360 n = self.sql(expression, "this") 361 d = self.sql(expression, "expression") 362 return f"IF({d} <> 0, {n} / {d}, NULL)" 363 364 365def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 366 self.unsupported("TABLESAMPLE unsupported") 367 return self.sql(expression.this) 368 369 370def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 371 self.unsupported("PIVOT unsupported") 372 return "" 373 374 375def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 376 return self.cast_sql(expression) 377 378 379def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 380 self.unsupported("Properties unsupported") 381 return "" 382 383 384def no_comment_column_constraint_sql( 385 self: Generator, expression: exp.CommentColumnConstraint 386) -> str: 387 self.unsupported("CommentColumnConstraint unsupported") 388 return "" 389 390 391def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 392 this = self.sql(expression, "this") 393 substr = self.sql(expression, "substr") 394 position = self.sql(expression, "position") 395 if position: 396 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 397 return f"STRPOS({this}, {substr})" 398 399 400def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 401 this = self.sql(expression, "this") 402 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 403 return f"{this}.{struct_key}" 404 405 406def var_map_sql( 407 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 408) -> str: 409 keys = expression.args["keys"] 410 values = expression.args["values"] 411 412 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 413 self.unsupported("Cannot convert array columns into map.") 414 return self.func(map_func_name, keys, values) 415 416 args = [] 417 for key, value in zip(keys.expressions, values.expressions): 418 args.append(self.sql(key)) 419 args.append(self.sql(value)) 420 421 return self.func(map_func_name, *args) 422 423 424def format_time_lambda( 425 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 426) -> t.Callable[[t.List], E]: 427 """Helper used for time expressions. 428 429 Args: 430 exp_class: the expression class to instantiate. 431 dialect: target sql dialect. 432 default: the default format, True being time. 433 434 Returns: 435 A callable that can be used to return the appropriately formatted time expression. 436 """ 437 438 def _format_time(args: t.List): 439 return exp_class( 440 this=seq_get(args, 0), 441 format=Dialect[dialect].format_time( 442 seq_get(args, 1) 443 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 444 ), 445 ) 446 447 return _format_time 448 449 450def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 451 """ 452 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 453 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 454 columns are removed from the create statement. 455 """ 456 has_schema = isinstance(expression.this, exp.Schema) 457 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 458 459 if has_schema and is_partitionable: 460 expression = expression.copy() 461 prop = expression.find(exp.PartitionedByProperty) 462 if prop and prop.this and not isinstance(prop.this, exp.Schema): 463 schema = expression.this 464 columns = {v.name.upper() for v in prop.this.expressions} 465 partitions = [col for col in schema.expressions if col.name.upper() in columns] 466 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 467 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 468 expression.set("this", schema) 469 470 return self.create_sql(expression) 471 472 473def parse_date_delta( 474 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 475) -> t.Callable[[t.List], E]: 476 def inner_func(args: t.List) -> E: 477 unit_based = len(args) == 3 478 this = args[2] if unit_based else seq_get(args, 0) 479 unit = args[0] if unit_based else exp.Literal.string("DAY") 480 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 481 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 482 483 return inner_func 484 485 486def parse_date_delta_with_interval( 487 expression_class: t.Type[E], 488) -> t.Callable[[t.List], t.Optional[E]]: 489 def func(args: t.List) -> t.Optional[E]: 490 if len(args) < 2: 491 return None 492 493 interval = args[1] 494 expression = interval.this 495 if expression and expression.is_string: 496 expression = exp.Literal.number(expression.this) 497 498 return expression_class( 499 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 500 ) 501 502 return func 503 504 505def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 506 unit = seq_get(args, 0) 507 this = seq_get(args, 1) 508 509 if isinstance(this, exp.Cast) and this.is_type("date"): 510 return exp.DateTrunc(unit=unit, this=this) 511 return exp.TimestampTrunc(this=this, unit=unit) 512 513 514def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 515 return self.func( 516 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 517 ) 518 519 520def locate_to_strposition(args: t.List) -> exp.Expression: 521 return exp.StrPosition( 522 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 523 ) 524 525 526def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 527 return self.func( 528 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 529 ) 530 531 532def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 533 expression = expression.copy() 534 return self.sql( 535 exp.Substring( 536 this=expression.this, start=exp.Literal.number(1), length=expression.expression 537 ) 538 ) 539 540 541def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 542 expression = expression.copy() 543 return self.sql( 544 exp.Substring( 545 this=expression.this, 546 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 547 ) 548 ) 549 550 551def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 552 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 553 554 555def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 556 return f"CAST({self.sql(expression, 'this')} AS DATE)" 557 558 559def min_or_least(self: Generator, expression: exp.Min) -> str: 560 name = "LEAST" if expression.expressions else "MIN" 561 return rename_func(name)(self, expression) 562 563 564def max_or_greatest(self: Generator, expression: exp.Max) -> str: 565 name = "GREATEST" if expression.expressions else "MAX" 566 return rename_func(name)(self, expression) 567 568 569def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 570 cond = expression.this 571 572 if isinstance(expression.this, exp.Distinct): 573 cond = expression.this.expressions[0] 574 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 575 576 return self.func("sum", exp.func("if", cond, 1, 0)) 577 578 579def trim_sql(self: Generator, expression: exp.Trim) -> str: 580 target = self.sql(expression, "this") 581 trim_type = self.sql(expression, "position") 582 remove_chars = self.sql(expression, "expression") 583 collation = self.sql(expression, "collation") 584 585 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 586 if not remove_chars and not collation: 587 return self.trim_sql(expression) 588 589 trim_type = f"{trim_type} " if trim_type else "" 590 remove_chars = f"{remove_chars} " if remove_chars else "" 591 from_part = "FROM " if trim_type or remove_chars else "" 592 collation = f" COLLATE {collation}" if collation else "" 593 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 594 595 596def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 597 return self.func("STRPTIME", expression.this, self.format_time(expression)) 598 599 600def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 601 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 602 _dialect = Dialect.get_or_raise(dialect) 603 time_format = self.format_time(expression) 604 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 605 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 606 return f"CAST({self.sql(expression, 'this')} AS DATE)" 607 608 return _ts_or_ds_to_date_sql 609 610 611def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 612 this, *rest_args = expression.expressions 613 for arg in rest_args: 614 this = exp.DPipe(this=this, expression=arg) 615 616 return self.sql(this) 617 618 619# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator 620def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 621 names = [] 622 for agg in aggregations: 623 if isinstance(agg, exp.Alias): 624 names.append(agg.alias) 625 else: 626 """ 627 This case corresponds to aggregations without aliases being used as suffixes 628 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 629 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 630 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 631 """ 632 agg_all_unquoted = agg.transform( 633 lambda node: exp.Identifier(this=node.name, quoted=False) 634 if isinstance(node, exp.Identifier) 635 else node 636 ) 637 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 638 639 return names
17class Dialects(str, Enum): 18 DIALECT = "" 19 20 BIGQUERY = "bigquery" 21 CLICKHOUSE = "clickhouse" 22 DATABRICKS = "databricks" 23 DRILL = "drill" 24 DUCKDB = "duckdb" 25 HIVE = "hive" 26 MYSQL = "mysql" 27 ORACLE = "oracle" 28 POSTGRES = "postgres" 29 PRESTO = "presto" 30 REDSHIFT = "redshift" 31 SNOWFLAKE = "snowflake" 32 SPARK = "spark" 33 SPARK2 = "spark2" 34 SQLITE = "sqlite" 35 STARROCKS = "starrocks" 36 TABLEAU = "tableau" 37 TERADATA = "teradata" 38 TRINO = "trino" 39 TSQL = "tsql"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
131class Dialect(metaclass=_Dialect): 132 # Determines the base index offset for arrays 133 INDEX_OFFSET = 0 134 135 # If true unnest table aliases are considered only as column aliases 136 UNNEST_COLUMN_ONLY = False 137 138 # Determines whether or not the table alias comes after tablesample 139 ALIAS_POST_TABLESAMPLE = False 140 141 # Determines whether or not unquoted identifiers are resolved as uppercase 142 # When set to None, it means that the dialect treats all identifiers as case-insensitive 143 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 144 145 # Determines whether or not an unquoted identifier can start with a digit 146 IDENTIFIERS_CAN_START_WITH_DIGIT = False 147 148 # Determines whether or not CONCAT's arguments must be strings 149 STRICT_STRING_CONCAT = False 150 151 # Determines how function names are going to be normalized 152 NORMALIZE_FUNCTIONS: bool | str = "upper" 153 154 # Indicates the default null ordering method to use if not explicitly set 155 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 156 NULL_ORDERING = "nulls_are_small" 157 158 DATE_FORMAT = "'%Y-%m-%d'" 159 DATEINT_FORMAT = "'%Y%m%d'" 160 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 161 162 # Custom time mappings in which the key represents dialect time format 163 # and the value represents a python time format 164 TIME_MAPPING: t.Dict[str, str] = {} 165 166 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 167 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 168 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 169 FORMAT_MAPPING: t.Dict[str, str] = {} 170 171 # Autofilled 172 tokenizer_class = Tokenizer 173 parser_class = Parser 174 generator_class = Generator 175 176 # A trie of the time_mapping keys 177 TIME_TRIE: t.Dict = {} 178 FORMAT_TRIE: t.Dict = {} 179 180 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 181 INVERSE_TIME_TRIE: t.Dict = {} 182 183 def __eq__(self, other: t.Any) -> bool: 184 return type(self) == other 185 186 def __hash__(self) -> int: 187 return hash(type(self)) 188 189 @classmethod 190 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 191 if not dialect: 192 return cls 193 if isinstance(dialect, _Dialect): 194 return dialect 195 if isinstance(dialect, Dialect): 196 return dialect.__class__ 197 198 result = cls.get(dialect) 199 if not result: 200 raise ValueError(f"Unknown dialect '{dialect}'") 201 202 return result 203 204 @classmethod 205 def format_time( 206 cls, expression: t.Optional[str | exp.Expression] 207 ) -> t.Optional[exp.Expression]: 208 if isinstance(expression, str): 209 return exp.Literal.string( 210 # the time formats are quoted 211 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 212 ) 213 214 if expression and expression.is_string: 215 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 216 217 return expression 218 219 @classmethod 220 def normalize_identifier(cls, expression: E) -> E: 221 """ 222 Normalizes an unquoted identifier to either lower or upper case, thus essentially 223 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 224 they will be normalized regardless of being quoted or not. 225 """ 226 if isinstance(expression, exp.Identifier) and ( 227 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 228 ): 229 expression.set( 230 "this", 231 expression.this.upper() 232 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 233 else expression.this.lower(), 234 ) 235 236 return expression 237 238 @classmethod 239 def case_sensitive(cls, text: str) -> bool: 240 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 241 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 242 return False 243 244 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 245 return any(unsafe(char) for char in text) 246 247 @classmethod 248 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 249 """Checks if text can be identified given an identify option. 250 251 Args: 252 text: The text to check. 253 identify: 254 "always" or `True`: Always returns true. 255 "safe": True if the identifier is case-insensitive. 256 257 Returns: 258 Whether or not the given text can be identified. 259 """ 260 if identify is True or identify == "always": 261 return True 262 263 if identify == "safe": 264 return not cls.case_sensitive(text) 265 266 return False 267 268 @classmethod 269 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 270 if isinstance(expression, exp.Identifier): 271 name = expression.this 272 expression.set( 273 "quoted", 274 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 275 ) 276 277 return expression 278 279 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 280 return self.parser(**opts).parse(self.tokenize(sql), sql) 281 282 def parse_into( 283 self, expression_type: exp.IntoType, sql: str, **opts 284 ) -> t.List[t.Optional[exp.Expression]]: 285 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 286 287 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 288 return self.generator(**opts).generate(expression) 289 290 def transpile(self, sql: str, **opts) -> t.List[str]: 291 return [self.generate(expression, **opts) for expression in self.parse(sql)] 292 293 def tokenize(self, sql: str) -> t.List[Token]: 294 return self.tokenizer.tokenize(sql) 295 296 @property 297 def tokenizer(self) -> Tokenizer: 298 if not hasattr(self, "_tokenizer"): 299 self._tokenizer = self.tokenizer_class() 300 return self._tokenizer 301 302 def parser(self, **opts) -> Parser: 303 return self.parser_class(**opts) 304 305 def generator(self, **opts) -> Generator: 306 return self.generator_class(**opts)
189 @classmethod 190 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 191 if not dialect: 192 return cls 193 if isinstance(dialect, _Dialect): 194 return dialect 195 if isinstance(dialect, Dialect): 196 return dialect.__class__ 197 198 result = cls.get(dialect) 199 if not result: 200 raise ValueError(f"Unknown dialect '{dialect}'") 201 202 return result
204 @classmethod 205 def format_time( 206 cls, expression: t.Optional[str | exp.Expression] 207 ) -> t.Optional[exp.Expression]: 208 if isinstance(expression, str): 209 return exp.Literal.string( 210 # the time formats are quoted 211 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 212 ) 213 214 if expression and expression.is_string: 215 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 216 217 return expression
219 @classmethod 220 def normalize_identifier(cls, expression: E) -> E: 221 """ 222 Normalizes an unquoted identifier to either lower or upper case, thus essentially 223 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 224 they will be normalized regardless of being quoted or not. 225 """ 226 if isinstance(expression, exp.Identifier) and ( 227 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 228 ): 229 expression.set( 230 "this", 231 expression.this.upper() 232 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 233 else expression.this.lower(), 234 ) 235 236 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
238 @classmethod 239 def case_sensitive(cls, text: str) -> bool: 240 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 241 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 242 return False 243 244 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 245 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
247 @classmethod 248 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 249 """Checks if text can be identified given an identify option. 250 251 Args: 252 text: The text to check. 253 identify: 254 "always" or `True`: Always returns true. 255 "safe": True if the identifier is case-insensitive. 256 257 Returns: 258 Whether or not the given text can be identified. 259 """ 260 if identify is True or identify == "always": 261 return True 262 263 if identify == "safe": 264 return not cls.case_sensitive(text) 265 266 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
268 @classmethod 269 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 270 if isinstance(expression, exp.Identifier): 271 name = expression.this 272 expression.set( 273 "quoted", 274 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 275 ) 276 277 return expression
392def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 393 this = self.sql(expression, "this") 394 substr = self.sql(expression, "substr") 395 position = self.sql(expression, "position") 396 if position: 397 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 398 return f"STRPOS({this}, {substr})"
407def var_map_sql( 408 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 409) -> str: 410 keys = expression.args["keys"] 411 values = expression.args["values"] 412 413 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 414 self.unsupported("Cannot convert array columns into map.") 415 return self.func(map_func_name, keys, values) 416 417 args = [] 418 for key, value in zip(keys.expressions, values.expressions): 419 args.append(self.sql(key)) 420 args.append(self.sql(value)) 421 422 return self.func(map_func_name, *args)
425def format_time_lambda( 426 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 427) -> t.Callable[[t.List], E]: 428 """Helper used for time expressions. 429 430 Args: 431 exp_class: the expression class to instantiate. 432 dialect: target sql dialect. 433 default: the default format, True being time. 434 435 Returns: 436 A callable that can be used to return the appropriately formatted time expression. 437 """ 438 439 def _format_time(args: t.List): 440 return exp_class( 441 this=seq_get(args, 0), 442 format=Dialect[dialect].format_time( 443 seq_get(args, 1) 444 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 445 ), 446 ) 447 448 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
451def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 452 """ 453 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 454 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 455 columns are removed from the create statement. 456 """ 457 has_schema = isinstance(expression.this, exp.Schema) 458 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 459 460 if has_schema and is_partitionable: 461 expression = expression.copy() 462 prop = expression.find(exp.PartitionedByProperty) 463 if prop and prop.this and not isinstance(prop.this, exp.Schema): 464 schema = expression.this 465 columns = {v.name.upper() for v in prop.this.expressions} 466 partitions = [col for col in schema.expressions if col.name.upper() in columns] 467 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 468 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 469 expression.set("this", schema) 470 471 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
474def parse_date_delta( 475 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 476) -> t.Callable[[t.List], E]: 477 def inner_func(args: t.List) -> E: 478 unit_based = len(args) == 3 479 this = args[2] if unit_based else seq_get(args, 0) 480 unit = args[0] if unit_based else exp.Literal.string("DAY") 481 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 482 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 483 484 return inner_func
487def parse_date_delta_with_interval( 488 expression_class: t.Type[E], 489) -> t.Callable[[t.List], t.Optional[E]]: 490 def func(args: t.List) -> t.Optional[E]: 491 if len(args) < 2: 492 return None 493 494 interval = args[1] 495 expression = interval.this 496 if expression and expression.is_string: 497 expression = exp.Literal.number(expression.this) 498 499 return expression_class( 500 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 501 ) 502 503 return func
570def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 571 cond = expression.this 572 573 if isinstance(expression.this, exp.Distinct): 574 cond = expression.this.expressions[0] 575 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 576 577 return self.func("sum", exp.func("if", cond, 1, 0))
580def trim_sql(self: Generator, expression: exp.Trim) -> str: 581 target = self.sql(expression, "this") 582 trim_type = self.sql(expression, "position") 583 remove_chars = self.sql(expression, "expression") 584 collation = self.sql(expression, "collation") 585 586 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 587 if not remove_chars and not collation: 588 return self.trim_sql(expression) 589 590 trim_type = f"{trim_type} " if trim_type else "" 591 remove_chars = f"{remove_chars} " if remove_chars else "" 592 from_part = "FROM " if trim_type or remove_chars else "" 593 collation = f" COLLATE {collation}" if collation else "" 594 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
601def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 602 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 603 _dialect = Dialect.get_or_raise(dialect) 604 time_format = self.format_time(expression) 605 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 606 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 607 return f"CAST({self.sql(expression, 'this')} AS DATE)" 608 609 return _ts_or_ds_to_date_sql
621def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 622 names = [] 623 for agg in aggregations: 624 if isinstance(agg, exp.Alias): 625 names.append(agg.alias) 626 else: 627 """ 628 This case corresponds to aggregations without aliases being used as suffixes 629 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 630 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 631 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 632 """ 633 agg_all_unquoted = agg.transform( 634 lambda node: exp.Identifier(this=node.name, quoted=False) 635 if isinstance(node, exp.Identifier) 636 else node 637 ) 638 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 639 640 return names