sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.errors import ParseError 9from sqlglot.generator import Generator 10from sqlglot.helper import flatten, seq_get 11from sqlglot.parser import Parser 12from sqlglot.time import format_time 13from sqlglot.tokens import Token, Tokenizer, TokenType 14from sqlglot.trie import new_trie 15 16B = t.TypeVar("B", bound=exp.Binary) 17 18 19class Dialects(str, Enum): 20 DIALECT = "" 21 22 BIGQUERY = "bigquery" 23 CLICKHOUSE = "clickhouse" 24 DATABRICKS = "databricks" 25 DRILL = "drill" 26 DUCKDB = "duckdb" 27 HIVE = "hive" 28 MYSQL = "mysql" 29 ORACLE = "oracle" 30 POSTGRES = "postgres" 31 PRESTO = "presto" 32 REDSHIFT = "redshift" 33 SNOWFLAKE = "snowflake" 34 SPARK = "spark" 35 SPARK2 = "spark2" 36 SQLITE = "sqlite" 37 STARROCKS = "starrocks" 38 TABLEAU = "tableau" 39 TERADATA = "teradata" 40 TRINO = "trino" 41 TSQL = "tsql" 42 43 44class _Dialect(type): 45 classes: t.Dict[str, t.Type[Dialect]] = {} 46 47 def __eq__(cls, other: t.Any) -> bool: 48 if cls is other: 49 return True 50 if isinstance(other, str): 51 return cls is cls.get(other) 52 if isinstance(other, Dialect): 53 return cls is type(other) 54 55 return False 56 57 def __hash__(cls) -> int: 58 return hash(cls.__name__.lower()) 59 60 @classmethod 61 def __getitem__(cls, key: str) -> t.Type[Dialect]: 62 return cls.classes[key] 63 64 @classmethod 65 def get( 66 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 67 ) -> t.Optional[t.Type[Dialect]]: 68 return cls.classes.get(key, default) 69 70 def __new__(cls, clsname, bases, attrs): 71 klass = super().__new__(cls, clsname, bases, attrs) 72 enum = Dialects.__members__.get(clsname.upper()) 73 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 74 75 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 76 klass.FORMAT_TRIE = ( 77 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 78 ) 79 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 80 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 81 82 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 83 klass.parser_class = getattr(klass, "Parser", Parser) 84 klass.generator_class = getattr(klass, "Generator", Generator) 85 86 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 87 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 88 klass.tokenizer_class._IDENTIFIERS.items() 89 )[0] 90 91 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 92 return next( 93 ( 94 (s, e) 95 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 96 if t == token_type 97 ), 98 (None, None), 99 ) 100 101 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 102 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 103 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 104 105 dialect_properties = { 106 **{ 107 k: v 108 for k, v in vars(klass).items() 109 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 110 }, 111 "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], 112 "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], 113 } 114 115 if enum not in ("", "bigquery"): 116 dialect_properties["SELECT_KINDS"] = () 117 118 # Pass required dialect properties to the tokenizer, parser and generator classes 119 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 120 for name, value in dialect_properties.items(): 121 if hasattr(subclass, name): 122 setattr(subclass, name, value) 123 124 if not klass.STRICT_STRING_CONCAT: 125 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 126 127 klass.generator_class.can_identify = klass.can_identify 128 129 return klass 130 131 132class Dialect(metaclass=_Dialect): 133 # Determines the base index offset for arrays 134 INDEX_OFFSET = 0 135 136 # If true unnest table aliases are considered only as column aliases 137 UNNEST_COLUMN_ONLY = False 138 139 # Determines whether or not the table alias comes after tablesample 140 ALIAS_POST_TABLESAMPLE = False 141 142 # Determines whether or not unquoted identifiers are resolved as uppercase 143 # When set to None, it means that the dialect treats all identifiers as case-insensitive 144 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 145 146 # Determines whether or not an unquoted identifier can start with a digit 147 IDENTIFIERS_CAN_START_WITH_DIGIT = False 148 149 # Determines whether or not CONCAT's arguments must be strings 150 STRICT_STRING_CONCAT = False 151 152 # Determines how function names are going to be normalized 153 NORMALIZE_FUNCTIONS: bool | str = "upper" 154 155 # Indicates the default null ordering method to use if not explicitly set 156 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 157 NULL_ORDERING = "nulls_are_small" 158 159 DATE_FORMAT = "'%Y-%m-%d'" 160 DATEINT_FORMAT = "'%Y%m%d'" 161 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 162 163 # Custom time mappings in which the key represents dialect time format 164 # and the value represents a python time format 165 TIME_MAPPING: t.Dict[str, str] = {} 166 167 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 168 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 169 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 170 FORMAT_MAPPING: t.Dict[str, str] = {} 171 172 # Columns that are auto-generated by the engine corresponding to this dialect 173 # Such columns may be excluded from SELECT * queries, for example 174 PSEUDOCOLUMNS: t.Set[str] = set() 175 176 # Autofilled 177 tokenizer_class = Tokenizer 178 parser_class = Parser 179 generator_class = Generator 180 181 # A trie of the time_mapping keys 182 TIME_TRIE: t.Dict = {} 183 FORMAT_TRIE: t.Dict = {} 184 185 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 186 INVERSE_TIME_TRIE: t.Dict = {} 187 188 def __eq__(self, other: t.Any) -> bool: 189 return type(self) == other 190 191 def __hash__(self) -> int: 192 return hash(type(self)) 193 194 @classmethod 195 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 196 if not dialect: 197 return cls 198 if isinstance(dialect, _Dialect): 199 return dialect 200 if isinstance(dialect, Dialect): 201 return dialect.__class__ 202 203 result = cls.get(dialect) 204 if not result: 205 raise ValueError(f"Unknown dialect '{dialect}'") 206 207 return result 208 209 @classmethod 210 def format_time( 211 cls, expression: t.Optional[str | exp.Expression] 212 ) -> t.Optional[exp.Expression]: 213 if isinstance(expression, str): 214 return exp.Literal.string( 215 # the time formats are quoted 216 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 217 ) 218 219 if expression and expression.is_string: 220 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 221 222 return expression 223 224 @classmethod 225 def normalize_identifier(cls, expression: E) -> E: 226 """ 227 Normalizes an unquoted identifier to either lower or upper case, thus essentially 228 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 229 they will be normalized regardless of being quoted or not. 230 """ 231 if isinstance(expression, exp.Identifier) and ( 232 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 233 ): 234 expression.set( 235 "this", 236 expression.this.upper() 237 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 238 else expression.this.lower(), 239 ) 240 241 return expression 242 243 @classmethod 244 def case_sensitive(cls, text: str) -> bool: 245 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 246 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 247 return False 248 249 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 250 return any(unsafe(char) for char in text) 251 252 @classmethod 253 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 254 """Checks if text can be identified given an identify option. 255 256 Args: 257 text: The text to check. 258 identify: 259 "always" or `True`: Always returns true. 260 "safe": True if the identifier is case-insensitive. 261 262 Returns: 263 Whether or not the given text can be identified. 264 """ 265 if identify is True or identify == "always": 266 return True 267 268 if identify == "safe": 269 return not cls.case_sensitive(text) 270 271 return False 272 273 @classmethod 274 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 275 if isinstance(expression, exp.Identifier): 276 name = expression.this 277 expression.set( 278 "quoted", 279 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 280 ) 281 282 return expression 283 284 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 285 return self.parser(**opts).parse(self.tokenize(sql), sql) 286 287 def parse_into( 288 self, expression_type: exp.IntoType, sql: str, **opts 289 ) -> t.List[t.Optional[exp.Expression]]: 290 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 291 292 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 293 return self.generator(**opts).generate(expression) 294 295 def transpile(self, sql: str, **opts) -> t.List[str]: 296 return [self.generate(expression, **opts) for expression in self.parse(sql)] 297 298 def tokenize(self, sql: str) -> t.List[Token]: 299 return self.tokenizer.tokenize(sql) 300 301 @property 302 def tokenizer(self) -> Tokenizer: 303 if not hasattr(self, "_tokenizer"): 304 self._tokenizer = self.tokenizer_class() 305 return self._tokenizer 306 307 def parser(self, **opts) -> Parser: 308 return self.parser_class(**opts) 309 310 def generator(self, **opts) -> Generator: 311 return self.generator_class(**opts) 312 313 314DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 315 316 317def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 318 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 319 320 321def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 322 if expression.args.get("accuracy"): 323 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 324 return self.func("APPROX_COUNT_DISTINCT", expression.this) 325 326 327def if_sql(self: Generator, expression: exp.If) -> str: 328 return self.func( 329 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 330 ) 331 332 333def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 334 return self.binary(expression, "->") 335 336 337def arrow_json_extract_scalar_sql( 338 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 339) -> str: 340 return self.binary(expression, "->>") 341 342 343def inline_array_sql(self: Generator, expression: exp.Array) -> str: 344 return f"[{self.expressions(expression)}]" 345 346 347def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 348 return self.like_sql( 349 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 350 ) 351 352 353def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 354 zone = self.sql(expression, "this") 355 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 356 357 358def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 359 if expression.args.get("recursive"): 360 self.unsupported("Recursive CTEs are unsupported") 361 expression.args["recursive"] = False 362 return self.with_sql(expression) 363 364 365def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 366 n = self.sql(expression, "this") 367 d = self.sql(expression, "expression") 368 return f"IF({d} <> 0, {n} / {d}, NULL)" 369 370 371def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 372 self.unsupported("TABLESAMPLE unsupported") 373 return self.sql(expression.this) 374 375 376def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 377 self.unsupported("PIVOT unsupported") 378 return "" 379 380 381def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 382 return self.cast_sql(expression) 383 384 385def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 386 self.unsupported("Properties unsupported") 387 return "" 388 389 390def no_comment_column_constraint_sql( 391 self: Generator, expression: exp.CommentColumnConstraint 392) -> str: 393 self.unsupported("CommentColumnConstraint unsupported") 394 return "" 395 396 397def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 398 self.unsupported("MAP_FROM_ENTRIES unsupported") 399 return "" 400 401 402def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 403 this = self.sql(expression, "this") 404 substr = self.sql(expression, "substr") 405 position = self.sql(expression, "position") 406 if position: 407 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 408 return f"STRPOS({this}, {substr})" 409 410 411def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 412 this = self.sql(expression, "this") 413 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 414 return f"{this}.{struct_key}" 415 416 417def var_map_sql( 418 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 419) -> str: 420 keys = expression.args["keys"] 421 values = expression.args["values"] 422 423 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 424 self.unsupported("Cannot convert array columns into map.") 425 return self.func(map_func_name, keys, values) 426 427 args = [] 428 for key, value in zip(keys.expressions, values.expressions): 429 args.append(self.sql(key)) 430 args.append(self.sql(value)) 431 432 return self.func(map_func_name, *args) 433 434 435def format_time_lambda( 436 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 437) -> t.Callable[[t.List], E]: 438 """Helper used for time expressions. 439 440 Args: 441 exp_class: the expression class to instantiate. 442 dialect: target sql dialect. 443 default: the default format, True being time. 444 445 Returns: 446 A callable that can be used to return the appropriately formatted time expression. 447 """ 448 449 def _format_time(args: t.List): 450 return exp_class( 451 this=seq_get(args, 0), 452 format=Dialect[dialect].format_time( 453 seq_get(args, 1) 454 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 455 ), 456 ) 457 458 return _format_time 459 460 461def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 462 """ 463 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 464 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 465 columns are removed from the create statement. 466 """ 467 has_schema = isinstance(expression.this, exp.Schema) 468 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 469 470 if has_schema and is_partitionable: 471 expression = expression.copy() 472 prop = expression.find(exp.PartitionedByProperty) 473 if prop and prop.this and not isinstance(prop.this, exp.Schema): 474 schema = expression.this 475 columns = {v.name.upper() for v in prop.this.expressions} 476 partitions = [col for col in schema.expressions if col.name.upper() in columns] 477 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 478 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 479 expression.set("this", schema) 480 481 return self.create_sql(expression) 482 483 484def parse_date_delta( 485 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 486) -> t.Callable[[t.List], E]: 487 def inner_func(args: t.List) -> E: 488 unit_based = len(args) == 3 489 this = args[2] if unit_based else seq_get(args, 0) 490 unit = args[0] if unit_based else exp.Literal.string("DAY") 491 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 492 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 493 494 return inner_func 495 496 497def parse_date_delta_with_interval( 498 expression_class: t.Type[E], 499) -> t.Callable[[t.List], t.Optional[E]]: 500 def func(args: t.List) -> t.Optional[E]: 501 if len(args) < 2: 502 return None 503 504 interval = args[1] 505 506 if not isinstance(interval, exp.Interval): 507 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 508 509 expression = interval.this 510 if expression and expression.is_string: 511 expression = exp.Literal.number(expression.this) 512 513 return expression_class( 514 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 515 ) 516 517 return func 518 519 520def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 521 unit = seq_get(args, 0) 522 this = seq_get(args, 1) 523 524 if isinstance(this, exp.Cast) and this.is_type("date"): 525 return exp.DateTrunc(unit=unit, this=this) 526 return exp.TimestampTrunc(this=this, unit=unit) 527 528 529def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 530 return self.func( 531 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 532 ) 533 534 535def locate_to_strposition(args: t.List) -> exp.Expression: 536 return exp.StrPosition( 537 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 538 ) 539 540 541def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 542 return self.func( 543 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 544 ) 545 546 547def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 548 expression = expression.copy() 549 return self.sql( 550 exp.Substring( 551 this=expression.this, start=exp.Literal.number(1), length=expression.expression 552 ) 553 ) 554 555 556def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 557 expression = expression.copy() 558 return self.sql( 559 exp.Substring( 560 this=expression.this, 561 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 562 ) 563 ) 564 565 566def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 567 return self.sql(exp.cast(expression.this, "timestamp")) 568 569 570def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 571 return self.sql(exp.cast(expression.this, "date")) 572 573 574def min_or_least(self: Generator, expression: exp.Min) -> str: 575 name = "LEAST" if expression.expressions else "MIN" 576 return rename_func(name)(self, expression) 577 578 579def max_or_greatest(self: Generator, expression: exp.Max) -> str: 580 name = "GREATEST" if expression.expressions else "MAX" 581 return rename_func(name)(self, expression) 582 583 584def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 585 cond = expression.this 586 587 if isinstance(expression.this, exp.Distinct): 588 cond = expression.this.expressions[0] 589 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 590 591 return self.func("sum", exp.func("if", cond, 1, 0)) 592 593 594def trim_sql(self: Generator, expression: exp.Trim) -> str: 595 target = self.sql(expression, "this") 596 trim_type = self.sql(expression, "position") 597 remove_chars = self.sql(expression, "expression") 598 collation = self.sql(expression, "collation") 599 600 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 601 if not remove_chars and not collation: 602 return self.trim_sql(expression) 603 604 trim_type = f"{trim_type} " if trim_type else "" 605 remove_chars = f"{remove_chars} " if remove_chars else "" 606 from_part = "FROM " if trim_type or remove_chars else "" 607 collation = f" COLLATE {collation}" if collation else "" 608 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 609 610 611def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 612 return self.func("STRPTIME", expression.this, self.format_time(expression)) 613 614 615def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 616 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 617 _dialect = Dialect.get_or_raise(dialect) 618 time_format = self.format_time(expression) 619 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 620 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 621 622 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 623 624 return _ts_or_ds_to_date_sql 625 626 627def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 628 this, *rest_args = expression.expressions 629 for arg in rest_args: 630 this = exp.DPipe(this=this, expression=arg) 631 632 return self.sql(this) 633 634 635def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 636 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 637 if bad_args: 638 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 639 640 return self.func( 641 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 642 ) 643 644 645def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 646 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 647 if bad_args: 648 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 649 650 return self.func( 651 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 652 ) 653 654 655def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 656 names = [] 657 for agg in aggregations: 658 if isinstance(agg, exp.Alias): 659 names.append(agg.alias) 660 else: 661 """ 662 This case corresponds to aggregations without aliases being used as suffixes 663 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 664 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 665 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 666 """ 667 agg_all_unquoted = agg.transform( 668 lambda node: exp.Identifier(this=node.name, quoted=False) 669 if isinstance(node, exp.Identifier) 670 else node 671 ) 672 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 673 674 return names 675 676 677def simplify_literal(expression: E, copy: bool = True) -> E: 678 if not isinstance(expression.expression, exp.Literal): 679 from sqlglot.optimizer.simplify import simplify 680 681 expression = exp.maybe_copy(expression, copy) 682 simplify(expression.expression) 683 684 return expression 685 686 687def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 688 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
133class Dialect(metaclass=_Dialect): 134 # Determines the base index offset for arrays 135 INDEX_OFFSET = 0 136 137 # If true unnest table aliases are considered only as column aliases 138 UNNEST_COLUMN_ONLY = False 139 140 # Determines whether or not the table alias comes after tablesample 141 ALIAS_POST_TABLESAMPLE = False 142 143 # Determines whether or not unquoted identifiers are resolved as uppercase 144 # When set to None, it means that the dialect treats all identifiers as case-insensitive 145 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 146 147 # Determines whether or not an unquoted identifier can start with a digit 148 IDENTIFIERS_CAN_START_WITH_DIGIT = False 149 150 # Determines whether or not CONCAT's arguments must be strings 151 STRICT_STRING_CONCAT = False 152 153 # Determines how function names are going to be normalized 154 NORMALIZE_FUNCTIONS: bool | str = "upper" 155 156 # Indicates the default null ordering method to use if not explicitly set 157 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 158 NULL_ORDERING = "nulls_are_small" 159 160 DATE_FORMAT = "'%Y-%m-%d'" 161 DATEINT_FORMAT = "'%Y%m%d'" 162 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 163 164 # Custom time mappings in which the key represents dialect time format 165 # and the value represents a python time format 166 TIME_MAPPING: t.Dict[str, str] = {} 167 168 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 169 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 170 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 171 FORMAT_MAPPING: t.Dict[str, str] = {} 172 173 # Columns that are auto-generated by the engine corresponding to this dialect 174 # Such columns may be excluded from SELECT * queries, for example 175 PSEUDOCOLUMNS: t.Set[str] = set() 176 177 # Autofilled 178 tokenizer_class = Tokenizer 179 parser_class = Parser 180 generator_class = Generator 181 182 # A trie of the time_mapping keys 183 TIME_TRIE: t.Dict = {} 184 FORMAT_TRIE: t.Dict = {} 185 186 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 187 INVERSE_TIME_TRIE: t.Dict = {} 188 189 def __eq__(self, other: t.Any) -> bool: 190 return type(self) == other 191 192 def __hash__(self) -> int: 193 return hash(type(self)) 194 195 @classmethod 196 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 197 if not dialect: 198 return cls 199 if isinstance(dialect, _Dialect): 200 return dialect 201 if isinstance(dialect, Dialect): 202 return dialect.__class__ 203 204 result = cls.get(dialect) 205 if not result: 206 raise ValueError(f"Unknown dialect '{dialect}'") 207 208 return result 209 210 @classmethod 211 def format_time( 212 cls, expression: t.Optional[str | exp.Expression] 213 ) -> t.Optional[exp.Expression]: 214 if isinstance(expression, str): 215 return exp.Literal.string( 216 # the time formats are quoted 217 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 218 ) 219 220 if expression and expression.is_string: 221 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 222 223 return expression 224 225 @classmethod 226 def normalize_identifier(cls, expression: E) -> E: 227 """ 228 Normalizes an unquoted identifier to either lower or upper case, thus essentially 229 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 230 they will be normalized regardless of being quoted or not. 231 """ 232 if isinstance(expression, exp.Identifier) and ( 233 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 234 ): 235 expression.set( 236 "this", 237 expression.this.upper() 238 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 239 else expression.this.lower(), 240 ) 241 242 return expression 243 244 @classmethod 245 def case_sensitive(cls, text: str) -> bool: 246 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 247 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 248 return False 249 250 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 251 return any(unsafe(char) for char in text) 252 253 @classmethod 254 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 255 """Checks if text can be identified given an identify option. 256 257 Args: 258 text: The text to check. 259 identify: 260 "always" or `True`: Always returns true. 261 "safe": True if the identifier is case-insensitive. 262 263 Returns: 264 Whether or not the given text can be identified. 265 """ 266 if identify is True or identify == "always": 267 return True 268 269 if identify == "safe": 270 return not cls.case_sensitive(text) 271 272 return False 273 274 @classmethod 275 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 276 if isinstance(expression, exp.Identifier): 277 name = expression.this 278 expression.set( 279 "quoted", 280 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 281 ) 282 283 return expression 284 285 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 286 return self.parser(**opts).parse(self.tokenize(sql), sql) 287 288 def parse_into( 289 self, expression_type: exp.IntoType, sql: str, **opts 290 ) -> t.List[t.Optional[exp.Expression]]: 291 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 292 293 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 294 return self.generator(**opts).generate(expression) 295 296 def transpile(self, sql: str, **opts) -> t.List[str]: 297 return [self.generate(expression, **opts) for expression in self.parse(sql)] 298 299 def tokenize(self, sql: str) -> t.List[Token]: 300 return self.tokenizer.tokenize(sql) 301 302 @property 303 def tokenizer(self) -> Tokenizer: 304 if not hasattr(self, "_tokenizer"): 305 self._tokenizer = self.tokenizer_class() 306 return self._tokenizer 307 308 def parser(self, **opts) -> Parser: 309 return self.parser_class(**opts) 310 311 def generator(self, **opts) -> Generator: 312 return self.generator_class(**opts)
195 @classmethod 196 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 197 if not dialect: 198 return cls 199 if isinstance(dialect, _Dialect): 200 return dialect 201 if isinstance(dialect, Dialect): 202 return dialect.__class__ 203 204 result = cls.get(dialect) 205 if not result: 206 raise ValueError(f"Unknown dialect '{dialect}'") 207 208 return result
210 @classmethod 211 def format_time( 212 cls, expression: t.Optional[str | exp.Expression] 213 ) -> t.Optional[exp.Expression]: 214 if isinstance(expression, str): 215 return exp.Literal.string( 216 # the time formats are quoted 217 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 218 ) 219 220 if expression and expression.is_string: 221 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 222 223 return expression
225 @classmethod 226 def normalize_identifier(cls, expression: E) -> E: 227 """ 228 Normalizes an unquoted identifier to either lower or upper case, thus essentially 229 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 230 they will be normalized regardless of being quoted or not. 231 """ 232 if isinstance(expression, exp.Identifier) and ( 233 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 234 ): 235 expression.set( 236 "this", 237 expression.this.upper() 238 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 239 else expression.this.lower(), 240 ) 241 242 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
244 @classmethod 245 def case_sensitive(cls, text: str) -> bool: 246 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 247 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 248 return False 249 250 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 251 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
253 @classmethod 254 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 255 """Checks if text can be identified given an identify option. 256 257 Args: 258 text: The text to check. 259 identify: 260 "always" or `True`: Always returns true. 261 "safe": True if the identifier is case-insensitive. 262 263 Returns: 264 Whether or not the given text can be identified. 265 """ 266 if identify is True or identify == "always": 267 return True 268 269 if identify == "safe": 270 return not cls.case_sensitive(text) 271 272 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
274 @classmethod 275 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 276 if isinstance(expression, exp.Identifier): 277 name = expression.this 278 expression.set( 279 "quoted", 280 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 281 ) 282 283 return expression
403def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 404 this = self.sql(expression, "this") 405 substr = self.sql(expression, "substr") 406 position = self.sql(expression, "position") 407 if position: 408 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 409 return f"STRPOS({this}, {substr})"
418def var_map_sql( 419 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 420) -> str: 421 keys = expression.args["keys"] 422 values = expression.args["values"] 423 424 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 425 self.unsupported("Cannot convert array columns into map.") 426 return self.func(map_func_name, keys, values) 427 428 args = [] 429 for key, value in zip(keys.expressions, values.expressions): 430 args.append(self.sql(key)) 431 args.append(self.sql(value)) 432 433 return self.func(map_func_name, *args)
436def format_time_lambda( 437 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 438) -> t.Callable[[t.List], E]: 439 """Helper used for time expressions. 440 441 Args: 442 exp_class: the expression class to instantiate. 443 dialect: target sql dialect. 444 default: the default format, True being time. 445 446 Returns: 447 A callable that can be used to return the appropriately formatted time expression. 448 """ 449 450 def _format_time(args: t.List): 451 return exp_class( 452 this=seq_get(args, 0), 453 format=Dialect[dialect].format_time( 454 seq_get(args, 1) 455 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 456 ), 457 ) 458 459 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
462def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 463 """ 464 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 465 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 466 columns are removed from the create statement. 467 """ 468 has_schema = isinstance(expression.this, exp.Schema) 469 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 470 471 if has_schema and is_partitionable: 472 expression = expression.copy() 473 prop = expression.find(exp.PartitionedByProperty) 474 if prop and prop.this and not isinstance(prop.this, exp.Schema): 475 schema = expression.this 476 columns = {v.name.upper() for v in prop.this.expressions} 477 partitions = [col for col in schema.expressions if col.name.upper() in columns] 478 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 479 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 480 expression.set("this", schema) 481 482 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
485def parse_date_delta( 486 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 487) -> t.Callable[[t.List], E]: 488 def inner_func(args: t.List) -> E: 489 unit_based = len(args) == 3 490 this = args[2] if unit_based else seq_get(args, 0) 491 unit = args[0] if unit_based else exp.Literal.string("DAY") 492 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 493 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 494 495 return inner_func
498def parse_date_delta_with_interval( 499 expression_class: t.Type[E], 500) -> t.Callable[[t.List], t.Optional[E]]: 501 def func(args: t.List) -> t.Optional[E]: 502 if len(args) < 2: 503 return None 504 505 interval = args[1] 506 507 if not isinstance(interval, exp.Interval): 508 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 509 510 expression = interval.this 511 if expression and expression.is_string: 512 expression = exp.Literal.number(expression.this) 513 514 return expression_class( 515 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 516 ) 517 518 return func
585def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 586 cond = expression.this 587 588 if isinstance(expression.this, exp.Distinct): 589 cond = expression.this.expressions[0] 590 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 591 592 return self.func("sum", exp.func("if", cond, 1, 0))
595def trim_sql(self: Generator, expression: exp.Trim) -> str: 596 target = self.sql(expression, "this") 597 trim_type = self.sql(expression, "position") 598 remove_chars = self.sql(expression, "expression") 599 collation = self.sql(expression, "collation") 600 601 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 602 if not remove_chars and not collation: 603 return self.trim_sql(expression) 604 605 trim_type = f"{trim_type} " if trim_type else "" 606 remove_chars = f"{remove_chars} " if remove_chars else "" 607 from_part = "FROM " if trim_type or remove_chars else "" 608 collation = f" COLLATE {collation}" if collation else "" 609 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
616def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 617 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 618 _dialect = Dialect.get_or_raise(dialect) 619 time_format = self.format_time(expression) 620 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 621 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 622 623 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 624 625 return _ts_or_ds_to_date_sql
636def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 637 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 638 if bad_args: 639 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 640 641 return self.func( 642 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 643 )
646def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 647 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 648 if bad_args: 649 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 650 651 return self.func( 652 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 653 )
656def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 657 names = [] 658 for agg in aggregations: 659 if isinstance(agg, exp.Alias): 660 names.append(agg.alias) 661 else: 662 """ 663 This case corresponds to aggregations without aliases being used as suffixes 664 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 665 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 666 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 667 """ 668 agg_all_unquoted = agg.transform( 669 lambda node: exp.Identifier(this=node.name, quoted=False) 670 if isinstance(node, exp.Identifier) 671 else node 672 ) 673 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 674 675 return names