sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.generator import Generator 9from sqlglot.helper import flatten, seq_get 10from sqlglot.parser import Parser 11from sqlglot.time import format_time 12from sqlglot.tokens import Token, Tokenizer, TokenType 13from sqlglot.trie import new_trie 14 15 16class Dialects(str, Enum): 17 DIALECT = "" 18 19 BIGQUERY = "bigquery" 20 CLICKHOUSE = "clickhouse" 21 DATABRICKS = "databricks" 22 DRILL = "drill" 23 DUCKDB = "duckdb" 24 HIVE = "hive" 25 MYSQL = "mysql" 26 ORACLE = "oracle" 27 POSTGRES = "postgres" 28 PRESTO = "presto" 29 REDSHIFT = "redshift" 30 SNOWFLAKE = "snowflake" 31 SPARK = "spark" 32 SPARK2 = "spark2" 33 SQLITE = "sqlite" 34 STARROCKS = "starrocks" 35 TABLEAU = "tableau" 36 TERADATA = "teradata" 37 TRINO = "trino" 38 TSQL = "tsql" 39 40 41class _Dialect(type): 42 classes: t.Dict[str, t.Type[Dialect]] = {} 43 44 def __eq__(cls, other: t.Any) -> bool: 45 if cls is other: 46 return True 47 if isinstance(other, str): 48 return cls is cls.get(other) 49 if isinstance(other, Dialect): 50 return cls is type(other) 51 52 return False 53 54 def __hash__(cls) -> int: 55 return hash(cls.__name__.lower()) 56 57 @classmethod 58 def __getitem__(cls, key: str) -> t.Type[Dialect]: 59 return cls.classes[key] 60 61 @classmethod 62 def get( 63 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 64 ) -> t.Optional[t.Type[Dialect]]: 65 return cls.classes.get(key, default) 66 67 def __new__(cls, clsname, bases, attrs): 68 klass = super().__new__(cls, clsname, bases, attrs) 69 enum = Dialects.__members__.get(clsname.upper()) 70 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 71 72 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 73 klass.FORMAT_TRIE = ( 74 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 75 ) 76 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 77 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 78 79 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 80 klass.parser_class = getattr(klass, "Parser", Parser) 81 klass.generator_class = getattr(klass, "Generator", Generator) 82 83 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 84 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 85 klass.tokenizer_class._IDENTIFIERS.items() 86 )[0] 87 88 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 89 return next( 90 ( 91 (s, e) 92 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 93 if t == token_type 94 ), 95 (None, None), 96 ) 97 98 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 99 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 100 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 101 klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING) 102 103 dialect_properties = { 104 **{ 105 k: v 106 for k, v in vars(klass).items() 107 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 108 }, 109 "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], 110 "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], 111 } 112 113 if enum not in ("", "bigquery"): 114 dialect_properties["SELECT_KINDS"] = () 115 116 # Pass required dialect properties to the tokenizer, parser and generator classes 117 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 118 for name, value in dialect_properties.items(): 119 if hasattr(subclass, name): 120 setattr(subclass, name, value) 121 122 if not klass.STRICT_STRING_CONCAT: 123 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 124 125 klass.generator_class.can_identify = klass.can_identify 126 127 return klass 128 129 130class Dialect(metaclass=_Dialect): 131 # Determines the base index offset for arrays 132 INDEX_OFFSET = 0 133 134 # If true unnest table aliases are considered only as column aliases 135 UNNEST_COLUMN_ONLY = False 136 137 # Determines whether or not the table alias comes after tablesample 138 ALIAS_POST_TABLESAMPLE = False 139 140 # Determines whether or not unquoted identifiers are resolved as uppercase 141 # When set to None, it means that the dialect treats all identifiers as case-insensitive 142 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 143 144 # Determines whether or not an unquoted identifier can start with a digit 145 IDENTIFIERS_CAN_START_WITH_DIGIT = False 146 147 # Determines whether or not CONCAT's arguments must be strings 148 STRICT_STRING_CONCAT = False 149 150 # Determines how function names are going to be normalized 151 NORMALIZE_FUNCTIONS: bool | str = "upper" 152 153 # Indicates the default null ordering method to use if not explicitly set 154 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 155 NULL_ORDERING = "nulls_are_small" 156 157 DATE_FORMAT = "'%Y-%m-%d'" 158 DATEINT_FORMAT = "'%Y%m%d'" 159 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 160 161 # Custom time mappings in which the key represents dialect time format 162 # and the value represents a python time format 163 TIME_MAPPING: t.Dict[str, str] = {} 164 165 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 166 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 167 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 168 FORMAT_MAPPING: t.Dict[str, str] = {} 169 170 # Autofilled 171 tokenizer_class = Tokenizer 172 parser_class = Parser 173 generator_class = Generator 174 175 # A trie of the time_mapping keys 176 TIME_TRIE: t.Dict = {} 177 FORMAT_TRIE: t.Dict = {} 178 179 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 180 INVERSE_TIME_TRIE: t.Dict = {} 181 182 def __eq__(self, other: t.Any) -> bool: 183 return type(self) == other 184 185 def __hash__(self) -> int: 186 return hash(type(self)) 187 188 @classmethod 189 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 190 if not dialect: 191 return cls 192 if isinstance(dialect, _Dialect): 193 return dialect 194 if isinstance(dialect, Dialect): 195 return dialect.__class__ 196 197 result = cls.get(dialect) 198 if not result: 199 raise ValueError(f"Unknown dialect '{dialect}'") 200 201 return result 202 203 @classmethod 204 def format_time( 205 cls, expression: t.Optional[str | exp.Expression] 206 ) -> t.Optional[exp.Expression]: 207 if isinstance(expression, str): 208 return exp.Literal.string( 209 # the time formats are quoted 210 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 211 ) 212 213 if expression and expression.is_string: 214 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 215 216 return expression 217 218 @classmethod 219 def normalize_identifier(cls, expression: E) -> E: 220 """ 221 Normalizes an unquoted identifier to either lower or upper case, thus essentially 222 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 223 they will be normalized regardless of being quoted or not. 224 """ 225 if isinstance(expression, exp.Identifier) and ( 226 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 227 ): 228 expression.set( 229 "this", 230 expression.this.upper() 231 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 232 else expression.this.lower(), 233 ) 234 235 return expression 236 237 @classmethod 238 def case_sensitive(cls, text: str) -> bool: 239 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 240 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 241 return False 242 243 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 244 return any(unsafe(char) for char in text) 245 246 @classmethod 247 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 248 """Checks if text can be identified given an identify option. 249 250 Args: 251 text: The text to check. 252 identify: 253 "always" or `True`: Always returns true. 254 "safe": True if the identifier is case-insensitive. 255 256 Returns: 257 Whether or not the given text can be identified. 258 """ 259 if identify is True or identify == "always": 260 return True 261 262 if identify == "safe": 263 return not cls.case_sensitive(text) 264 265 return False 266 267 @classmethod 268 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 269 if isinstance(expression, exp.Identifier): 270 name = expression.this 271 expression.set( 272 "quoted", 273 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 274 ) 275 276 return expression 277 278 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 279 return self.parser(**opts).parse(self.tokenize(sql), sql) 280 281 def parse_into( 282 self, expression_type: exp.IntoType, sql: str, **opts 283 ) -> t.List[t.Optional[exp.Expression]]: 284 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 285 286 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 287 return self.generator(**opts).generate(expression) 288 289 def transpile(self, sql: str, **opts) -> t.List[str]: 290 return [self.generate(expression, **opts) for expression in self.parse(sql)] 291 292 def tokenize(self, sql: str) -> t.List[Token]: 293 return self.tokenizer.tokenize(sql) 294 295 @property 296 def tokenizer(self) -> Tokenizer: 297 if not hasattr(self, "_tokenizer"): 298 self._tokenizer = self.tokenizer_class() 299 return self._tokenizer 300 301 def parser(self, **opts) -> Parser: 302 return self.parser_class(**opts) 303 304 def generator(self, **opts) -> Generator: 305 return self.generator_class(**opts) 306 307 308DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 309 310 311def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 312 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 313 314 315def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 316 if expression.args.get("accuracy"): 317 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 318 return self.func("APPROX_COUNT_DISTINCT", expression.this) 319 320 321def if_sql(self: Generator, expression: exp.If) -> str: 322 return self.func( 323 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 324 ) 325 326 327def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 328 return self.binary(expression, "->") 329 330 331def arrow_json_extract_scalar_sql( 332 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 333) -> str: 334 return self.binary(expression, "->>") 335 336 337def inline_array_sql(self: Generator, expression: exp.Array) -> str: 338 return f"[{self.expressions(expression)}]" 339 340 341def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 342 return self.like_sql( 343 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 344 ) 345 346 347def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 348 zone = self.sql(expression, "this") 349 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 350 351 352def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 353 if expression.args.get("recursive"): 354 self.unsupported("Recursive CTEs are unsupported") 355 expression.args["recursive"] = False 356 return self.with_sql(expression) 357 358 359def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 360 n = self.sql(expression, "this") 361 d = self.sql(expression, "expression") 362 return f"IF({d} <> 0, {n} / {d}, NULL)" 363 364 365def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 366 self.unsupported("TABLESAMPLE unsupported") 367 return self.sql(expression.this) 368 369 370def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 371 self.unsupported("PIVOT unsupported") 372 return "" 373 374 375def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 376 return self.cast_sql(expression) 377 378 379def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 380 self.unsupported("Properties unsupported") 381 return "" 382 383 384def no_comment_column_constraint_sql( 385 self: Generator, expression: exp.CommentColumnConstraint 386) -> str: 387 self.unsupported("CommentColumnConstraint unsupported") 388 return "" 389 390 391def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 392 self.unsupported("MAP_FROM_ENTRIES unsupported") 393 return "" 394 395 396def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 397 this = self.sql(expression, "this") 398 substr = self.sql(expression, "substr") 399 position = self.sql(expression, "position") 400 if position: 401 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 402 return f"STRPOS({this}, {substr})" 403 404 405def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 406 this = self.sql(expression, "this") 407 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 408 return f"{this}.{struct_key}" 409 410 411def var_map_sql( 412 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 413) -> str: 414 keys = expression.args["keys"] 415 values = expression.args["values"] 416 417 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 418 self.unsupported("Cannot convert array columns into map.") 419 return self.func(map_func_name, keys, values) 420 421 args = [] 422 for key, value in zip(keys.expressions, values.expressions): 423 args.append(self.sql(key)) 424 args.append(self.sql(value)) 425 426 return self.func(map_func_name, *args) 427 428 429def format_time_lambda( 430 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 431) -> t.Callable[[t.List], E]: 432 """Helper used for time expressions. 433 434 Args: 435 exp_class: the expression class to instantiate. 436 dialect: target sql dialect. 437 default: the default format, True being time. 438 439 Returns: 440 A callable that can be used to return the appropriately formatted time expression. 441 """ 442 443 def _format_time(args: t.List): 444 return exp_class( 445 this=seq_get(args, 0), 446 format=Dialect[dialect].format_time( 447 seq_get(args, 1) 448 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 449 ), 450 ) 451 452 return _format_time 453 454 455def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 456 """ 457 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 458 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 459 columns are removed from the create statement. 460 """ 461 has_schema = isinstance(expression.this, exp.Schema) 462 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 463 464 if has_schema and is_partitionable: 465 expression = expression.copy() 466 prop = expression.find(exp.PartitionedByProperty) 467 if prop and prop.this and not isinstance(prop.this, exp.Schema): 468 schema = expression.this 469 columns = {v.name.upper() for v in prop.this.expressions} 470 partitions = [col for col in schema.expressions if col.name.upper() in columns] 471 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 472 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 473 expression.set("this", schema) 474 475 return self.create_sql(expression) 476 477 478def parse_date_delta( 479 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 480) -> t.Callable[[t.List], E]: 481 def inner_func(args: t.List) -> E: 482 unit_based = len(args) == 3 483 this = args[2] if unit_based else seq_get(args, 0) 484 unit = args[0] if unit_based else exp.Literal.string("DAY") 485 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 486 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 487 488 return inner_func 489 490 491def parse_date_delta_with_interval( 492 expression_class: t.Type[E], 493) -> t.Callable[[t.List], t.Optional[E]]: 494 def func(args: t.List) -> t.Optional[E]: 495 if len(args) < 2: 496 return None 497 498 interval = args[1] 499 expression = interval.this 500 if expression and expression.is_string: 501 expression = exp.Literal.number(expression.this) 502 503 return expression_class( 504 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 505 ) 506 507 return func 508 509 510def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 511 unit = seq_get(args, 0) 512 this = seq_get(args, 1) 513 514 if isinstance(this, exp.Cast) and this.is_type("date"): 515 return exp.DateTrunc(unit=unit, this=this) 516 return exp.TimestampTrunc(this=this, unit=unit) 517 518 519def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 520 return self.func( 521 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 522 ) 523 524 525def locate_to_strposition(args: t.List) -> exp.Expression: 526 return exp.StrPosition( 527 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 528 ) 529 530 531def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 532 return self.func( 533 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 534 ) 535 536 537def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 538 expression = expression.copy() 539 return self.sql( 540 exp.Substring( 541 this=expression.this, start=exp.Literal.number(1), length=expression.expression 542 ) 543 ) 544 545 546def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 547 expression = expression.copy() 548 return self.sql( 549 exp.Substring( 550 this=expression.this, 551 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 552 ) 553 ) 554 555 556def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 557 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 558 559 560def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 561 return f"CAST({self.sql(expression, 'this')} AS DATE)" 562 563 564def min_or_least(self: Generator, expression: exp.Min) -> str: 565 name = "LEAST" if expression.expressions else "MIN" 566 return rename_func(name)(self, expression) 567 568 569def max_or_greatest(self: Generator, expression: exp.Max) -> str: 570 name = "GREATEST" if expression.expressions else "MAX" 571 return rename_func(name)(self, expression) 572 573 574def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 575 cond = expression.this 576 577 if isinstance(expression.this, exp.Distinct): 578 cond = expression.this.expressions[0] 579 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 580 581 return self.func("sum", exp.func("if", cond, 1, 0)) 582 583 584def trim_sql(self: Generator, expression: exp.Trim) -> str: 585 target = self.sql(expression, "this") 586 trim_type = self.sql(expression, "position") 587 remove_chars = self.sql(expression, "expression") 588 collation = self.sql(expression, "collation") 589 590 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 591 if not remove_chars and not collation: 592 return self.trim_sql(expression) 593 594 trim_type = f"{trim_type} " if trim_type else "" 595 remove_chars = f"{remove_chars} " if remove_chars else "" 596 from_part = "FROM " if trim_type or remove_chars else "" 597 collation = f" COLLATE {collation}" if collation else "" 598 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 599 600 601def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 602 return self.func("STRPTIME", expression.this, self.format_time(expression)) 603 604 605def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 606 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 607 _dialect = Dialect.get_or_raise(dialect) 608 time_format = self.format_time(expression) 609 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 610 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 611 return f"CAST({self.sql(expression, 'this')} AS DATE)" 612 613 return _ts_or_ds_to_date_sql 614 615 616def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 617 this, *rest_args = expression.expressions 618 for arg in rest_args: 619 this = exp.DPipe(this=this, expression=arg) 620 621 return self.sql(this) 622 623 624# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator 625def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 626 names = [] 627 for agg in aggregations: 628 if isinstance(agg, exp.Alias): 629 names.append(agg.alias) 630 else: 631 """ 632 This case corresponds to aggregations without aliases being used as suffixes 633 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 634 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 635 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 636 """ 637 agg_all_unquoted = agg.transform( 638 lambda node: exp.Identifier(this=node.name, quoted=False) 639 if isinstance(node, exp.Identifier) 640 else node 641 ) 642 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 643 644 return names
17class Dialects(str, Enum): 18 DIALECT = "" 19 20 BIGQUERY = "bigquery" 21 CLICKHOUSE = "clickhouse" 22 DATABRICKS = "databricks" 23 DRILL = "drill" 24 DUCKDB = "duckdb" 25 HIVE = "hive" 26 MYSQL = "mysql" 27 ORACLE = "oracle" 28 POSTGRES = "postgres" 29 PRESTO = "presto" 30 REDSHIFT = "redshift" 31 SNOWFLAKE = "snowflake" 32 SPARK = "spark" 33 SPARK2 = "spark2" 34 SQLITE = "sqlite" 35 STARROCKS = "starrocks" 36 TABLEAU = "tableau" 37 TERADATA = "teradata" 38 TRINO = "trino" 39 TSQL = "tsql"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
131class Dialect(metaclass=_Dialect): 132 # Determines the base index offset for arrays 133 INDEX_OFFSET = 0 134 135 # If true unnest table aliases are considered only as column aliases 136 UNNEST_COLUMN_ONLY = False 137 138 # Determines whether or not the table alias comes after tablesample 139 ALIAS_POST_TABLESAMPLE = False 140 141 # Determines whether or not unquoted identifiers are resolved as uppercase 142 # When set to None, it means that the dialect treats all identifiers as case-insensitive 143 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 144 145 # Determines whether or not an unquoted identifier can start with a digit 146 IDENTIFIERS_CAN_START_WITH_DIGIT = False 147 148 # Determines whether or not CONCAT's arguments must be strings 149 STRICT_STRING_CONCAT = False 150 151 # Determines how function names are going to be normalized 152 NORMALIZE_FUNCTIONS: bool | str = "upper" 153 154 # Indicates the default null ordering method to use if not explicitly set 155 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 156 NULL_ORDERING = "nulls_are_small" 157 158 DATE_FORMAT = "'%Y-%m-%d'" 159 DATEINT_FORMAT = "'%Y%m%d'" 160 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 161 162 # Custom time mappings in which the key represents dialect time format 163 # and the value represents a python time format 164 TIME_MAPPING: t.Dict[str, str] = {} 165 166 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 167 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 168 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 169 FORMAT_MAPPING: t.Dict[str, str] = {} 170 171 # Autofilled 172 tokenizer_class = Tokenizer 173 parser_class = Parser 174 generator_class = Generator 175 176 # A trie of the time_mapping keys 177 TIME_TRIE: t.Dict = {} 178 FORMAT_TRIE: t.Dict = {} 179 180 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 181 INVERSE_TIME_TRIE: t.Dict = {} 182 183 def __eq__(self, other: t.Any) -> bool: 184 return type(self) == other 185 186 def __hash__(self) -> int: 187 return hash(type(self)) 188 189 @classmethod 190 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 191 if not dialect: 192 return cls 193 if isinstance(dialect, _Dialect): 194 return dialect 195 if isinstance(dialect, Dialect): 196 return dialect.__class__ 197 198 result = cls.get(dialect) 199 if not result: 200 raise ValueError(f"Unknown dialect '{dialect}'") 201 202 return result 203 204 @classmethod 205 def format_time( 206 cls, expression: t.Optional[str | exp.Expression] 207 ) -> t.Optional[exp.Expression]: 208 if isinstance(expression, str): 209 return exp.Literal.string( 210 # the time formats are quoted 211 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 212 ) 213 214 if expression and expression.is_string: 215 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 216 217 return expression 218 219 @classmethod 220 def normalize_identifier(cls, expression: E) -> E: 221 """ 222 Normalizes an unquoted identifier to either lower or upper case, thus essentially 223 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 224 they will be normalized regardless of being quoted or not. 225 """ 226 if isinstance(expression, exp.Identifier) and ( 227 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 228 ): 229 expression.set( 230 "this", 231 expression.this.upper() 232 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 233 else expression.this.lower(), 234 ) 235 236 return expression 237 238 @classmethod 239 def case_sensitive(cls, text: str) -> bool: 240 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 241 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 242 return False 243 244 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 245 return any(unsafe(char) for char in text) 246 247 @classmethod 248 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 249 """Checks if text can be identified given an identify option. 250 251 Args: 252 text: The text to check. 253 identify: 254 "always" or `True`: Always returns true. 255 "safe": True if the identifier is case-insensitive. 256 257 Returns: 258 Whether or not the given text can be identified. 259 """ 260 if identify is True or identify == "always": 261 return True 262 263 if identify == "safe": 264 return not cls.case_sensitive(text) 265 266 return False 267 268 @classmethod 269 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 270 if isinstance(expression, exp.Identifier): 271 name = expression.this 272 expression.set( 273 "quoted", 274 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 275 ) 276 277 return expression 278 279 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 280 return self.parser(**opts).parse(self.tokenize(sql), sql) 281 282 def parse_into( 283 self, expression_type: exp.IntoType, sql: str, **opts 284 ) -> t.List[t.Optional[exp.Expression]]: 285 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 286 287 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 288 return self.generator(**opts).generate(expression) 289 290 def transpile(self, sql: str, **opts) -> t.List[str]: 291 return [self.generate(expression, **opts) for expression in self.parse(sql)] 292 293 def tokenize(self, sql: str) -> t.List[Token]: 294 return self.tokenizer.tokenize(sql) 295 296 @property 297 def tokenizer(self) -> Tokenizer: 298 if not hasattr(self, "_tokenizer"): 299 self._tokenizer = self.tokenizer_class() 300 return self._tokenizer 301 302 def parser(self, **opts) -> Parser: 303 return self.parser_class(**opts) 304 305 def generator(self, **opts) -> Generator: 306 return self.generator_class(**opts)
189 @classmethod 190 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 191 if not dialect: 192 return cls 193 if isinstance(dialect, _Dialect): 194 return dialect 195 if isinstance(dialect, Dialect): 196 return dialect.__class__ 197 198 result = cls.get(dialect) 199 if not result: 200 raise ValueError(f"Unknown dialect '{dialect}'") 201 202 return result
204 @classmethod 205 def format_time( 206 cls, expression: t.Optional[str | exp.Expression] 207 ) -> t.Optional[exp.Expression]: 208 if isinstance(expression, str): 209 return exp.Literal.string( 210 # the time formats are quoted 211 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 212 ) 213 214 if expression and expression.is_string: 215 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 216 217 return expression
219 @classmethod 220 def normalize_identifier(cls, expression: E) -> E: 221 """ 222 Normalizes an unquoted identifier to either lower or upper case, thus essentially 223 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 224 they will be normalized regardless of being quoted or not. 225 """ 226 if isinstance(expression, exp.Identifier) and ( 227 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 228 ): 229 expression.set( 230 "this", 231 expression.this.upper() 232 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 233 else expression.this.lower(), 234 ) 235 236 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
238 @classmethod 239 def case_sensitive(cls, text: str) -> bool: 240 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 241 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 242 return False 243 244 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 245 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
247 @classmethod 248 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 249 """Checks if text can be identified given an identify option. 250 251 Args: 252 text: The text to check. 253 identify: 254 "always" or `True`: Always returns true. 255 "safe": True if the identifier is case-insensitive. 256 257 Returns: 258 Whether or not the given text can be identified. 259 """ 260 if identify is True or identify == "always": 261 return True 262 263 if identify == "safe": 264 return not cls.case_sensitive(text) 265 266 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
268 @classmethod 269 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 270 if isinstance(expression, exp.Identifier): 271 name = expression.this 272 expression.set( 273 "quoted", 274 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 275 ) 276 277 return expression
397def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 398 this = self.sql(expression, "this") 399 substr = self.sql(expression, "substr") 400 position = self.sql(expression, "position") 401 if position: 402 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 403 return f"STRPOS({this}, {substr})"
412def var_map_sql( 413 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 414) -> str: 415 keys = expression.args["keys"] 416 values = expression.args["values"] 417 418 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 419 self.unsupported("Cannot convert array columns into map.") 420 return self.func(map_func_name, keys, values) 421 422 args = [] 423 for key, value in zip(keys.expressions, values.expressions): 424 args.append(self.sql(key)) 425 args.append(self.sql(value)) 426 427 return self.func(map_func_name, *args)
430def format_time_lambda( 431 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 432) -> t.Callable[[t.List], E]: 433 """Helper used for time expressions. 434 435 Args: 436 exp_class: the expression class to instantiate. 437 dialect: target sql dialect. 438 default: the default format, True being time. 439 440 Returns: 441 A callable that can be used to return the appropriately formatted time expression. 442 """ 443 444 def _format_time(args: t.List): 445 return exp_class( 446 this=seq_get(args, 0), 447 format=Dialect[dialect].format_time( 448 seq_get(args, 1) 449 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 450 ), 451 ) 452 453 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
456def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 457 """ 458 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 459 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 460 columns are removed from the create statement. 461 """ 462 has_schema = isinstance(expression.this, exp.Schema) 463 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 464 465 if has_schema and is_partitionable: 466 expression = expression.copy() 467 prop = expression.find(exp.PartitionedByProperty) 468 if prop and prop.this and not isinstance(prop.this, exp.Schema): 469 schema = expression.this 470 columns = {v.name.upper() for v in prop.this.expressions} 471 partitions = [col for col in schema.expressions if col.name.upper() in columns] 472 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 473 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 474 expression.set("this", schema) 475 476 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
479def parse_date_delta( 480 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 481) -> t.Callable[[t.List], E]: 482 def inner_func(args: t.List) -> E: 483 unit_based = len(args) == 3 484 this = args[2] if unit_based else seq_get(args, 0) 485 unit = args[0] if unit_based else exp.Literal.string("DAY") 486 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 487 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 488 489 return inner_func
492def parse_date_delta_with_interval( 493 expression_class: t.Type[E], 494) -> t.Callable[[t.List], t.Optional[E]]: 495 def func(args: t.List) -> t.Optional[E]: 496 if len(args) < 2: 497 return None 498 499 interval = args[1] 500 expression = interval.this 501 if expression and expression.is_string: 502 expression = exp.Literal.number(expression.this) 503 504 return expression_class( 505 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 506 ) 507 508 return func
575def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 576 cond = expression.this 577 578 if isinstance(expression.this, exp.Distinct): 579 cond = expression.this.expressions[0] 580 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 581 582 return self.func("sum", exp.func("if", cond, 1, 0))
585def trim_sql(self: Generator, expression: exp.Trim) -> str: 586 target = self.sql(expression, "this") 587 trim_type = self.sql(expression, "position") 588 remove_chars = self.sql(expression, "expression") 589 collation = self.sql(expression, "collation") 590 591 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 592 if not remove_chars and not collation: 593 return self.trim_sql(expression) 594 595 trim_type = f"{trim_type} " if trim_type else "" 596 remove_chars = f"{remove_chars} " if remove_chars else "" 597 from_part = "FROM " if trim_type or remove_chars else "" 598 collation = f" COLLATE {collation}" if collation else "" 599 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
606def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 607 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 608 _dialect = Dialect.get_or_raise(dialect) 609 time_format = self.format_time(expression) 610 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 611 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 612 return f"CAST({self.sql(expression, 'this')} AS DATE)" 613 614 return _ts_or_ds_to_date_sql
626def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 627 names = [] 628 for agg in aggregations: 629 if isinstance(agg, exp.Alias): 630 names.append(agg.alias) 631 else: 632 """ 633 This case corresponds to aggregations without aliases being used as suffixes 634 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 635 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 636 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 637 """ 638 agg_all_unquoted = agg.transform( 639 lambda node: exp.Identifier(this=node.name, quoted=False) 640 if isinstance(node, exp.Identifier) 641 else node 642 ) 643 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 644 645 return names