sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.generator import Generator 9from sqlglot.helper import flatten, seq_get 10from sqlglot.parser import Parser 11from sqlglot.time import format_time 12from sqlglot.tokens import Token, Tokenizer, TokenType 13from sqlglot.trie import new_trie 14 15B = t.TypeVar("B", bound=exp.Binary) 16 17 18class Dialects(str, Enum): 19 DIALECT = "" 20 21 BIGQUERY = "bigquery" 22 CLICKHOUSE = "clickhouse" 23 DATABRICKS = "databricks" 24 DRILL = "drill" 25 DUCKDB = "duckdb" 26 HIVE = "hive" 27 MYSQL = "mysql" 28 ORACLE = "oracle" 29 POSTGRES = "postgres" 30 PRESTO = "presto" 31 REDSHIFT = "redshift" 32 SNOWFLAKE = "snowflake" 33 SPARK = "spark" 34 SPARK2 = "spark2" 35 SQLITE = "sqlite" 36 STARROCKS = "starrocks" 37 TABLEAU = "tableau" 38 TERADATA = "teradata" 39 TRINO = "trino" 40 TSQL = "tsql" 41 42 43class _Dialect(type): 44 classes: t.Dict[str, t.Type[Dialect]] = {} 45 46 def __eq__(cls, other: t.Any) -> bool: 47 if cls is other: 48 return True 49 if isinstance(other, str): 50 return cls is cls.get(other) 51 if isinstance(other, Dialect): 52 return cls is type(other) 53 54 return False 55 56 def __hash__(cls) -> int: 57 return hash(cls.__name__.lower()) 58 59 @classmethod 60 def __getitem__(cls, key: str) -> t.Type[Dialect]: 61 return cls.classes[key] 62 63 @classmethod 64 def get( 65 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 66 ) -> t.Optional[t.Type[Dialect]]: 67 return cls.classes.get(key, default) 68 69 def __new__(cls, clsname, bases, attrs): 70 klass = super().__new__(cls, clsname, bases, attrs) 71 enum = Dialects.__members__.get(clsname.upper()) 72 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 73 74 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 75 klass.FORMAT_TRIE = ( 76 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 77 ) 78 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 79 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 80 81 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 82 klass.parser_class = getattr(klass, "Parser", Parser) 83 klass.generator_class = getattr(klass, "Generator", Generator) 84 85 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 86 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 87 klass.tokenizer_class._IDENTIFIERS.items() 88 )[0] 89 90 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 91 return next( 92 ( 93 (s, e) 94 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 95 if t == token_type 96 ), 97 (None, None), 98 ) 99 100 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 101 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 102 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 103 104 dialect_properties = { 105 **{ 106 k: v 107 for k, v in vars(klass).items() 108 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 109 }, 110 "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], 111 "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], 112 } 113 114 if enum not in ("", "bigquery"): 115 dialect_properties["SELECT_KINDS"] = () 116 117 # Pass required dialect properties to the tokenizer, parser and generator classes 118 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 119 for name, value in dialect_properties.items(): 120 if hasattr(subclass, name): 121 setattr(subclass, name, value) 122 123 if not klass.STRICT_STRING_CONCAT: 124 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 125 126 klass.generator_class.can_identify = klass.can_identify 127 128 return klass 129 130 131class Dialect(metaclass=_Dialect): 132 # Determines the base index offset for arrays 133 INDEX_OFFSET = 0 134 135 # If true unnest table aliases are considered only as column aliases 136 UNNEST_COLUMN_ONLY = False 137 138 # Determines whether or not the table alias comes after tablesample 139 ALIAS_POST_TABLESAMPLE = False 140 141 # Determines whether or not unquoted identifiers are resolved as uppercase 142 # When set to None, it means that the dialect treats all identifiers as case-insensitive 143 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 144 145 # Determines whether or not an unquoted identifier can start with a digit 146 IDENTIFIERS_CAN_START_WITH_DIGIT = False 147 148 # Determines whether or not CONCAT's arguments must be strings 149 STRICT_STRING_CONCAT = False 150 151 # Determines how function names are going to be normalized 152 NORMALIZE_FUNCTIONS: bool | str = "upper" 153 154 # Indicates the default null ordering method to use if not explicitly set 155 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 156 NULL_ORDERING = "nulls_are_small" 157 158 DATE_FORMAT = "'%Y-%m-%d'" 159 DATEINT_FORMAT = "'%Y%m%d'" 160 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 161 162 # Custom time mappings in which the key represents dialect time format 163 # and the value represents a python time format 164 TIME_MAPPING: t.Dict[str, str] = {} 165 166 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 167 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 168 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 169 FORMAT_MAPPING: t.Dict[str, str] = {} 170 171 # Autofilled 172 tokenizer_class = Tokenizer 173 parser_class = Parser 174 generator_class = Generator 175 176 # A trie of the time_mapping keys 177 TIME_TRIE: t.Dict = {} 178 FORMAT_TRIE: t.Dict = {} 179 180 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 181 INVERSE_TIME_TRIE: t.Dict = {} 182 183 def __eq__(self, other: t.Any) -> bool: 184 return type(self) == other 185 186 def __hash__(self) -> int: 187 return hash(type(self)) 188 189 @classmethod 190 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 191 if not dialect: 192 return cls 193 if isinstance(dialect, _Dialect): 194 return dialect 195 if isinstance(dialect, Dialect): 196 return dialect.__class__ 197 198 result = cls.get(dialect) 199 if not result: 200 raise ValueError(f"Unknown dialect '{dialect}'") 201 202 return result 203 204 @classmethod 205 def format_time( 206 cls, expression: t.Optional[str | exp.Expression] 207 ) -> t.Optional[exp.Expression]: 208 if isinstance(expression, str): 209 return exp.Literal.string( 210 # the time formats are quoted 211 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 212 ) 213 214 if expression and expression.is_string: 215 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 216 217 return expression 218 219 @classmethod 220 def normalize_identifier(cls, expression: E) -> E: 221 """ 222 Normalizes an unquoted identifier to either lower or upper case, thus essentially 223 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 224 they will be normalized regardless of being quoted or not. 225 """ 226 if isinstance(expression, exp.Identifier) and ( 227 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 228 ): 229 expression.set( 230 "this", 231 expression.this.upper() 232 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 233 else expression.this.lower(), 234 ) 235 236 return expression 237 238 @classmethod 239 def case_sensitive(cls, text: str) -> bool: 240 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 241 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 242 return False 243 244 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 245 return any(unsafe(char) for char in text) 246 247 @classmethod 248 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 249 """Checks if text can be identified given an identify option. 250 251 Args: 252 text: The text to check. 253 identify: 254 "always" or `True`: Always returns true. 255 "safe": True if the identifier is case-insensitive. 256 257 Returns: 258 Whether or not the given text can be identified. 259 """ 260 if identify is True or identify == "always": 261 return True 262 263 if identify == "safe": 264 return not cls.case_sensitive(text) 265 266 return False 267 268 @classmethod 269 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 270 if isinstance(expression, exp.Identifier): 271 name = expression.this 272 expression.set( 273 "quoted", 274 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 275 ) 276 277 return expression 278 279 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 280 return self.parser(**opts).parse(self.tokenize(sql), sql) 281 282 def parse_into( 283 self, expression_type: exp.IntoType, sql: str, **opts 284 ) -> t.List[t.Optional[exp.Expression]]: 285 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 286 287 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 288 return self.generator(**opts).generate(expression) 289 290 def transpile(self, sql: str, **opts) -> t.List[str]: 291 return [self.generate(expression, **opts) for expression in self.parse(sql)] 292 293 def tokenize(self, sql: str) -> t.List[Token]: 294 return self.tokenizer.tokenize(sql) 295 296 @property 297 def tokenizer(self) -> Tokenizer: 298 if not hasattr(self, "_tokenizer"): 299 self._tokenizer = self.tokenizer_class() 300 return self._tokenizer 301 302 def parser(self, **opts) -> Parser: 303 return self.parser_class(**opts) 304 305 def generator(self, **opts) -> Generator: 306 return self.generator_class(**opts) 307 308 309DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 310 311 312def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 313 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 314 315 316def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 317 if expression.args.get("accuracy"): 318 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 319 return self.func("APPROX_COUNT_DISTINCT", expression.this) 320 321 322def if_sql(self: Generator, expression: exp.If) -> str: 323 return self.func( 324 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 325 ) 326 327 328def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 329 return self.binary(expression, "->") 330 331 332def arrow_json_extract_scalar_sql( 333 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 334) -> str: 335 return self.binary(expression, "->>") 336 337 338def inline_array_sql(self: Generator, expression: exp.Array) -> str: 339 return f"[{self.expressions(expression)}]" 340 341 342def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 343 return self.like_sql( 344 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 345 ) 346 347 348def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 349 zone = self.sql(expression, "this") 350 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 351 352 353def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 354 if expression.args.get("recursive"): 355 self.unsupported("Recursive CTEs are unsupported") 356 expression.args["recursive"] = False 357 return self.with_sql(expression) 358 359 360def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 361 n = self.sql(expression, "this") 362 d = self.sql(expression, "expression") 363 return f"IF({d} <> 0, {n} / {d}, NULL)" 364 365 366def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 367 self.unsupported("TABLESAMPLE unsupported") 368 return self.sql(expression.this) 369 370 371def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 372 self.unsupported("PIVOT unsupported") 373 return "" 374 375 376def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 377 return self.cast_sql(expression) 378 379 380def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 381 self.unsupported("Properties unsupported") 382 return "" 383 384 385def no_comment_column_constraint_sql( 386 self: Generator, expression: exp.CommentColumnConstraint 387) -> str: 388 self.unsupported("CommentColumnConstraint unsupported") 389 return "" 390 391 392def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 393 self.unsupported("MAP_FROM_ENTRIES unsupported") 394 return "" 395 396 397def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 398 this = self.sql(expression, "this") 399 substr = self.sql(expression, "substr") 400 position = self.sql(expression, "position") 401 if position: 402 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 403 return f"STRPOS({this}, {substr})" 404 405 406def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 407 this = self.sql(expression, "this") 408 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 409 return f"{this}.{struct_key}" 410 411 412def var_map_sql( 413 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 414) -> str: 415 keys = expression.args["keys"] 416 values = expression.args["values"] 417 418 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 419 self.unsupported("Cannot convert array columns into map.") 420 return self.func(map_func_name, keys, values) 421 422 args = [] 423 for key, value in zip(keys.expressions, values.expressions): 424 args.append(self.sql(key)) 425 args.append(self.sql(value)) 426 427 return self.func(map_func_name, *args) 428 429 430def format_time_lambda( 431 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 432) -> t.Callable[[t.List], E]: 433 """Helper used for time expressions. 434 435 Args: 436 exp_class: the expression class to instantiate. 437 dialect: target sql dialect. 438 default: the default format, True being time. 439 440 Returns: 441 A callable that can be used to return the appropriately formatted time expression. 442 """ 443 444 def _format_time(args: t.List): 445 return exp_class( 446 this=seq_get(args, 0), 447 format=Dialect[dialect].format_time( 448 seq_get(args, 1) 449 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 450 ), 451 ) 452 453 return _format_time 454 455 456def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 457 """ 458 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 459 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 460 columns are removed from the create statement. 461 """ 462 has_schema = isinstance(expression.this, exp.Schema) 463 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 464 465 if has_schema and is_partitionable: 466 expression = expression.copy() 467 prop = expression.find(exp.PartitionedByProperty) 468 if prop and prop.this and not isinstance(prop.this, exp.Schema): 469 schema = expression.this 470 columns = {v.name.upper() for v in prop.this.expressions} 471 partitions = [col for col in schema.expressions if col.name.upper() in columns] 472 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 473 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 474 expression.set("this", schema) 475 476 return self.create_sql(expression) 477 478 479def parse_date_delta( 480 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 481) -> t.Callable[[t.List], E]: 482 def inner_func(args: t.List) -> E: 483 unit_based = len(args) == 3 484 this = args[2] if unit_based else seq_get(args, 0) 485 unit = args[0] if unit_based else exp.Literal.string("DAY") 486 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 487 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 488 489 return inner_func 490 491 492def parse_date_delta_with_interval( 493 expression_class: t.Type[E], 494) -> t.Callable[[t.List], t.Optional[E]]: 495 def func(args: t.List) -> t.Optional[E]: 496 if len(args) < 2: 497 return None 498 499 interval = args[1] 500 expression = interval.this 501 if expression and expression.is_string: 502 expression = exp.Literal.number(expression.this) 503 504 return expression_class( 505 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 506 ) 507 508 return func 509 510 511def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 512 unit = seq_get(args, 0) 513 this = seq_get(args, 1) 514 515 if isinstance(this, exp.Cast) and this.is_type("date"): 516 return exp.DateTrunc(unit=unit, this=this) 517 return exp.TimestampTrunc(this=this, unit=unit) 518 519 520def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 521 return self.func( 522 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 523 ) 524 525 526def locate_to_strposition(args: t.List) -> exp.Expression: 527 return exp.StrPosition( 528 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 529 ) 530 531 532def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 533 return self.func( 534 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 535 ) 536 537 538def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 539 expression = expression.copy() 540 return self.sql( 541 exp.Substring( 542 this=expression.this, start=exp.Literal.number(1), length=expression.expression 543 ) 544 ) 545 546 547def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 548 expression = expression.copy() 549 return self.sql( 550 exp.Substring( 551 this=expression.this, 552 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 553 ) 554 ) 555 556 557def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 558 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 559 560 561def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 562 return f"CAST({self.sql(expression, 'this')} AS DATE)" 563 564 565def min_or_least(self: Generator, expression: exp.Min) -> str: 566 name = "LEAST" if expression.expressions else "MIN" 567 return rename_func(name)(self, expression) 568 569 570def max_or_greatest(self: Generator, expression: exp.Max) -> str: 571 name = "GREATEST" if expression.expressions else "MAX" 572 return rename_func(name)(self, expression) 573 574 575def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 576 cond = expression.this 577 578 if isinstance(expression.this, exp.Distinct): 579 cond = expression.this.expressions[0] 580 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 581 582 return self.func("sum", exp.func("if", cond, 1, 0)) 583 584 585def trim_sql(self: Generator, expression: exp.Trim) -> str: 586 target = self.sql(expression, "this") 587 trim_type = self.sql(expression, "position") 588 remove_chars = self.sql(expression, "expression") 589 collation = self.sql(expression, "collation") 590 591 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 592 if not remove_chars and not collation: 593 return self.trim_sql(expression) 594 595 trim_type = f"{trim_type} " if trim_type else "" 596 remove_chars = f"{remove_chars} " if remove_chars else "" 597 from_part = "FROM " if trim_type or remove_chars else "" 598 collation = f" COLLATE {collation}" if collation else "" 599 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 600 601 602def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 603 return self.func("STRPTIME", expression.this, self.format_time(expression)) 604 605 606def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 607 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 608 _dialect = Dialect.get_or_raise(dialect) 609 time_format = self.format_time(expression) 610 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 611 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 612 return f"CAST({self.sql(expression, 'this')} AS DATE)" 613 614 return _ts_or_ds_to_date_sql 615 616 617def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 618 this, *rest_args = expression.expressions 619 for arg in rest_args: 620 this = exp.DPipe(this=this, expression=arg) 621 622 return self.sql(this) 623 624 625def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 626 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 627 if bad_args: 628 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 629 630 return self.func( 631 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 632 ) 633 634 635def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 636 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 637 if bad_args: 638 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 639 640 return self.func( 641 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 642 ) 643 644 645def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 646 names = [] 647 for agg in aggregations: 648 if isinstance(agg, exp.Alias): 649 names.append(agg.alias) 650 else: 651 """ 652 This case corresponds to aggregations without aliases being used as suffixes 653 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 654 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 655 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 656 """ 657 agg_all_unquoted = agg.transform( 658 lambda node: exp.Identifier(this=node.name, quoted=False) 659 if isinstance(node, exp.Identifier) 660 else node 661 ) 662 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 663 664 return names 665 666 667def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 668 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
19class Dialects(str, Enum): 20 DIALECT = "" 21 22 BIGQUERY = "bigquery" 23 CLICKHOUSE = "clickhouse" 24 DATABRICKS = "databricks" 25 DRILL = "drill" 26 DUCKDB = "duckdb" 27 HIVE = "hive" 28 MYSQL = "mysql" 29 ORACLE = "oracle" 30 POSTGRES = "postgres" 31 PRESTO = "presto" 32 REDSHIFT = "redshift" 33 SNOWFLAKE = "snowflake" 34 SPARK = "spark" 35 SPARK2 = "spark2" 36 SQLITE = "sqlite" 37 STARROCKS = "starrocks" 38 TABLEAU = "tableau" 39 TERADATA = "teradata" 40 TRINO = "trino" 41 TSQL = "tsql"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
132class Dialect(metaclass=_Dialect): 133 # Determines the base index offset for arrays 134 INDEX_OFFSET = 0 135 136 # If true unnest table aliases are considered only as column aliases 137 UNNEST_COLUMN_ONLY = False 138 139 # Determines whether or not the table alias comes after tablesample 140 ALIAS_POST_TABLESAMPLE = False 141 142 # Determines whether or not unquoted identifiers are resolved as uppercase 143 # When set to None, it means that the dialect treats all identifiers as case-insensitive 144 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 145 146 # Determines whether or not an unquoted identifier can start with a digit 147 IDENTIFIERS_CAN_START_WITH_DIGIT = False 148 149 # Determines whether or not CONCAT's arguments must be strings 150 STRICT_STRING_CONCAT = False 151 152 # Determines how function names are going to be normalized 153 NORMALIZE_FUNCTIONS: bool | str = "upper" 154 155 # Indicates the default null ordering method to use if not explicitly set 156 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 157 NULL_ORDERING = "nulls_are_small" 158 159 DATE_FORMAT = "'%Y-%m-%d'" 160 DATEINT_FORMAT = "'%Y%m%d'" 161 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 162 163 # Custom time mappings in which the key represents dialect time format 164 # and the value represents a python time format 165 TIME_MAPPING: t.Dict[str, str] = {} 166 167 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 168 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 169 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 170 FORMAT_MAPPING: t.Dict[str, str] = {} 171 172 # Autofilled 173 tokenizer_class = Tokenizer 174 parser_class = Parser 175 generator_class = Generator 176 177 # A trie of the time_mapping keys 178 TIME_TRIE: t.Dict = {} 179 FORMAT_TRIE: t.Dict = {} 180 181 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 182 INVERSE_TIME_TRIE: t.Dict = {} 183 184 def __eq__(self, other: t.Any) -> bool: 185 return type(self) == other 186 187 def __hash__(self) -> int: 188 return hash(type(self)) 189 190 @classmethod 191 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 192 if not dialect: 193 return cls 194 if isinstance(dialect, _Dialect): 195 return dialect 196 if isinstance(dialect, Dialect): 197 return dialect.__class__ 198 199 result = cls.get(dialect) 200 if not result: 201 raise ValueError(f"Unknown dialect '{dialect}'") 202 203 return result 204 205 @classmethod 206 def format_time( 207 cls, expression: t.Optional[str | exp.Expression] 208 ) -> t.Optional[exp.Expression]: 209 if isinstance(expression, str): 210 return exp.Literal.string( 211 # the time formats are quoted 212 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 213 ) 214 215 if expression and expression.is_string: 216 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 217 218 return expression 219 220 @classmethod 221 def normalize_identifier(cls, expression: E) -> E: 222 """ 223 Normalizes an unquoted identifier to either lower or upper case, thus essentially 224 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 225 they will be normalized regardless of being quoted or not. 226 """ 227 if isinstance(expression, exp.Identifier) and ( 228 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 229 ): 230 expression.set( 231 "this", 232 expression.this.upper() 233 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 234 else expression.this.lower(), 235 ) 236 237 return expression 238 239 @classmethod 240 def case_sensitive(cls, text: str) -> bool: 241 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 242 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 243 return False 244 245 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 246 return any(unsafe(char) for char in text) 247 248 @classmethod 249 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 250 """Checks if text can be identified given an identify option. 251 252 Args: 253 text: The text to check. 254 identify: 255 "always" or `True`: Always returns true. 256 "safe": True if the identifier is case-insensitive. 257 258 Returns: 259 Whether or not the given text can be identified. 260 """ 261 if identify is True or identify == "always": 262 return True 263 264 if identify == "safe": 265 return not cls.case_sensitive(text) 266 267 return False 268 269 @classmethod 270 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 271 if isinstance(expression, exp.Identifier): 272 name = expression.this 273 expression.set( 274 "quoted", 275 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 276 ) 277 278 return expression 279 280 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 281 return self.parser(**opts).parse(self.tokenize(sql), sql) 282 283 def parse_into( 284 self, expression_type: exp.IntoType, sql: str, **opts 285 ) -> t.List[t.Optional[exp.Expression]]: 286 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 287 288 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 289 return self.generator(**opts).generate(expression) 290 291 def transpile(self, sql: str, **opts) -> t.List[str]: 292 return [self.generate(expression, **opts) for expression in self.parse(sql)] 293 294 def tokenize(self, sql: str) -> t.List[Token]: 295 return self.tokenizer.tokenize(sql) 296 297 @property 298 def tokenizer(self) -> Tokenizer: 299 if not hasattr(self, "_tokenizer"): 300 self._tokenizer = self.tokenizer_class() 301 return self._tokenizer 302 303 def parser(self, **opts) -> Parser: 304 return self.parser_class(**opts) 305 306 def generator(self, **opts) -> Generator: 307 return self.generator_class(**opts)
190 @classmethod 191 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 192 if not dialect: 193 return cls 194 if isinstance(dialect, _Dialect): 195 return dialect 196 if isinstance(dialect, Dialect): 197 return dialect.__class__ 198 199 result = cls.get(dialect) 200 if not result: 201 raise ValueError(f"Unknown dialect '{dialect}'") 202 203 return result
205 @classmethod 206 def format_time( 207 cls, expression: t.Optional[str | exp.Expression] 208 ) -> t.Optional[exp.Expression]: 209 if isinstance(expression, str): 210 return exp.Literal.string( 211 # the time formats are quoted 212 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 213 ) 214 215 if expression and expression.is_string: 216 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 217 218 return expression
220 @classmethod 221 def normalize_identifier(cls, expression: E) -> E: 222 """ 223 Normalizes an unquoted identifier to either lower or upper case, thus essentially 224 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 225 they will be normalized regardless of being quoted or not. 226 """ 227 if isinstance(expression, exp.Identifier) and ( 228 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 229 ): 230 expression.set( 231 "this", 232 expression.this.upper() 233 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 234 else expression.this.lower(), 235 ) 236 237 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
239 @classmethod 240 def case_sensitive(cls, text: str) -> bool: 241 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 242 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 243 return False 244 245 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 246 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
248 @classmethod 249 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 250 """Checks if text can be identified given an identify option. 251 252 Args: 253 text: The text to check. 254 identify: 255 "always" or `True`: Always returns true. 256 "safe": True if the identifier is case-insensitive. 257 258 Returns: 259 Whether or not the given text can be identified. 260 """ 261 if identify is True or identify == "always": 262 return True 263 264 if identify == "safe": 265 return not cls.case_sensitive(text) 266 267 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
269 @classmethod 270 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 271 if isinstance(expression, exp.Identifier): 272 name = expression.this 273 expression.set( 274 "quoted", 275 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 276 ) 277 278 return expression
398def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 399 this = self.sql(expression, "this") 400 substr = self.sql(expression, "substr") 401 position = self.sql(expression, "position") 402 if position: 403 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 404 return f"STRPOS({this}, {substr})"
413def var_map_sql( 414 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 415) -> str: 416 keys = expression.args["keys"] 417 values = expression.args["values"] 418 419 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 420 self.unsupported("Cannot convert array columns into map.") 421 return self.func(map_func_name, keys, values) 422 423 args = [] 424 for key, value in zip(keys.expressions, values.expressions): 425 args.append(self.sql(key)) 426 args.append(self.sql(value)) 427 428 return self.func(map_func_name, *args)
431def format_time_lambda( 432 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 433) -> t.Callable[[t.List], E]: 434 """Helper used for time expressions. 435 436 Args: 437 exp_class: the expression class to instantiate. 438 dialect: target sql dialect. 439 default: the default format, True being time. 440 441 Returns: 442 A callable that can be used to return the appropriately formatted time expression. 443 """ 444 445 def _format_time(args: t.List): 446 return exp_class( 447 this=seq_get(args, 0), 448 format=Dialect[dialect].format_time( 449 seq_get(args, 1) 450 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 451 ), 452 ) 453 454 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
457def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 458 """ 459 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 460 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 461 columns are removed from the create statement. 462 """ 463 has_schema = isinstance(expression.this, exp.Schema) 464 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 465 466 if has_schema and is_partitionable: 467 expression = expression.copy() 468 prop = expression.find(exp.PartitionedByProperty) 469 if prop and prop.this and not isinstance(prop.this, exp.Schema): 470 schema = expression.this 471 columns = {v.name.upper() for v in prop.this.expressions} 472 partitions = [col for col in schema.expressions if col.name.upper() in columns] 473 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 474 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 475 expression.set("this", schema) 476 477 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
480def parse_date_delta( 481 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 482) -> t.Callable[[t.List], E]: 483 def inner_func(args: t.List) -> E: 484 unit_based = len(args) == 3 485 this = args[2] if unit_based else seq_get(args, 0) 486 unit = args[0] if unit_based else exp.Literal.string("DAY") 487 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 488 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 489 490 return inner_func
493def parse_date_delta_with_interval( 494 expression_class: t.Type[E], 495) -> t.Callable[[t.List], t.Optional[E]]: 496 def func(args: t.List) -> t.Optional[E]: 497 if len(args) < 2: 498 return None 499 500 interval = args[1] 501 expression = interval.this 502 if expression and expression.is_string: 503 expression = exp.Literal.number(expression.this) 504 505 return expression_class( 506 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 507 ) 508 509 return func
576def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 577 cond = expression.this 578 579 if isinstance(expression.this, exp.Distinct): 580 cond = expression.this.expressions[0] 581 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 582 583 return self.func("sum", exp.func("if", cond, 1, 0))
586def trim_sql(self: Generator, expression: exp.Trim) -> str: 587 target = self.sql(expression, "this") 588 trim_type = self.sql(expression, "position") 589 remove_chars = self.sql(expression, "expression") 590 collation = self.sql(expression, "collation") 591 592 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 593 if not remove_chars and not collation: 594 return self.trim_sql(expression) 595 596 trim_type = f"{trim_type} " if trim_type else "" 597 remove_chars = f"{remove_chars} " if remove_chars else "" 598 from_part = "FROM " if trim_type or remove_chars else "" 599 collation = f" COLLATE {collation}" if collation else "" 600 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
607def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 608 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 609 _dialect = Dialect.get_or_raise(dialect) 610 time_format = self.format_time(expression) 611 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 612 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 613 return f"CAST({self.sql(expression, 'this')} AS DATE)" 614 615 return _ts_or_ds_to_date_sql
626def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 627 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 628 if bad_args: 629 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 630 631 return self.func( 632 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 633 )
636def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 637 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 638 if bad_args: 639 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 640 641 return self.func( 642 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 643 )
646def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 647 names = [] 648 for agg in aggregations: 649 if isinstance(agg, exp.Alias): 650 names.append(agg.alias) 651 else: 652 """ 653 This case corresponds to aggregations without aliases being used as suffixes 654 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 655 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 656 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 657 """ 658 agg_all_unquoted = agg.transform( 659 lambda node: exp.Identifier(this=node.name, quoted=False) 660 if isinstance(node, exp.Identifier) 661 else node 662 ) 663 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 664 665 return names