sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.generator import Generator 9from sqlglot.helper import flatten, seq_get 10from sqlglot.parser import Parser 11from sqlglot.time import format_time 12from sqlglot.tokens import Token, Tokenizer, TokenType 13from sqlglot.trie import new_trie 14 15 16class Dialects(str, Enum): 17 DIALECT = "" 18 19 BIGQUERY = "bigquery" 20 CLICKHOUSE = "clickhouse" 21 DATABRICKS = "databricks" 22 DRILL = "drill" 23 DUCKDB = "duckdb" 24 HIVE = "hive" 25 MYSQL = "mysql" 26 ORACLE = "oracle" 27 POSTGRES = "postgres" 28 PRESTO = "presto" 29 REDSHIFT = "redshift" 30 SNOWFLAKE = "snowflake" 31 SPARK = "spark" 32 SPARK2 = "spark2" 33 SQLITE = "sqlite" 34 STARROCKS = "starrocks" 35 TABLEAU = "tableau" 36 TERADATA = "teradata" 37 TRINO = "trino" 38 TSQL = "tsql" 39 40 41class _Dialect(type): 42 classes: t.Dict[str, t.Type[Dialect]] = {} 43 44 def __eq__(cls, other: t.Any) -> bool: 45 if cls is other: 46 return True 47 if isinstance(other, str): 48 return cls is cls.get(other) 49 if isinstance(other, Dialect): 50 return cls is type(other) 51 52 return False 53 54 def __hash__(cls) -> int: 55 return hash(cls.__name__.lower()) 56 57 @classmethod 58 def __getitem__(cls, key: str) -> t.Type[Dialect]: 59 return cls.classes[key] 60 61 @classmethod 62 def get( 63 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 64 ) -> t.Optional[t.Type[Dialect]]: 65 return cls.classes.get(key, default) 66 67 def __new__(cls, clsname, bases, attrs): 68 klass = super().__new__(cls, clsname, bases, attrs) 69 enum = Dialects.__members__.get(clsname.upper()) 70 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 71 72 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 73 klass.FORMAT_TRIE = ( 74 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 75 ) 76 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 77 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 78 79 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 80 klass.parser_class = getattr(klass, "Parser", Parser) 81 klass.generator_class = getattr(klass, "Generator", Generator) 82 83 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 84 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 85 klass.tokenizer_class._IDENTIFIERS.items() 86 )[0] 87 88 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 89 return next( 90 ( 91 (s, e) 92 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 93 if t == token_type 94 ), 95 (None, None), 96 ) 97 98 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 99 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 100 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 101 102 dialect_properties = { 103 **{ 104 k: v 105 for k, v in vars(klass).items() 106 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 107 }, 108 "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], 109 "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], 110 } 111 112 if enum not in ("", "bigquery"): 113 dialect_properties["SELECT_KINDS"] = () 114 115 # Pass required dialect properties to the tokenizer, parser and generator classes 116 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 117 for name, value in dialect_properties.items(): 118 if hasattr(subclass, name): 119 setattr(subclass, name, value) 120 121 if not klass.STRICT_STRING_CONCAT: 122 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 123 124 klass.generator_class.can_identify = klass.can_identify 125 126 return klass 127 128 129class Dialect(metaclass=_Dialect): 130 # Determines the base index offset for arrays 131 INDEX_OFFSET = 0 132 133 # If true unnest table aliases are considered only as column aliases 134 UNNEST_COLUMN_ONLY = False 135 136 # Determines whether or not the table alias comes after tablesample 137 ALIAS_POST_TABLESAMPLE = False 138 139 # Determines whether or not unquoted identifiers are resolved as uppercase 140 # When set to None, it means that the dialect treats all identifiers as case-insensitive 141 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 142 143 # Determines whether or not an unquoted identifier can start with a digit 144 IDENTIFIERS_CAN_START_WITH_DIGIT = False 145 146 # Determines whether or not CONCAT's arguments must be strings 147 STRICT_STRING_CONCAT = False 148 149 # Determines how function names are going to be normalized 150 NORMALIZE_FUNCTIONS: bool | str = "upper" 151 152 # Indicates the default null ordering method to use if not explicitly set 153 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 154 NULL_ORDERING = "nulls_are_small" 155 156 DATE_FORMAT = "'%Y-%m-%d'" 157 DATEINT_FORMAT = "'%Y%m%d'" 158 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 159 160 # Custom time mappings in which the key represents dialect time format 161 # and the value represents a python time format 162 TIME_MAPPING: t.Dict[str, str] = {} 163 164 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 165 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 166 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 167 FORMAT_MAPPING: t.Dict[str, str] = {} 168 169 # Autofilled 170 tokenizer_class = Tokenizer 171 parser_class = Parser 172 generator_class = Generator 173 174 # A trie of the time_mapping keys 175 TIME_TRIE: t.Dict = {} 176 FORMAT_TRIE: t.Dict = {} 177 178 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 179 INVERSE_TIME_TRIE: t.Dict = {} 180 181 def __eq__(self, other: t.Any) -> bool: 182 return type(self) == other 183 184 def __hash__(self) -> int: 185 return hash(type(self)) 186 187 @classmethod 188 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 189 if not dialect: 190 return cls 191 if isinstance(dialect, _Dialect): 192 return dialect 193 if isinstance(dialect, Dialect): 194 return dialect.__class__ 195 196 result = cls.get(dialect) 197 if not result: 198 raise ValueError(f"Unknown dialect '{dialect}'") 199 200 return result 201 202 @classmethod 203 def format_time( 204 cls, expression: t.Optional[str | exp.Expression] 205 ) -> t.Optional[exp.Expression]: 206 if isinstance(expression, str): 207 return exp.Literal.string( 208 # the time formats are quoted 209 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 210 ) 211 212 if expression and expression.is_string: 213 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 214 215 return expression 216 217 @classmethod 218 def normalize_identifier(cls, expression: E) -> E: 219 """ 220 Normalizes an unquoted identifier to either lower or upper case, thus essentially 221 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 222 they will be normalized regardless of being quoted or not. 223 """ 224 if isinstance(expression, exp.Identifier) and ( 225 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 226 ): 227 expression.set( 228 "this", 229 expression.this.upper() 230 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 231 else expression.this.lower(), 232 ) 233 234 return expression 235 236 @classmethod 237 def case_sensitive(cls, text: str) -> bool: 238 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 239 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 240 return False 241 242 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 243 return any(unsafe(char) for char in text) 244 245 @classmethod 246 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 247 """Checks if text can be identified given an identify option. 248 249 Args: 250 text: The text to check. 251 identify: 252 "always" or `True`: Always returns true. 253 "safe": True if the identifier is case-insensitive. 254 255 Returns: 256 Whether or not the given text can be identified. 257 """ 258 if identify is True or identify == "always": 259 return True 260 261 if identify == "safe": 262 return not cls.case_sensitive(text) 263 264 return False 265 266 @classmethod 267 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 268 if isinstance(expression, exp.Identifier): 269 name = expression.this 270 expression.set( 271 "quoted", 272 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 273 ) 274 275 return expression 276 277 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 278 return self.parser(**opts).parse(self.tokenize(sql), sql) 279 280 def parse_into( 281 self, expression_type: exp.IntoType, sql: str, **opts 282 ) -> t.List[t.Optional[exp.Expression]]: 283 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 284 285 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 286 return self.generator(**opts).generate(expression) 287 288 def transpile(self, sql: str, **opts) -> t.List[str]: 289 return [self.generate(expression, **opts) for expression in self.parse(sql)] 290 291 def tokenize(self, sql: str) -> t.List[Token]: 292 return self.tokenizer.tokenize(sql) 293 294 @property 295 def tokenizer(self) -> Tokenizer: 296 if not hasattr(self, "_tokenizer"): 297 self._tokenizer = self.tokenizer_class() 298 return self._tokenizer 299 300 def parser(self, **opts) -> Parser: 301 return self.parser_class(**opts) 302 303 def generator(self, **opts) -> Generator: 304 return self.generator_class(**opts) 305 306 307DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 308 309 310def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 311 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 312 313 314def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 315 if expression.args.get("accuracy"): 316 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 317 return self.func("APPROX_COUNT_DISTINCT", expression.this) 318 319 320def if_sql(self: Generator, expression: exp.If) -> str: 321 return self.func( 322 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 323 ) 324 325 326def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 327 return self.binary(expression, "->") 328 329 330def arrow_json_extract_scalar_sql( 331 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 332) -> str: 333 return self.binary(expression, "->>") 334 335 336def inline_array_sql(self: Generator, expression: exp.Array) -> str: 337 return f"[{self.expressions(expression)}]" 338 339 340def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 341 return self.like_sql( 342 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 343 ) 344 345 346def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 347 zone = self.sql(expression, "this") 348 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 349 350 351def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 352 if expression.args.get("recursive"): 353 self.unsupported("Recursive CTEs are unsupported") 354 expression.args["recursive"] = False 355 return self.with_sql(expression) 356 357 358def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 359 n = self.sql(expression, "this") 360 d = self.sql(expression, "expression") 361 return f"IF({d} <> 0, {n} / {d}, NULL)" 362 363 364def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 365 self.unsupported("TABLESAMPLE unsupported") 366 return self.sql(expression.this) 367 368 369def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 370 self.unsupported("PIVOT unsupported") 371 return "" 372 373 374def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 375 return self.cast_sql(expression) 376 377 378def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 379 self.unsupported("Properties unsupported") 380 return "" 381 382 383def no_comment_column_constraint_sql( 384 self: Generator, expression: exp.CommentColumnConstraint 385) -> str: 386 self.unsupported("CommentColumnConstraint unsupported") 387 return "" 388 389 390def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 391 self.unsupported("MAP_FROM_ENTRIES unsupported") 392 return "" 393 394 395def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 396 this = self.sql(expression, "this") 397 substr = self.sql(expression, "substr") 398 position = self.sql(expression, "position") 399 if position: 400 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 401 return f"STRPOS({this}, {substr})" 402 403 404def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 405 this = self.sql(expression, "this") 406 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 407 return f"{this}.{struct_key}" 408 409 410def var_map_sql( 411 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 412) -> str: 413 keys = expression.args["keys"] 414 values = expression.args["values"] 415 416 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 417 self.unsupported("Cannot convert array columns into map.") 418 return self.func(map_func_name, keys, values) 419 420 args = [] 421 for key, value in zip(keys.expressions, values.expressions): 422 args.append(self.sql(key)) 423 args.append(self.sql(value)) 424 425 return self.func(map_func_name, *args) 426 427 428def format_time_lambda( 429 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 430) -> t.Callable[[t.List], E]: 431 """Helper used for time expressions. 432 433 Args: 434 exp_class: the expression class to instantiate. 435 dialect: target sql dialect. 436 default: the default format, True being time. 437 438 Returns: 439 A callable that can be used to return the appropriately formatted time expression. 440 """ 441 442 def _format_time(args: t.List): 443 return exp_class( 444 this=seq_get(args, 0), 445 format=Dialect[dialect].format_time( 446 seq_get(args, 1) 447 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 448 ), 449 ) 450 451 return _format_time 452 453 454def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 455 """ 456 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 457 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 458 columns are removed from the create statement. 459 """ 460 has_schema = isinstance(expression.this, exp.Schema) 461 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 462 463 if has_schema and is_partitionable: 464 expression = expression.copy() 465 prop = expression.find(exp.PartitionedByProperty) 466 if prop and prop.this and not isinstance(prop.this, exp.Schema): 467 schema = expression.this 468 columns = {v.name.upper() for v in prop.this.expressions} 469 partitions = [col for col in schema.expressions if col.name.upper() in columns] 470 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 471 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 472 expression.set("this", schema) 473 474 return self.create_sql(expression) 475 476 477def parse_date_delta( 478 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 479) -> t.Callable[[t.List], E]: 480 def inner_func(args: t.List) -> E: 481 unit_based = len(args) == 3 482 this = args[2] if unit_based else seq_get(args, 0) 483 unit = args[0] if unit_based else exp.Literal.string("DAY") 484 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 485 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 486 487 return inner_func 488 489 490def parse_date_delta_with_interval( 491 expression_class: t.Type[E], 492) -> t.Callable[[t.List], t.Optional[E]]: 493 def func(args: t.List) -> t.Optional[E]: 494 if len(args) < 2: 495 return None 496 497 interval = args[1] 498 expression = interval.this 499 if expression and expression.is_string: 500 expression = exp.Literal.number(expression.this) 501 502 return expression_class( 503 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 504 ) 505 506 return func 507 508 509def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 510 unit = seq_get(args, 0) 511 this = seq_get(args, 1) 512 513 if isinstance(this, exp.Cast) and this.is_type("date"): 514 return exp.DateTrunc(unit=unit, this=this) 515 return exp.TimestampTrunc(this=this, unit=unit) 516 517 518def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 519 return self.func( 520 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 521 ) 522 523 524def locate_to_strposition(args: t.List) -> exp.Expression: 525 return exp.StrPosition( 526 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 527 ) 528 529 530def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 531 return self.func( 532 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 533 ) 534 535 536def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 537 expression = expression.copy() 538 return self.sql( 539 exp.Substring( 540 this=expression.this, start=exp.Literal.number(1), length=expression.expression 541 ) 542 ) 543 544 545def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 546 expression = expression.copy() 547 return self.sql( 548 exp.Substring( 549 this=expression.this, 550 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 551 ) 552 ) 553 554 555def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 556 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 557 558 559def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 560 return f"CAST({self.sql(expression, 'this')} AS DATE)" 561 562 563def min_or_least(self: Generator, expression: exp.Min) -> str: 564 name = "LEAST" if expression.expressions else "MIN" 565 return rename_func(name)(self, expression) 566 567 568def max_or_greatest(self: Generator, expression: exp.Max) -> str: 569 name = "GREATEST" if expression.expressions else "MAX" 570 return rename_func(name)(self, expression) 571 572 573def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 574 cond = expression.this 575 576 if isinstance(expression.this, exp.Distinct): 577 cond = expression.this.expressions[0] 578 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 579 580 return self.func("sum", exp.func("if", cond, 1, 0)) 581 582 583def trim_sql(self: Generator, expression: exp.Trim) -> str: 584 target = self.sql(expression, "this") 585 trim_type = self.sql(expression, "position") 586 remove_chars = self.sql(expression, "expression") 587 collation = self.sql(expression, "collation") 588 589 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 590 if not remove_chars and not collation: 591 return self.trim_sql(expression) 592 593 trim_type = f"{trim_type} " if trim_type else "" 594 remove_chars = f"{remove_chars} " if remove_chars else "" 595 from_part = "FROM " if trim_type or remove_chars else "" 596 collation = f" COLLATE {collation}" if collation else "" 597 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 598 599 600def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 601 return self.func("STRPTIME", expression.this, self.format_time(expression)) 602 603 604def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 605 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 606 _dialect = Dialect.get_or_raise(dialect) 607 time_format = self.format_time(expression) 608 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 609 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 610 return f"CAST({self.sql(expression, 'this')} AS DATE)" 611 612 return _ts_or_ds_to_date_sql 613 614 615def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 616 this, *rest_args = expression.expressions 617 for arg in rest_args: 618 this = exp.DPipe(this=this, expression=arg) 619 620 return self.sql(this) 621 622 623# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator 624def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 625 names = [] 626 for agg in aggregations: 627 if isinstance(agg, exp.Alias): 628 names.append(agg.alias) 629 else: 630 """ 631 This case corresponds to aggregations without aliases being used as suffixes 632 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 633 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 634 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 635 """ 636 agg_all_unquoted = agg.transform( 637 lambda node: exp.Identifier(this=node.name, quoted=False) 638 if isinstance(node, exp.Identifier) 639 else node 640 ) 641 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 642 643 return names
17class Dialects(str, Enum): 18 DIALECT = "" 19 20 BIGQUERY = "bigquery" 21 CLICKHOUSE = "clickhouse" 22 DATABRICKS = "databricks" 23 DRILL = "drill" 24 DUCKDB = "duckdb" 25 HIVE = "hive" 26 MYSQL = "mysql" 27 ORACLE = "oracle" 28 POSTGRES = "postgres" 29 PRESTO = "presto" 30 REDSHIFT = "redshift" 31 SNOWFLAKE = "snowflake" 32 SPARK = "spark" 33 SPARK2 = "spark2" 34 SQLITE = "sqlite" 35 STARROCKS = "starrocks" 36 TABLEAU = "tableau" 37 TERADATA = "teradata" 38 TRINO = "trino" 39 TSQL = "tsql"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
130class Dialect(metaclass=_Dialect): 131 # Determines the base index offset for arrays 132 INDEX_OFFSET = 0 133 134 # If true unnest table aliases are considered only as column aliases 135 UNNEST_COLUMN_ONLY = False 136 137 # Determines whether or not the table alias comes after tablesample 138 ALIAS_POST_TABLESAMPLE = False 139 140 # Determines whether or not unquoted identifiers are resolved as uppercase 141 # When set to None, it means that the dialect treats all identifiers as case-insensitive 142 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 143 144 # Determines whether or not an unquoted identifier can start with a digit 145 IDENTIFIERS_CAN_START_WITH_DIGIT = False 146 147 # Determines whether or not CONCAT's arguments must be strings 148 STRICT_STRING_CONCAT = False 149 150 # Determines how function names are going to be normalized 151 NORMALIZE_FUNCTIONS: bool | str = "upper" 152 153 # Indicates the default null ordering method to use if not explicitly set 154 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 155 NULL_ORDERING = "nulls_are_small" 156 157 DATE_FORMAT = "'%Y-%m-%d'" 158 DATEINT_FORMAT = "'%Y%m%d'" 159 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 160 161 # Custom time mappings in which the key represents dialect time format 162 # and the value represents a python time format 163 TIME_MAPPING: t.Dict[str, str] = {} 164 165 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 166 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 167 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 168 FORMAT_MAPPING: t.Dict[str, str] = {} 169 170 # Autofilled 171 tokenizer_class = Tokenizer 172 parser_class = Parser 173 generator_class = Generator 174 175 # A trie of the time_mapping keys 176 TIME_TRIE: t.Dict = {} 177 FORMAT_TRIE: t.Dict = {} 178 179 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 180 INVERSE_TIME_TRIE: t.Dict = {} 181 182 def __eq__(self, other: t.Any) -> bool: 183 return type(self) == other 184 185 def __hash__(self) -> int: 186 return hash(type(self)) 187 188 @classmethod 189 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 190 if not dialect: 191 return cls 192 if isinstance(dialect, _Dialect): 193 return dialect 194 if isinstance(dialect, Dialect): 195 return dialect.__class__ 196 197 result = cls.get(dialect) 198 if not result: 199 raise ValueError(f"Unknown dialect '{dialect}'") 200 201 return result 202 203 @classmethod 204 def format_time( 205 cls, expression: t.Optional[str | exp.Expression] 206 ) -> t.Optional[exp.Expression]: 207 if isinstance(expression, str): 208 return exp.Literal.string( 209 # the time formats are quoted 210 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 211 ) 212 213 if expression and expression.is_string: 214 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 215 216 return expression 217 218 @classmethod 219 def normalize_identifier(cls, expression: E) -> E: 220 """ 221 Normalizes an unquoted identifier to either lower or upper case, thus essentially 222 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 223 they will be normalized regardless of being quoted or not. 224 """ 225 if isinstance(expression, exp.Identifier) and ( 226 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 227 ): 228 expression.set( 229 "this", 230 expression.this.upper() 231 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 232 else expression.this.lower(), 233 ) 234 235 return expression 236 237 @classmethod 238 def case_sensitive(cls, text: str) -> bool: 239 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 240 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 241 return False 242 243 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 244 return any(unsafe(char) for char in text) 245 246 @classmethod 247 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 248 """Checks if text can be identified given an identify option. 249 250 Args: 251 text: The text to check. 252 identify: 253 "always" or `True`: Always returns true. 254 "safe": True if the identifier is case-insensitive. 255 256 Returns: 257 Whether or not the given text can be identified. 258 """ 259 if identify is True or identify == "always": 260 return True 261 262 if identify == "safe": 263 return not cls.case_sensitive(text) 264 265 return False 266 267 @classmethod 268 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 269 if isinstance(expression, exp.Identifier): 270 name = expression.this 271 expression.set( 272 "quoted", 273 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 274 ) 275 276 return expression 277 278 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 279 return self.parser(**opts).parse(self.tokenize(sql), sql) 280 281 def parse_into( 282 self, expression_type: exp.IntoType, sql: str, **opts 283 ) -> t.List[t.Optional[exp.Expression]]: 284 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 285 286 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 287 return self.generator(**opts).generate(expression) 288 289 def transpile(self, sql: str, **opts) -> t.List[str]: 290 return [self.generate(expression, **opts) for expression in self.parse(sql)] 291 292 def tokenize(self, sql: str) -> t.List[Token]: 293 return self.tokenizer.tokenize(sql) 294 295 @property 296 def tokenizer(self) -> Tokenizer: 297 if not hasattr(self, "_tokenizer"): 298 self._tokenizer = self.tokenizer_class() 299 return self._tokenizer 300 301 def parser(self, **opts) -> Parser: 302 return self.parser_class(**opts) 303 304 def generator(self, **opts) -> Generator: 305 return self.generator_class(**opts)
188 @classmethod 189 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 190 if not dialect: 191 return cls 192 if isinstance(dialect, _Dialect): 193 return dialect 194 if isinstance(dialect, Dialect): 195 return dialect.__class__ 196 197 result = cls.get(dialect) 198 if not result: 199 raise ValueError(f"Unknown dialect '{dialect}'") 200 201 return result
203 @classmethod 204 def format_time( 205 cls, expression: t.Optional[str | exp.Expression] 206 ) -> t.Optional[exp.Expression]: 207 if isinstance(expression, str): 208 return exp.Literal.string( 209 # the time formats are quoted 210 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 211 ) 212 213 if expression and expression.is_string: 214 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 215 216 return expression
218 @classmethod 219 def normalize_identifier(cls, expression: E) -> E: 220 """ 221 Normalizes an unquoted identifier to either lower or upper case, thus essentially 222 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 223 they will be normalized regardless of being quoted or not. 224 """ 225 if isinstance(expression, exp.Identifier) and ( 226 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 227 ): 228 expression.set( 229 "this", 230 expression.this.upper() 231 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 232 else expression.this.lower(), 233 ) 234 235 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
237 @classmethod 238 def case_sensitive(cls, text: str) -> bool: 239 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 240 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 241 return False 242 243 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 244 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
246 @classmethod 247 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 248 """Checks if text can be identified given an identify option. 249 250 Args: 251 text: The text to check. 252 identify: 253 "always" or `True`: Always returns true. 254 "safe": True if the identifier is case-insensitive. 255 256 Returns: 257 Whether or not the given text can be identified. 258 """ 259 if identify is True or identify == "always": 260 return True 261 262 if identify == "safe": 263 return not cls.case_sensitive(text) 264 265 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
267 @classmethod 268 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 269 if isinstance(expression, exp.Identifier): 270 name = expression.this 271 expression.set( 272 "quoted", 273 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 274 ) 275 276 return expression
396def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 397 this = self.sql(expression, "this") 398 substr = self.sql(expression, "substr") 399 position = self.sql(expression, "position") 400 if position: 401 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 402 return f"STRPOS({this}, {substr})"
411def var_map_sql( 412 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 413) -> str: 414 keys = expression.args["keys"] 415 values = expression.args["values"] 416 417 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 418 self.unsupported("Cannot convert array columns into map.") 419 return self.func(map_func_name, keys, values) 420 421 args = [] 422 for key, value in zip(keys.expressions, values.expressions): 423 args.append(self.sql(key)) 424 args.append(self.sql(value)) 425 426 return self.func(map_func_name, *args)
429def format_time_lambda( 430 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 431) -> t.Callable[[t.List], E]: 432 """Helper used for time expressions. 433 434 Args: 435 exp_class: the expression class to instantiate. 436 dialect: target sql dialect. 437 default: the default format, True being time. 438 439 Returns: 440 A callable that can be used to return the appropriately formatted time expression. 441 """ 442 443 def _format_time(args: t.List): 444 return exp_class( 445 this=seq_get(args, 0), 446 format=Dialect[dialect].format_time( 447 seq_get(args, 1) 448 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 449 ), 450 ) 451 452 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
455def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 456 """ 457 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 458 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 459 columns are removed from the create statement. 460 """ 461 has_schema = isinstance(expression.this, exp.Schema) 462 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 463 464 if has_schema and is_partitionable: 465 expression = expression.copy() 466 prop = expression.find(exp.PartitionedByProperty) 467 if prop and prop.this and not isinstance(prop.this, exp.Schema): 468 schema = expression.this 469 columns = {v.name.upper() for v in prop.this.expressions} 470 partitions = [col for col in schema.expressions if col.name.upper() in columns] 471 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 472 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 473 expression.set("this", schema) 474 475 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
478def parse_date_delta( 479 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 480) -> t.Callable[[t.List], E]: 481 def inner_func(args: t.List) -> E: 482 unit_based = len(args) == 3 483 this = args[2] if unit_based else seq_get(args, 0) 484 unit = args[0] if unit_based else exp.Literal.string("DAY") 485 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 486 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 487 488 return inner_func
491def parse_date_delta_with_interval( 492 expression_class: t.Type[E], 493) -> t.Callable[[t.List], t.Optional[E]]: 494 def func(args: t.List) -> t.Optional[E]: 495 if len(args) < 2: 496 return None 497 498 interval = args[1] 499 expression = interval.this 500 if expression and expression.is_string: 501 expression = exp.Literal.number(expression.this) 502 503 return expression_class( 504 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 505 ) 506 507 return func
574def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 575 cond = expression.this 576 577 if isinstance(expression.this, exp.Distinct): 578 cond = expression.this.expressions[0] 579 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 580 581 return self.func("sum", exp.func("if", cond, 1, 0))
584def trim_sql(self: Generator, expression: exp.Trim) -> str: 585 target = self.sql(expression, "this") 586 trim_type = self.sql(expression, "position") 587 remove_chars = self.sql(expression, "expression") 588 collation = self.sql(expression, "collation") 589 590 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 591 if not remove_chars and not collation: 592 return self.trim_sql(expression) 593 594 trim_type = f"{trim_type} " if trim_type else "" 595 remove_chars = f"{remove_chars} " if remove_chars else "" 596 from_part = "FROM " if trim_type or remove_chars else "" 597 collation = f" COLLATE {collation}" if collation else "" 598 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
605def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 606 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 607 _dialect = Dialect.get_or_raise(dialect) 608 time_format = self.format_time(expression) 609 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 610 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 611 return f"CAST({self.sql(expression, 'this')} AS DATE)" 612 613 return _ts_or_ds_to_date_sql
625def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 626 names = [] 627 for agg in aggregations: 628 if isinstance(agg, exp.Alias): 629 names.append(agg.alias) 630 else: 631 """ 632 This case corresponds to aggregations without aliases being used as suffixes 633 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 634 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 635 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 636 """ 637 agg_all_unquoted = agg.transform( 638 lambda node: exp.Identifier(this=node.name, quoted=False) 639 if isinstance(node, exp.Identifier) 640 else node 641 ) 642 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 643 644 return names