sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot.generator import Generator 8from sqlglot.helper import flatten, seq_get 9from sqlglot.parser import Parser 10from sqlglot.time import format_time 11from sqlglot.tokens import Token, Tokenizer, TokenType 12from sqlglot.trie import new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot._typing import E 16 17 18class Dialects(str, Enum): 19 DIALECT = "" 20 21 BIGQUERY = "bigquery" 22 CLICKHOUSE = "clickhouse" 23 DUCKDB = "duckdb" 24 HIVE = "hive" 25 MYSQL = "mysql" 26 ORACLE = "oracle" 27 POSTGRES = "postgres" 28 PRESTO = "presto" 29 REDSHIFT = "redshift" 30 SNOWFLAKE = "snowflake" 31 SPARK = "spark" 32 SPARK2 = "spark2" 33 SQLITE = "sqlite" 34 STARROCKS = "starrocks" 35 TABLEAU = "tableau" 36 TRINO = "trino" 37 TSQL = "tsql" 38 DATABRICKS = "databricks" 39 DRILL = "drill" 40 TERADATA = "teradata" 41 42 43class _Dialect(type): 44 classes: t.Dict[str, t.Type[Dialect]] = {} 45 46 def __eq__(cls, other: t.Any) -> bool: 47 if cls is other: 48 return True 49 if isinstance(other, str): 50 return cls is cls.get(other) 51 if isinstance(other, Dialect): 52 return cls is type(other) 53 54 return False 55 56 def __hash__(cls) -> int: 57 return hash(cls.__name__.lower()) 58 59 @classmethod 60 def __getitem__(cls, key: str) -> t.Type[Dialect]: 61 return cls.classes[key] 62 63 @classmethod 64 def get( 65 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 66 ) -> t.Optional[t.Type[Dialect]]: 67 return cls.classes.get(key, default) 68 69 def __new__(cls, clsname, bases, attrs): 70 klass = super().__new__(cls, clsname, bases, attrs) 71 enum = Dialects.__members__.get(clsname.upper()) 72 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 73 74 klass.time_trie = new_trie(klass.time_mapping) 75 klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} 76 klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) 77 78 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 79 klass.parser_class = getattr(klass, "Parser", Parser) 80 klass.generator_class = getattr(klass, "Generator", Generator) 81 82 klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] 83 klass.identifier_start, klass.identifier_end = list( 84 klass.tokenizer_class._IDENTIFIERS.items() 85 )[0] 86 87 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 88 return next( 89 ( 90 (s, e) 91 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 92 if t == token_type 93 ), 94 (None, None), 95 ) 96 97 klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) 98 klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) 99 klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) 100 klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) 101 102 return klass 103 104 105class Dialect(metaclass=_Dialect): 106 index_offset = 0 107 unnest_column_only = False 108 alias_post_tablesample = False 109 normalize_functions: t.Optional[str] = "upper" 110 null_ordering = "nulls_are_small" 111 112 date_format = "'%Y-%m-%d'" 113 dateint_format = "'%Y%m%d'" 114 time_format = "'%Y-%m-%d %H:%M:%S'" 115 time_mapping: t.Dict[str, str] = {} 116 117 # autofilled 118 quote_start = None 119 quote_end = None 120 identifier_start = None 121 identifier_end = None 122 123 time_trie = None 124 inverse_time_mapping = None 125 inverse_time_trie = None 126 tokenizer_class = None 127 parser_class = None 128 generator_class = None 129 130 def __eq__(self, other: t.Any) -> bool: 131 return type(self) == other 132 133 def __hash__(self) -> int: 134 return hash(type(self)) 135 136 @classmethod 137 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 138 if not dialect: 139 return cls 140 if isinstance(dialect, _Dialect): 141 return dialect 142 if isinstance(dialect, Dialect): 143 return dialect.__class__ 144 145 result = cls.get(dialect) 146 if not result: 147 raise ValueError(f"Unknown dialect '{dialect}'") 148 149 return result 150 151 @classmethod 152 def format_time( 153 cls, expression: t.Optional[str | exp.Expression] 154 ) -> t.Optional[exp.Expression]: 155 if isinstance(expression, str): 156 return exp.Literal.string( 157 format_time( 158 expression[1:-1], # the time formats are quoted 159 cls.time_mapping, 160 cls.time_trie, 161 ) 162 ) 163 if expression and expression.is_string: 164 return exp.Literal.string( 165 format_time( 166 expression.this, 167 cls.time_mapping, 168 cls.time_trie, 169 ) 170 ) 171 return expression 172 173 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 174 return self.parser(**opts).parse(self.tokenize(sql), sql) 175 176 def parse_into( 177 self, expression_type: exp.IntoType, sql: str, **opts 178 ) -> t.List[t.Optional[exp.Expression]]: 179 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 180 181 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 182 return self.generator(**opts).generate(expression) 183 184 def transpile(self, sql: str, **opts) -> t.List[str]: 185 return [self.generate(expression, **opts) for expression in self.parse(sql)] 186 187 def tokenize(self, sql: str) -> t.List[Token]: 188 return self.tokenizer.tokenize(sql) 189 190 @property 191 def tokenizer(self) -> Tokenizer: 192 if not hasattr(self, "_tokenizer"): 193 self._tokenizer = self.tokenizer_class() # type: ignore 194 return self._tokenizer 195 196 def parser(self, **opts) -> Parser: 197 return self.parser_class( # type: ignore 198 **{ 199 "index_offset": self.index_offset, 200 "unnest_column_only": self.unnest_column_only, 201 "alias_post_tablesample": self.alias_post_tablesample, 202 "null_ordering": self.null_ordering, 203 **opts, 204 }, 205 ) 206 207 def generator(self, **opts) -> Generator: 208 return self.generator_class( # type: ignore 209 **{ 210 "quote_start": self.quote_start, 211 "quote_end": self.quote_end, 212 "bit_start": self.bit_start, 213 "bit_end": self.bit_end, 214 "hex_start": self.hex_start, 215 "hex_end": self.hex_end, 216 "byte_start": self.byte_start, 217 "byte_end": self.byte_end, 218 "raw_start": self.raw_start, 219 "raw_end": self.raw_end, 220 "identifier_start": self.identifier_start, 221 "identifier_end": self.identifier_end, 222 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 223 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 224 "index_offset": self.index_offset, 225 "time_mapping": self.inverse_time_mapping, 226 "time_trie": self.inverse_time_trie, 227 "unnest_column_only": self.unnest_column_only, 228 "alias_post_tablesample": self.alias_post_tablesample, 229 "normalize_functions": self.normalize_functions, 230 "null_ordering": self.null_ordering, 231 **opts, 232 } 233 ) 234 235 236DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 237 238 239def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 240 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 241 242 243def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 244 if expression.args.get("accuracy"): 245 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 246 return self.func("APPROX_COUNT_DISTINCT", expression.this) 247 248 249def if_sql(self: Generator, expression: exp.If) -> str: 250 return self.func( 251 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 252 ) 253 254 255def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 256 return self.binary(expression, "->") 257 258 259def arrow_json_extract_scalar_sql( 260 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 261) -> str: 262 return self.binary(expression, "->>") 263 264 265def inline_array_sql(self: Generator, expression: exp.Array) -> str: 266 return f"[{self.expressions(expression)}]" 267 268 269def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 270 return self.like_sql( 271 exp.Like( 272 this=exp.Lower(this=expression.this), 273 expression=expression.args["expression"], 274 ) 275 ) 276 277 278def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 279 zone = self.sql(expression, "this") 280 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 281 282 283def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 284 if expression.args.get("recursive"): 285 self.unsupported("Recursive CTEs are unsupported") 286 expression.args["recursive"] = False 287 return self.with_sql(expression) 288 289 290def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 291 n = self.sql(expression, "this") 292 d = self.sql(expression, "expression") 293 return f"IF({d} <> 0, {n} / {d}, NULL)" 294 295 296def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 297 self.unsupported("TABLESAMPLE unsupported") 298 return self.sql(expression.this) 299 300 301def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 302 self.unsupported("PIVOT unsupported") 303 return "" 304 305 306def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 307 return self.cast_sql(expression) 308 309 310def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 311 self.unsupported("Properties unsupported") 312 return "" 313 314 315def no_comment_column_constraint_sql( 316 self: Generator, expression: exp.CommentColumnConstraint 317) -> str: 318 self.unsupported("CommentColumnConstraint unsupported") 319 return "" 320 321 322def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 323 this = self.sql(expression, "this") 324 substr = self.sql(expression, "substr") 325 position = self.sql(expression, "position") 326 if position: 327 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 328 return f"STRPOS({this}, {substr})" 329 330 331def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 332 this = self.sql(expression, "this") 333 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 334 return f"{this}.{struct_key}" 335 336 337def var_map_sql( 338 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 339) -> str: 340 keys = expression.args["keys"] 341 values = expression.args["values"] 342 343 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 344 self.unsupported("Cannot convert array columns into map.") 345 return self.func(map_func_name, keys, values) 346 347 args = [] 348 for key, value in zip(keys.expressions, values.expressions): 349 args.append(self.sql(key)) 350 args.append(self.sql(value)) 351 return self.func(map_func_name, *args) 352 353 354def format_time_lambda( 355 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 356) -> t.Callable[[t.List], E]: 357 """Helper used for time expressions. 358 359 Args: 360 exp_class: the expression class to instantiate. 361 dialect: target sql dialect. 362 default: the default format, True being time. 363 364 Returns: 365 A callable that can be used to return the appropriately formatted time expression. 366 """ 367 368 def _format_time(args: t.List): 369 return exp_class( 370 this=seq_get(args, 0), 371 format=Dialect[dialect].format_time( 372 seq_get(args, 1) 373 or (Dialect[dialect].time_format if default is True else default or None) 374 ), 375 ) 376 377 return _format_time 378 379 380def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 381 """ 382 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 383 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 384 columns are removed from the create statement. 385 """ 386 has_schema = isinstance(expression.this, exp.Schema) 387 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 388 389 if has_schema and is_partitionable: 390 expression = expression.copy() 391 prop = expression.find(exp.PartitionedByProperty) 392 if prop and prop.this and not isinstance(prop.this, exp.Schema): 393 schema = expression.this 394 columns = {v.name.upper() for v in prop.this.expressions} 395 partitions = [col for col in schema.expressions if col.name.upper() in columns] 396 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 397 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 398 expression.set("this", schema) 399 400 return self.create_sql(expression) 401 402 403def parse_date_delta( 404 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 405) -> t.Callable[[t.List], E]: 406 def inner_func(args: t.List) -> E: 407 unit_based = len(args) == 3 408 this = args[2] if unit_based else seq_get(args, 0) 409 unit = args[0] if unit_based else exp.Literal.string("DAY") 410 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 411 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 412 413 return inner_func 414 415 416def parse_date_delta_with_interval( 417 expression_class: t.Type[E], 418) -> t.Callable[[t.List], t.Optional[E]]: 419 def func(args: t.List) -> t.Optional[E]: 420 if len(args) < 2: 421 return None 422 423 interval = args[1] 424 expression = interval.this 425 if expression and expression.is_string: 426 expression = exp.Literal.number(expression.this) 427 428 return expression_class( 429 this=args[0], 430 expression=expression, 431 unit=exp.Literal.string(interval.text("unit")), 432 ) 433 434 return func 435 436 437def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 438 unit = seq_get(args, 0) 439 this = seq_get(args, 1) 440 441 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 442 return exp.DateTrunc(unit=unit, this=this) 443 return exp.TimestampTrunc(this=this, unit=unit) 444 445 446def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 447 return self.func( 448 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 449 ) 450 451 452def locate_to_strposition(args: t.List) -> exp.Expression: 453 return exp.StrPosition( 454 this=seq_get(args, 1), 455 substr=seq_get(args, 0), 456 position=seq_get(args, 2), 457 ) 458 459 460def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 461 return self.func( 462 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 463 ) 464 465 466def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 467 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 468 469 470def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 471 return f"CAST({self.sql(expression, 'this')} AS DATE)" 472 473 474def min_or_least(self: Generator, expression: exp.Min) -> str: 475 name = "LEAST" if expression.expressions else "MIN" 476 return rename_func(name)(self, expression) 477 478 479def max_or_greatest(self: Generator, expression: exp.Max) -> str: 480 name = "GREATEST" if expression.expressions else "MAX" 481 return rename_func(name)(self, expression) 482 483 484def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 485 cond = expression.this 486 487 if isinstance(expression.this, exp.Distinct): 488 cond = expression.this.expressions[0] 489 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 490 491 return self.func("sum", exp.func("if", cond, 1, 0)) 492 493 494def trim_sql(self: Generator, expression: exp.Trim) -> str: 495 target = self.sql(expression, "this") 496 trim_type = self.sql(expression, "position") 497 remove_chars = self.sql(expression, "expression") 498 collation = self.sql(expression, "collation") 499 500 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 501 if not remove_chars and not collation: 502 return self.trim_sql(expression) 503 504 trim_type = f"{trim_type} " if trim_type else "" 505 remove_chars = f"{remove_chars} " if remove_chars else "" 506 from_part = "FROM " if trim_type or remove_chars else "" 507 collation = f" COLLATE {collation}" if collation else "" 508 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 509 510 511def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 512 return self.func("STRPTIME", expression.this, self.format_time(expression)) 513 514 515def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 516 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 517 _dialect = Dialect.get_or_raise(dialect) 518 time_format = self.format_time(expression) 519 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 520 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 521 return f"CAST({self.sql(expression, 'this')} AS DATE)" 522 523 return _ts_or_ds_to_date_sql 524 525 526# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator 527def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 528 names = [] 529 for agg in aggregations: 530 if isinstance(agg, exp.Alias): 531 names.append(agg.alias) 532 else: 533 """ 534 This case corresponds to aggregations without aliases being used as suffixes 535 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 536 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 537 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 538 """ 539 agg_all_unquoted = agg.transform( 540 lambda node: exp.Identifier(this=node.name, quoted=False) 541 if isinstance(node, exp.Identifier) 542 else node 543 ) 544 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 545 546 return names
class
Dialects(builtins.str, enum.Enum):
19class Dialects(str, Enum): 20 DIALECT = "" 21 22 BIGQUERY = "bigquery" 23 CLICKHOUSE = "clickhouse" 24 DUCKDB = "duckdb" 25 HIVE = "hive" 26 MYSQL = "mysql" 27 ORACLE = "oracle" 28 POSTGRES = "postgres" 29 PRESTO = "presto" 30 REDSHIFT = "redshift" 31 SNOWFLAKE = "snowflake" 32 SPARK = "spark" 33 SPARK2 = "spark2" 34 SQLITE = "sqlite" 35 STARROCKS = "starrocks" 36 TABLEAU = "tableau" 37 TRINO = "trino" 38 TSQL = "tsql" 39 DATABRICKS = "databricks" 40 DRILL = "drill" 41 TERADATA = "teradata"
An enumeration.
DIALECT =
<Dialects.DIALECT: ''>
BIGQUERY =
<Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE =
<Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB =
<Dialects.DUCKDB: 'duckdb'>
HIVE =
<Dialects.HIVE: 'hive'>
MYSQL =
<Dialects.MYSQL: 'mysql'>
ORACLE =
<Dialects.ORACLE: 'oracle'>
POSTGRES =
<Dialects.POSTGRES: 'postgres'>
PRESTO =
<Dialects.PRESTO: 'presto'>
REDSHIFT =
<Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE =
<Dialects.SNOWFLAKE: 'snowflake'>
SPARK =
<Dialects.SPARK: 'spark'>
SPARK2 =
<Dialects.SPARK2: 'spark2'>
SQLITE =
<Dialects.SQLITE: 'sqlite'>
STARROCKS =
<Dialects.STARROCKS: 'starrocks'>
TABLEAU =
<Dialects.TABLEAU: 'tableau'>
TRINO =
<Dialects.TRINO: 'trino'>
TSQL =
<Dialects.TSQL: 'tsql'>
DATABRICKS =
<Dialects.DATABRICKS: 'databricks'>
DRILL =
<Dialects.DRILL: 'drill'>
TERADATA =
<Dialects.TERADATA: 'teradata'>
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
class
Dialect:
106class Dialect(metaclass=_Dialect): 107 index_offset = 0 108 unnest_column_only = False 109 alias_post_tablesample = False 110 normalize_functions: t.Optional[str] = "upper" 111 null_ordering = "nulls_are_small" 112 113 date_format = "'%Y-%m-%d'" 114 dateint_format = "'%Y%m%d'" 115 time_format = "'%Y-%m-%d %H:%M:%S'" 116 time_mapping: t.Dict[str, str] = {} 117 118 # autofilled 119 quote_start = None 120 quote_end = None 121 identifier_start = None 122 identifier_end = None 123 124 time_trie = None 125 inverse_time_mapping = None 126 inverse_time_trie = None 127 tokenizer_class = None 128 parser_class = None 129 generator_class = None 130 131 def __eq__(self, other: t.Any) -> bool: 132 return type(self) == other 133 134 def __hash__(self) -> int: 135 return hash(type(self)) 136 137 @classmethod 138 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 139 if not dialect: 140 return cls 141 if isinstance(dialect, _Dialect): 142 return dialect 143 if isinstance(dialect, Dialect): 144 return dialect.__class__ 145 146 result = cls.get(dialect) 147 if not result: 148 raise ValueError(f"Unknown dialect '{dialect}'") 149 150 return result 151 152 @classmethod 153 def format_time( 154 cls, expression: t.Optional[str | exp.Expression] 155 ) -> t.Optional[exp.Expression]: 156 if isinstance(expression, str): 157 return exp.Literal.string( 158 format_time( 159 expression[1:-1], # the time formats are quoted 160 cls.time_mapping, 161 cls.time_trie, 162 ) 163 ) 164 if expression and expression.is_string: 165 return exp.Literal.string( 166 format_time( 167 expression.this, 168 cls.time_mapping, 169 cls.time_trie, 170 ) 171 ) 172 return expression 173 174 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 175 return self.parser(**opts).parse(self.tokenize(sql), sql) 176 177 def parse_into( 178 self, expression_type: exp.IntoType, sql: str, **opts 179 ) -> t.List[t.Optional[exp.Expression]]: 180 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 181 182 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 183 return self.generator(**opts).generate(expression) 184 185 def transpile(self, sql: str, **opts) -> t.List[str]: 186 return [self.generate(expression, **opts) for expression in self.parse(sql)] 187 188 def tokenize(self, sql: str) -> t.List[Token]: 189 return self.tokenizer.tokenize(sql) 190 191 @property 192 def tokenizer(self) -> Tokenizer: 193 if not hasattr(self, "_tokenizer"): 194 self._tokenizer = self.tokenizer_class() # type: ignore 195 return self._tokenizer 196 197 def parser(self, **opts) -> Parser: 198 return self.parser_class( # type: ignore 199 **{ 200 "index_offset": self.index_offset, 201 "unnest_column_only": self.unnest_column_only, 202 "alias_post_tablesample": self.alias_post_tablesample, 203 "null_ordering": self.null_ordering, 204 **opts, 205 }, 206 ) 207 208 def generator(self, **opts) -> Generator: 209 return self.generator_class( # type: ignore 210 **{ 211 "quote_start": self.quote_start, 212 "quote_end": self.quote_end, 213 "bit_start": self.bit_start, 214 "bit_end": self.bit_end, 215 "hex_start": self.hex_start, 216 "hex_end": self.hex_end, 217 "byte_start": self.byte_start, 218 "byte_end": self.byte_end, 219 "raw_start": self.raw_start, 220 "raw_end": self.raw_end, 221 "identifier_start": self.identifier_start, 222 "identifier_end": self.identifier_end, 223 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 224 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 225 "index_offset": self.index_offset, 226 "time_mapping": self.inverse_time_mapping, 227 "time_trie": self.inverse_time_trie, 228 "unnest_column_only": self.unnest_column_only, 229 "alias_post_tablesample": self.alias_post_tablesample, 230 "normalize_functions": self.normalize_functions, 231 "null_ordering": self.null_ordering, 232 **opts, 233 } 234 )
@classmethod
def
get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
137 @classmethod 138 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 139 if not dialect: 140 return cls 141 if isinstance(dialect, _Dialect): 142 return dialect 143 if isinstance(dialect, Dialect): 144 return dialect.__class__ 145 146 result = cls.get(dialect) 147 if not result: 148 raise ValueError(f"Unknown dialect '{dialect}'") 149 150 return result
@classmethod
def
format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
152 @classmethod 153 def format_time( 154 cls, expression: t.Optional[str | exp.Expression] 155 ) -> t.Optional[exp.Expression]: 156 if isinstance(expression, str): 157 return exp.Literal.string( 158 format_time( 159 expression[1:-1], # the time formats are quoted 160 cls.time_mapping, 161 cls.time_trie, 162 ) 163 ) 164 if expression and expression.is_string: 165 return exp.Literal.string( 166 format_time( 167 expression.this, 168 cls.time_mapping, 169 cls.time_trie, 170 ) 171 ) 172 return expression
def
parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
197 def parser(self, **opts) -> Parser: 198 return self.parser_class( # type: ignore 199 **{ 200 "index_offset": self.index_offset, 201 "unnest_column_only": self.unnest_column_only, 202 "alias_post_tablesample": self.alias_post_tablesample, 203 "null_ordering": self.null_ordering, 204 **opts, 205 }, 206 )
208 def generator(self, **opts) -> Generator: 209 return self.generator_class( # type: ignore 210 **{ 211 "quote_start": self.quote_start, 212 "quote_end": self.quote_end, 213 "bit_start": self.bit_start, 214 "bit_end": self.bit_end, 215 "hex_start": self.hex_start, 216 "hex_end": self.hex_end, 217 "byte_start": self.byte_start, 218 "byte_end": self.byte_end, 219 "raw_start": self.raw_start, 220 "raw_end": self.raw_end, 221 "identifier_start": self.identifier_start, 222 "identifier_end": self.identifier_end, 223 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 224 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 225 "index_offset": self.index_offset, 226 "time_mapping": self.inverse_time_mapping, 227 "time_trie": self.inverse_time_trie, 228 "unnest_column_only": self.unnest_column_only, 229 "alias_post_tablesample": self.alias_post_tablesample, 230 "normalize_functions": self.normalize_functions, 231 "null_ordering": self.null_ordering, 232 **opts, 233 } 234 )
def
rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
def
approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
def
arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
def
arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
def
inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
def
no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
def
no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
def
no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
def
no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
def
no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
def
no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
def
no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
def
str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
323def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 324 this = self.sql(expression, "this") 325 substr = self.sql(expression, "substr") 326 position = self.sql(expression, "position") 327 if position: 328 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 329 return f"STRPOS({this}, {substr})"
def
struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
def
var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
338def var_map_sql( 339 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 340) -> str: 341 keys = expression.args["keys"] 342 values = expression.args["values"] 343 344 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 345 self.unsupported("Cannot convert array columns into map.") 346 return self.func(map_func_name, keys, values) 347 348 args = [] 349 for key, value in zip(keys.expressions, values.expressions): 350 args.append(self.sql(key)) 351 args.append(self.sql(value)) 352 return self.func(map_func_name, *args)
def
format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[List], ~E]:
355def format_time_lambda( 356 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 357) -> t.Callable[[t.List], E]: 358 """Helper used for time expressions. 359 360 Args: 361 exp_class: the expression class to instantiate. 362 dialect: target sql dialect. 363 default: the default format, True being time. 364 365 Returns: 366 A callable that can be used to return the appropriately formatted time expression. 367 """ 368 369 def _format_time(args: t.List): 370 return exp_class( 371 this=seq_get(args, 0), 372 format=Dialect[dialect].format_time( 373 seq_get(args, 1) 374 or (Dialect[dialect].time_format if default is True else default or None) 375 ), 376 ) 377 378 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
def
create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
381def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 382 """ 383 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 384 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 385 columns are removed from the create statement. 386 """ 387 has_schema = isinstance(expression.this, exp.Schema) 388 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 389 390 if has_schema and is_partitionable: 391 expression = expression.copy() 392 prop = expression.find(exp.PartitionedByProperty) 393 if prop and prop.this and not isinstance(prop.this, exp.Schema): 394 schema = expression.this 395 columns = {v.name.upper() for v in prop.this.expressions} 396 partitions = [col for col in schema.expressions if col.name.upper() in columns] 397 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 398 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 399 expression.set("this", schema) 400 401 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
def
parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
404def parse_date_delta( 405 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 406) -> t.Callable[[t.List], E]: 407 def inner_func(args: t.List) -> E: 408 unit_based = len(args) == 3 409 this = args[2] if unit_based else seq_get(args, 0) 410 unit = args[0] if unit_based else exp.Literal.string("DAY") 411 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 412 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 413 414 return inner_func
def
parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
417def parse_date_delta_with_interval( 418 expression_class: t.Type[E], 419) -> t.Callable[[t.List], t.Optional[E]]: 420 def func(args: t.List) -> t.Optional[E]: 421 if len(args) < 2: 422 return None 423 424 interval = args[1] 425 expression = interval.this 426 if expression and expression.is_string: 427 expression = exp.Literal.number(expression.this) 428 429 return expression_class( 430 this=args[0], 431 expression=expression, 432 unit=exp.Literal.string(interval.text("unit")), 433 ) 434 435 return func
def
date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
438def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 439 unit = seq_get(args, 0) 440 this = seq_get(args, 1) 441 442 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 443 return exp.DateTrunc(unit=unit, this=this) 444 return exp.TimestampTrunc(this=this, unit=unit)
def
timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
def
strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
def
timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
def
datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
def
max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
def
count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
485def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 486 cond = expression.this 487 488 if isinstance(expression.this, exp.Distinct): 489 cond = expression.this.expressions[0] 490 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 491 492 return self.func("sum", exp.func("if", cond, 1, 0))
495def trim_sql(self: Generator, expression: exp.Trim) -> str: 496 target = self.sql(expression, "this") 497 trim_type = self.sql(expression, "position") 498 remove_chars = self.sql(expression, "expression") 499 collation = self.sql(expression, "collation") 500 501 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 502 if not remove_chars and not collation: 503 return self.trim_sql(expression) 504 505 trim_type = f"{trim_type} " if trim_type else "" 506 remove_chars = f"{remove_chars} " if remove_chars else "" 507 from_part = "FROM " if trim_type or remove_chars else "" 508 collation = f" COLLATE {collation}" if collation else "" 509 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def
str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
def
ts_or_ds_to_date_sql(dialect: str) -> Callable:
516def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 517 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 518 _dialect = Dialect.get_or_raise(dialect) 519 time_format = self.format_time(expression) 520 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 521 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 522 return f"CAST({self.sql(expression, 'this')} AS DATE)" 523 524 return _ts_or_ds_to_date_sql
def
pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
528def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 529 names = [] 530 for agg in aggregations: 531 if isinstance(agg, exp.Alias): 532 names.append(agg.alias) 533 else: 534 """ 535 This case corresponds to aggregations without aliases being used as suffixes 536 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 537 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 538 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 539 """ 540 agg_all_unquoted = agg.transform( 541 lambda node: exp.Identifier(this=node.name, quoted=False) 542 if isinstance(node, exp.Identifier) 543 else node 544 ) 545 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 546 547 return names