sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot.generator import Generator 8from sqlglot.helper import flatten, seq_get 9from sqlglot.parser import Parser 10from sqlglot.time import format_time 11from sqlglot.tokens import Token, Tokenizer, TokenType 12from sqlglot.trie import new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot._typing import E 16 17 18# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. 19# https://docs.snowflake.com/en/sql-reference/identifiers-syntax 20RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} 21 22 23class Dialects(str, Enum): 24 DIALECT = "" 25 26 BIGQUERY = "bigquery" 27 CLICKHOUSE = "clickhouse" 28 DATABRICKS = "databricks" 29 DRILL = "drill" 30 DUCKDB = "duckdb" 31 HIVE = "hive" 32 MYSQL = "mysql" 33 ORACLE = "oracle" 34 POSTGRES = "postgres" 35 PRESTO = "presto" 36 REDSHIFT = "redshift" 37 SNOWFLAKE = "snowflake" 38 SPARK = "spark" 39 SPARK2 = "spark2" 40 SQLITE = "sqlite" 41 STARROCKS = "starrocks" 42 TABLEAU = "tableau" 43 TERADATA = "teradata" 44 TRINO = "trino" 45 TSQL = "tsql" 46 47 48class _Dialect(type): 49 classes: t.Dict[str, t.Type[Dialect]] = {} 50 51 def __eq__(cls, other: t.Any) -> bool: 52 if cls is other: 53 return True 54 if isinstance(other, str): 55 return cls is cls.get(other) 56 if isinstance(other, Dialect): 57 return cls is type(other) 58 59 return False 60 61 def __hash__(cls) -> int: 62 return hash(cls.__name__.lower()) 63 64 @classmethod 65 def __getitem__(cls, key: str) -> t.Type[Dialect]: 66 return cls.classes[key] 67 68 @classmethod 69 def get( 70 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 71 ) -> t.Optional[t.Type[Dialect]]: 72 return cls.classes.get(key, default) 73 74 def __new__(cls, clsname, bases, attrs): 75 klass = super().__new__(cls, clsname, bases, attrs) 76 enum = Dialects.__members__.get(clsname.upper()) 77 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 78 79 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 80 klass.FORMAT_TRIE = ( 81 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 82 ) 83 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 84 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 85 86 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 87 klass.parser_class = getattr(klass, "Parser", Parser) 88 klass.generator_class = getattr(klass, "Generator", Generator) 89 90 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 91 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 92 klass.tokenizer_class._IDENTIFIERS.items() 93 )[0] 94 95 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 96 return next( 97 ( 98 (s, e) 99 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 100 if t == token_type 101 ), 102 (None, None), 103 ) 104 105 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 106 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 107 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 108 klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING) 109 110 dialect_properties = { 111 **{ 112 k: v 113 for k, v in vars(klass).items() 114 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 115 }, 116 "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], 117 "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], 118 } 119 120 # Pass required dialect properties to the tokenizer, parser and generator classes 121 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 122 for name, value in dialect_properties.items(): 123 if hasattr(subclass, name): 124 setattr(subclass, name, value) 125 126 if not klass.STRICT_STRING_CONCAT: 127 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 128 129 return klass 130 131 132class Dialect(metaclass=_Dialect): 133 # Determines the base index offset for arrays 134 INDEX_OFFSET = 0 135 136 # If true unnest table aliases are considered only as column aliases 137 UNNEST_COLUMN_ONLY = False 138 139 # Determines whether or not the table alias comes after tablesample 140 ALIAS_POST_TABLESAMPLE = False 141 142 # Determines whether or not an unquoted identifier can start with a digit 143 IDENTIFIERS_CAN_START_WITH_DIGIT = False 144 145 # Determines whether or not CONCAT's arguments must be strings 146 STRICT_STRING_CONCAT = False 147 148 # Determines how function names are going to be normalized 149 NORMALIZE_FUNCTIONS: bool | str = "upper" 150 151 # Indicates the default null ordering method to use if not explicitly set 152 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 153 NULL_ORDERING = "nulls_are_small" 154 155 DATE_FORMAT = "'%Y-%m-%d'" 156 DATEINT_FORMAT = "'%Y%m%d'" 157 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 158 159 # Custom time mappings in which the key represents dialect time format 160 # and the value represents a python time format 161 TIME_MAPPING: t.Dict[str, str] = {} 162 163 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 164 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 165 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 166 FORMAT_MAPPING: t.Dict[str, str] = {} 167 168 # Autofilled 169 tokenizer_class = Tokenizer 170 parser_class = Parser 171 generator_class = Generator 172 173 # A trie of the time_mapping keys 174 TIME_TRIE: t.Dict = {} 175 FORMAT_TRIE: t.Dict = {} 176 177 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 178 INVERSE_TIME_TRIE: t.Dict = {} 179 180 def __eq__(self, other: t.Any) -> bool: 181 return type(self) == other 182 183 def __hash__(self) -> int: 184 return hash(type(self)) 185 186 @classmethod 187 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 188 if not dialect: 189 return cls 190 if isinstance(dialect, _Dialect): 191 return dialect 192 if isinstance(dialect, Dialect): 193 return dialect.__class__ 194 195 result = cls.get(dialect) 196 if not result: 197 raise ValueError(f"Unknown dialect '{dialect}'") 198 199 return result 200 201 @classmethod 202 def format_time( 203 cls, expression: t.Optional[str | exp.Expression] 204 ) -> t.Optional[exp.Expression]: 205 if isinstance(expression, str): 206 return exp.Literal.string( 207 # the time formats are quoted 208 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 209 ) 210 211 if expression and expression.is_string: 212 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 213 214 return expression 215 216 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 217 return self.parser(**opts).parse(self.tokenize(sql), sql) 218 219 def parse_into( 220 self, expression_type: exp.IntoType, sql: str, **opts 221 ) -> t.List[t.Optional[exp.Expression]]: 222 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 223 224 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 225 return self.generator(**opts).generate(expression) 226 227 def transpile(self, sql: str, **opts) -> t.List[str]: 228 return [self.generate(expression, **opts) for expression in self.parse(sql)] 229 230 def tokenize(self, sql: str) -> t.List[Token]: 231 return self.tokenizer.tokenize(sql) 232 233 @property 234 def tokenizer(self) -> Tokenizer: 235 if not hasattr(self, "_tokenizer"): 236 self._tokenizer = self.tokenizer_class() 237 return self._tokenizer 238 239 def parser(self, **opts) -> Parser: 240 return self.parser_class(**opts) 241 242 def generator(self, **opts) -> Generator: 243 return self.generator_class(**opts) 244 245 246DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 247 248 249def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 250 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 251 252 253def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 254 if expression.args.get("accuracy"): 255 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 256 return self.func("APPROX_COUNT_DISTINCT", expression.this) 257 258 259def if_sql(self: Generator, expression: exp.If) -> str: 260 return self.func( 261 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 262 ) 263 264 265def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 266 return self.binary(expression, "->") 267 268 269def arrow_json_extract_scalar_sql( 270 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 271) -> str: 272 return self.binary(expression, "->>") 273 274 275def inline_array_sql(self: Generator, expression: exp.Array) -> str: 276 return f"[{self.expressions(expression)}]" 277 278 279def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 280 return self.like_sql( 281 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 282 ) 283 284 285def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 286 zone = self.sql(expression, "this") 287 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 288 289 290def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 291 if expression.args.get("recursive"): 292 self.unsupported("Recursive CTEs are unsupported") 293 expression.args["recursive"] = False 294 return self.with_sql(expression) 295 296 297def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 298 n = self.sql(expression, "this") 299 d = self.sql(expression, "expression") 300 return f"IF({d} <> 0, {n} / {d}, NULL)" 301 302 303def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 304 self.unsupported("TABLESAMPLE unsupported") 305 return self.sql(expression.this) 306 307 308def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 309 self.unsupported("PIVOT unsupported") 310 return "" 311 312 313def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 314 return self.cast_sql(expression) 315 316 317def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 318 self.unsupported("Properties unsupported") 319 return "" 320 321 322def no_comment_column_constraint_sql( 323 self: Generator, expression: exp.CommentColumnConstraint 324) -> str: 325 self.unsupported("CommentColumnConstraint unsupported") 326 return "" 327 328 329def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 330 this = self.sql(expression, "this") 331 substr = self.sql(expression, "substr") 332 position = self.sql(expression, "position") 333 if position: 334 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 335 return f"STRPOS({this}, {substr})" 336 337 338def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 339 this = self.sql(expression, "this") 340 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 341 return f"{this}.{struct_key}" 342 343 344def var_map_sql( 345 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 346) -> str: 347 keys = expression.args["keys"] 348 values = expression.args["values"] 349 350 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 351 self.unsupported("Cannot convert array columns into map.") 352 return self.func(map_func_name, keys, values) 353 354 args = [] 355 for key, value in zip(keys.expressions, values.expressions): 356 args.append(self.sql(key)) 357 args.append(self.sql(value)) 358 359 return self.func(map_func_name, *args) 360 361 362def format_time_lambda( 363 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 364) -> t.Callable[[t.List], E]: 365 """Helper used for time expressions. 366 367 Args: 368 exp_class: the expression class to instantiate. 369 dialect: target sql dialect. 370 default: the default format, True being time. 371 372 Returns: 373 A callable that can be used to return the appropriately formatted time expression. 374 """ 375 376 def _format_time(args: t.List): 377 return exp_class( 378 this=seq_get(args, 0), 379 format=Dialect[dialect].format_time( 380 seq_get(args, 1) 381 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 382 ), 383 ) 384 385 return _format_time 386 387 388def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 389 """ 390 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 391 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 392 columns are removed from the create statement. 393 """ 394 has_schema = isinstance(expression.this, exp.Schema) 395 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 396 397 if has_schema and is_partitionable: 398 expression = expression.copy() 399 prop = expression.find(exp.PartitionedByProperty) 400 if prop and prop.this and not isinstance(prop.this, exp.Schema): 401 schema = expression.this 402 columns = {v.name.upper() for v in prop.this.expressions} 403 partitions = [col for col in schema.expressions if col.name.upper() in columns] 404 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 405 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 406 expression.set("this", schema) 407 408 return self.create_sql(expression) 409 410 411def parse_date_delta( 412 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 413) -> t.Callable[[t.List], E]: 414 def inner_func(args: t.List) -> E: 415 unit_based = len(args) == 3 416 this = args[2] if unit_based else seq_get(args, 0) 417 unit = args[0] if unit_based else exp.Literal.string("DAY") 418 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 419 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 420 421 return inner_func 422 423 424def parse_date_delta_with_interval( 425 expression_class: t.Type[E], 426) -> t.Callable[[t.List], t.Optional[E]]: 427 def func(args: t.List) -> t.Optional[E]: 428 if len(args) < 2: 429 return None 430 431 interval = args[1] 432 expression = interval.this 433 if expression and expression.is_string: 434 expression = exp.Literal.number(expression.this) 435 436 return expression_class( 437 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 438 ) 439 440 return func 441 442 443def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 444 unit = seq_get(args, 0) 445 this = seq_get(args, 1) 446 447 if isinstance(this, exp.Cast) and this.is_type("date"): 448 return exp.DateTrunc(unit=unit, this=this) 449 return exp.TimestampTrunc(this=this, unit=unit) 450 451 452def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 453 return self.func( 454 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 455 ) 456 457 458def locate_to_strposition(args: t.List) -> exp.Expression: 459 return exp.StrPosition( 460 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 461 ) 462 463 464def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 465 return self.func( 466 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 467 ) 468 469 470def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 471 expression = expression.copy() 472 return self.sql( 473 exp.Substring( 474 this=expression.this, start=exp.Literal.number(1), length=expression.expression 475 ) 476 ) 477 478 479def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 480 expression = expression.copy() 481 return self.sql( 482 exp.Substring( 483 this=expression.this, 484 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 485 ) 486 ) 487 488 489def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 490 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 491 492 493def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 494 return f"CAST({self.sql(expression, 'this')} AS DATE)" 495 496 497def min_or_least(self: Generator, expression: exp.Min) -> str: 498 name = "LEAST" if expression.expressions else "MIN" 499 return rename_func(name)(self, expression) 500 501 502def max_or_greatest(self: Generator, expression: exp.Max) -> str: 503 name = "GREATEST" if expression.expressions else "MAX" 504 return rename_func(name)(self, expression) 505 506 507def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 508 cond = expression.this 509 510 if isinstance(expression.this, exp.Distinct): 511 cond = expression.this.expressions[0] 512 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 513 514 return self.func("sum", exp.func("if", cond, 1, 0)) 515 516 517def trim_sql(self: Generator, expression: exp.Trim) -> str: 518 target = self.sql(expression, "this") 519 trim_type = self.sql(expression, "position") 520 remove_chars = self.sql(expression, "expression") 521 collation = self.sql(expression, "collation") 522 523 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 524 if not remove_chars and not collation: 525 return self.trim_sql(expression) 526 527 trim_type = f"{trim_type} " if trim_type else "" 528 remove_chars = f"{remove_chars} " if remove_chars else "" 529 from_part = "FROM " if trim_type or remove_chars else "" 530 collation = f" COLLATE {collation}" if collation else "" 531 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 532 533 534def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 535 return self.func("STRPTIME", expression.this, self.format_time(expression)) 536 537 538def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 539 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 540 _dialect = Dialect.get_or_raise(dialect) 541 time_format = self.format_time(expression) 542 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 543 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 544 return f"CAST({self.sql(expression, 'this')} AS DATE)" 545 546 return _ts_or_ds_to_date_sql 547 548 549def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 550 this, *rest_args = expression.expressions 551 for arg in rest_args: 552 this = exp.DPipe(this=this, expression=arg) 553 554 return self.sql(this) 555 556 557# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator 558def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 559 names = [] 560 for agg in aggregations: 561 if isinstance(agg, exp.Alias): 562 names.append(agg.alias) 563 else: 564 """ 565 This case corresponds to aggregations without aliases being used as suffixes 566 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 567 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 568 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 569 """ 570 agg_all_unquoted = agg.transform( 571 lambda node: exp.Identifier(this=node.name, quoted=False) 572 if isinstance(node, exp.Identifier) 573 else node 574 ) 575 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 576 577 return names
class
Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum): 25 DIALECT = "" 26 27 BIGQUERY = "bigquery" 28 CLICKHOUSE = "clickhouse" 29 DATABRICKS = "databricks" 30 DRILL = "drill" 31 DUCKDB = "duckdb" 32 HIVE = "hive" 33 MYSQL = "mysql" 34 ORACLE = "oracle" 35 POSTGRES = "postgres" 36 PRESTO = "presto" 37 REDSHIFT = "redshift" 38 SNOWFLAKE = "snowflake" 39 SPARK = "spark" 40 SPARK2 = "spark2" 41 SQLITE = "sqlite" 42 STARROCKS = "starrocks" 43 TABLEAU = "tableau" 44 TERADATA = "teradata" 45 TRINO = "trino" 46 TSQL = "tsql"
An enumeration.
DIALECT =
<Dialects.DIALECT: ''>
BIGQUERY =
<Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE =
<Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS =
<Dialects.DATABRICKS: 'databricks'>
DRILL =
<Dialects.DRILL: 'drill'>
DUCKDB =
<Dialects.DUCKDB: 'duckdb'>
HIVE =
<Dialects.HIVE: 'hive'>
MYSQL =
<Dialects.MYSQL: 'mysql'>
ORACLE =
<Dialects.ORACLE: 'oracle'>
POSTGRES =
<Dialects.POSTGRES: 'postgres'>
PRESTO =
<Dialects.PRESTO: 'presto'>
REDSHIFT =
<Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE =
<Dialects.SNOWFLAKE: 'snowflake'>
SPARK =
<Dialects.SPARK: 'spark'>
SPARK2 =
<Dialects.SPARK2: 'spark2'>
SQLITE =
<Dialects.SQLITE: 'sqlite'>
STARROCKS =
<Dialects.STARROCKS: 'starrocks'>
TABLEAU =
<Dialects.TABLEAU: 'tableau'>
TERADATA =
<Dialects.TERADATA: 'teradata'>
TRINO =
<Dialects.TRINO: 'trino'>
TSQL =
<Dialects.TSQL: 'tsql'>
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
class
Dialect:
133class Dialect(metaclass=_Dialect): 134 # Determines the base index offset for arrays 135 INDEX_OFFSET = 0 136 137 # If true unnest table aliases are considered only as column aliases 138 UNNEST_COLUMN_ONLY = False 139 140 # Determines whether or not the table alias comes after tablesample 141 ALIAS_POST_TABLESAMPLE = False 142 143 # Determines whether or not an unquoted identifier can start with a digit 144 IDENTIFIERS_CAN_START_WITH_DIGIT = False 145 146 # Determines whether or not CONCAT's arguments must be strings 147 STRICT_STRING_CONCAT = False 148 149 # Determines how function names are going to be normalized 150 NORMALIZE_FUNCTIONS: bool | str = "upper" 151 152 # Indicates the default null ordering method to use if not explicitly set 153 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 154 NULL_ORDERING = "nulls_are_small" 155 156 DATE_FORMAT = "'%Y-%m-%d'" 157 DATEINT_FORMAT = "'%Y%m%d'" 158 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 159 160 # Custom time mappings in which the key represents dialect time format 161 # and the value represents a python time format 162 TIME_MAPPING: t.Dict[str, str] = {} 163 164 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 165 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 166 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 167 FORMAT_MAPPING: t.Dict[str, str] = {} 168 169 # Autofilled 170 tokenizer_class = Tokenizer 171 parser_class = Parser 172 generator_class = Generator 173 174 # A trie of the time_mapping keys 175 TIME_TRIE: t.Dict = {} 176 FORMAT_TRIE: t.Dict = {} 177 178 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 179 INVERSE_TIME_TRIE: t.Dict = {} 180 181 def __eq__(self, other: t.Any) -> bool: 182 return type(self) == other 183 184 def __hash__(self) -> int: 185 return hash(type(self)) 186 187 @classmethod 188 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 189 if not dialect: 190 return cls 191 if isinstance(dialect, _Dialect): 192 return dialect 193 if isinstance(dialect, Dialect): 194 return dialect.__class__ 195 196 result = cls.get(dialect) 197 if not result: 198 raise ValueError(f"Unknown dialect '{dialect}'") 199 200 return result 201 202 @classmethod 203 def format_time( 204 cls, expression: t.Optional[str | exp.Expression] 205 ) -> t.Optional[exp.Expression]: 206 if isinstance(expression, str): 207 return exp.Literal.string( 208 # the time formats are quoted 209 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 210 ) 211 212 if expression and expression.is_string: 213 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 214 215 return expression 216 217 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 218 return self.parser(**opts).parse(self.tokenize(sql), sql) 219 220 def parse_into( 221 self, expression_type: exp.IntoType, sql: str, **opts 222 ) -> t.List[t.Optional[exp.Expression]]: 223 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 224 225 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 226 return self.generator(**opts).generate(expression) 227 228 def transpile(self, sql: str, **opts) -> t.List[str]: 229 return [self.generate(expression, **opts) for expression in self.parse(sql)] 230 231 def tokenize(self, sql: str) -> t.List[Token]: 232 return self.tokenizer.tokenize(sql) 233 234 @property 235 def tokenizer(self) -> Tokenizer: 236 if not hasattr(self, "_tokenizer"): 237 self._tokenizer = self.tokenizer_class() 238 return self._tokenizer 239 240 def parser(self, **opts) -> Parser: 241 return self.parser_class(**opts) 242 243 def generator(self, **opts) -> Generator: 244 return self.generator_class(**opts)
@classmethod
def
get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
187 @classmethod 188 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 189 if not dialect: 190 return cls 191 if isinstance(dialect, _Dialect): 192 return dialect 193 if isinstance(dialect, Dialect): 194 return dialect.__class__ 195 196 result = cls.get(dialect) 197 if not result: 198 raise ValueError(f"Unknown dialect '{dialect}'") 199 200 return result
@classmethod
def
format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
202 @classmethod 203 def format_time( 204 cls, expression: t.Optional[str | exp.Expression] 205 ) -> t.Optional[exp.Expression]: 206 if isinstance(expression, str): 207 return exp.Literal.string( 208 # the time formats are quoted 209 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 210 ) 211 212 if expression and expression.is_string: 213 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 214 215 return expression
def
parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
def
rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
def
approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
def
arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
def
arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
def
inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
def
no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
def
no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
def
no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
def
no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
def
no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
def
no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
def
no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
def
str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
330def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 331 this = self.sql(expression, "this") 332 substr = self.sql(expression, "substr") 333 position = self.sql(expression, "position") 334 if position: 335 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 336 return f"STRPOS({this}, {substr})"
def
struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
def
var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
345def var_map_sql( 346 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 347) -> str: 348 keys = expression.args["keys"] 349 values = expression.args["values"] 350 351 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 352 self.unsupported("Cannot convert array columns into map.") 353 return self.func(map_func_name, keys, values) 354 355 args = [] 356 for key, value in zip(keys.expressions, values.expressions): 357 args.append(self.sql(key)) 358 args.append(self.sql(value)) 359 360 return self.func(map_func_name, *args)
def
format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
363def format_time_lambda( 364 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 365) -> t.Callable[[t.List], E]: 366 """Helper used for time expressions. 367 368 Args: 369 exp_class: the expression class to instantiate. 370 dialect: target sql dialect. 371 default: the default format, True being time. 372 373 Returns: 374 A callable that can be used to return the appropriately formatted time expression. 375 """ 376 377 def _format_time(args: t.List): 378 return exp_class( 379 this=seq_get(args, 0), 380 format=Dialect[dialect].format_time( 381 seq_get(args, 1) 382 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 383 ), 384 ) 385 386 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
def
create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
389def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 390 """ 391 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 392 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 393 columns are removed from the create statement. 394 """ 395 has_schema = isinstance(expression.this, exp.Schema) 396 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 397 398 if has_schema and is_partitionable: 399 expression = expression.copy() 400 prop = expression.find(exp.PartitionedByProperty) 401 if prop and prop.this and not isinstance(prop.this, exp.Schema): 402 schema = expression.this 403 columns = {v.name.upper() for v in prop.this.expressions} 404 partitions = [col for col in schema.expressions if col.name.upper() in columns] 405 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 406 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 407 expression.set("this", schema) 408 409 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
def
parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
412def parse_date_delta( 413 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 414) -> t.Callable[[t.List], E]: 415 def inner_func(args: t.List) -> E: 416 unit_based = len(args) == 3 417 this = args[2] if unit_based else seq_get(args, 0) 418 unit = args[0] if unit_based else exp.Literal.string("DAY") 419 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 420 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 421 422 return inner_func
def
parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
425def parse_date_delta_with_interval( 426 expression_class: t.Type[E], 427) -> t.Callable[[t.List], t.Optional[E]]: 428 def func(args: t.List) -> t.Optional[E]: 429 if len(args) < 2: 430 return None 431 432 interval = args[1] 433 expression = interval.this 434 if expression and expression.is_string: 435 expression = exp.Literal.number(expression.this) 436 437 return expression_class( 438 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 439 ) 440 441 return func
def
date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
def
timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
def
strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
def
left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
def
right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
def
timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
def
datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
def
max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
def
count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
508def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 509 cond = expression.this 510 511 if isinstance(expression.this, exp.Distinct): 512 cond = expression.this.expressions[0] 513 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 514 515 return self.func("sum", exp.func("if", cond, 1, 0))
518def trim_sql(self: Generator, expression: exp.Trim) -> str: 519 target = self.sql(expression, "this") 520 trim_type = self.sql(expression, "position") 521 remove_chars = self.sql(expression, "expression") 522 collation = self.sql(expression, "collation") 523 524 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 525 if not remove_chars and not collation: 526 return self.trim_sql(expression) 527 528 trim_type = f"{trim_type} " if trim_type else "" 529 remove_chars = f"{remove_chars} " if remove_chars else "" 530 from_part = "FROM " if trim_type or remove_chars else "" 531 collation = f" COLLATE {collation}" if collation else "" 532 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def
str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
def
ts_or_ds_to_date_sql(dialect: str) -> Callable:
539def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 540 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 541 _dialect = Dialect.get_or_raise(dialect) 542 time_format = self.format_time(expression) 543 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 544 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 545 return f"CAST({self.sql(expression, 'this')} AS DATE)" 546 547 return _ts_or_ds_to_date_sql
def
concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
def
pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
559def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 560 names = [] 561 for agg in aggregations: 562 if isinstance(agg, exp.Alias): 563 names.append(agg.alias) 564 else: 565 """ 566 This case corresponds to aggregations without aliases being used as suffixes 567 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 568 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 569 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 570 """ 571 agg_all_unquoted = agg.transform( 572 lambda node: exp.Identifier(this=node.name, quoted=False) 573 if isinstance(node, exp.Identifier) 574 else node 575 ) 576 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 577 578 return names