sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot.generator import Generator 8from sqlglot.helper import flatten, seq_get 9from sqlglot.parser import Parser 10from sqlglot.time import format_time 11from sqlglot.tokens import Token, Tokenizer 12from sqlglot.trie import new_trie 13 14E = t.TypeVar("E", bound=exp.Expression) 15 16 17class Dialects(str, Enum): 18 DIALECT = "" 19 20 BIGQUERY = "bigquery" 21 CLICKHOUSE = "clickhouse" 22 DUCKDB = "duckdb" 23 HIVE = "hive" 24 MYSQL = "mysql" 25 ORACLE = "oracle" 26 POSTGRES = "postgres" 27 PRESTO = "presto" 28 REDSHIFT = "redshift" 29 SNOWFLAKE = "snowflake" 30 SPARK = "spark" 31 SPARK2 = "spark2" 32 SQLITE = "sqlite" 33 STARROCKS = "starrocks" 34 TABLEAU = "tableau" 35 TRINO = "trino" 36 TSQL = "tsql" 37 DATABRICKS = "databricks" 38 DRILL = "drill" 39 TERADATA = "teradata" 40 41 42class _Dialect(type): 43 classes: t.Dict[str, t.Type[Dialect]] = {} 44 45 @classmethod 46 def __getitem__(cls, key: str) -> t.Type[Dialect]: 47 return cls.classes[key] 48 49 @classmethod 50 def get( 51 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 52 ) -> t.Optional[t.Type[Dialect]]: 53 return cls.classes.get(key, default) 54 55 def __new__(cls, clsname, bases, attrs): 56 klass = super().__new__(cls, clsname, bases, attrs) 57 enum = Dialects.__members__.get(clsname.upper()) 58 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 59 60 klass.time_trie = new_trie(klass.time_mapping) 61 klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} 62 klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) 63 64 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 65 klass.parser_class = getattr(klass, "Parser", Parser) 66 klass.generator_class = getattr(klass, "Generator", Generator) 67 68 klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] 69 klass.identifier_start, klass.identifier_end = list( 70 klass.tokenizer_class._IDENTIFIERS.items() 71 )[0] 72 73 if ( 74 klass.tokenizer_class._BIT_STRINGS 75 and exp.BitString not in klass.generator_class.TRANSFORMS 76 ): 77 bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] 78 klass.generator_class.TRANSFORMS[ 79 exp.BitString 80 ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" 81 if ( 82 klass.tokenizer_class._HEX_STRINGS 83 and exp.HexString not in klass.generator_class.TRANSFORMS 84 ): 85 hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] 86 klass.generator_class.TRANSFORMS[ 87 exp.HexString 88 ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" 89 if ( 90 klass.tokenizer_class._BYTE_STRINGS 91 and exp.ByteString not in klass.generator_class.TRANSFORMS 92 ): 93 be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] 94 klass.generator_class.TRANSFORMS[ 95 exp.ByteString 96 ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}" 97 98 return klass 99 100 101class Dialect(metaclass=_Dialect): 102 index_offset = 0 103 unnest_column_only = False 104 alias_post_tablesample = False 105 normalize_functions: t.Optional[str] = "upper" 106 null_ordering = "nulls_are_small" 107 108 date_format = "'%Y-%m-%d'" 109 dateint_format = "'%Y%m%d'" 110 time_format = "'%Y-%m-%d %H:%M:%S'" 111 time_mapping: t.Dict[str, str] = {} 112 113 # autofilled 114 quote_start = None 115 quote_end = None 116 identifier_start = None 117 identifier_end = None 118 119 time_trie = None 120 inverse_time_mapping = None 121 inverse_time_trie = None 122 tokenizer_class = None 123 parser_class = None 124 generator_class = None 125 126 @classmethod 127 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 128 if not dialect: 129 return cls 130 if isinstance(dialect, _Dialect): 131 return dialect 132 if isinstance(dialect, Dialect): 133 return dialect.__class__ 134 135 result = cls.get(dialect) 136 if not result: 137 raise ValueError(f"Unknown dialect '{dialect}'") 138 139 return result 140 141 @classmethod 142 def format_time( 143 cls, expression: t.Optional[str | exp.Expression] 144 ) -> t.Optional[exp.Expression]: 145 if isinstance(expression, str): 146 return exp.Literal.string( 147 format_time( 148 expression[1:-1], # the time formats are quoted 149 cls.time_mapping, 150 cls.time_trie, 151 ) 152 ) 153 if expression and expression.is_string: 154 return exp.Literal.string( 155 format_time( 156 expression.this, 157 cls.time_mapping, 158 cls.time_trie, 159 ) 160 ) 161 return expression 162 163 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 164 return self.parser(**opts).parse(self.tokenize(sql), sql) 165 166 def parse_into( 167 self, expression_type: exp.IntoType, sql: str, **opts 168 ) -> t.List[t.Optional[exp.Expression]]: 169 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 170 171 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 172 return self.generator(**opts).generate(expression) 173 174 def transpile(self, sql: str, **opts) -> t.List[str]: 175 return [self.generate(expression, **opts) for expression in self.parse(sql)] 176 177 def tokenize(self, sql: str) -> t.List[Token]: 178 return self.tokenizer.tokenize(sql) 179 180 @property 181 def tokenizer(self) -> Tokenizer: 182 if not hasattr(self, "_tokenizer"): 183 self._tokenizer = self.tokenizer_class() # type: ignore 184 return self._tokenizer 185 186 def parser(self, **opts) -> Parser: 187 return self.parser_class( # type: ignore 188 **{ 189 "index_offset": self.index_offset, 190 "unnest_column_only": self.unnest_column_only, 191 "alias_post_tablesample": self.alias_post_tablesample, 192 "null_ordering": self.null_ordering, 193 **opts, 194 }, 195 ) 196 197 def generator(self, **opts) -> Generator: 198 return self.generator_class( # type: ignore 199 **{ 200 "quote_start": self.quote_start, 201 "quote_end": self.quote_end, 202 "identifier_start": self.identifier_start, 203 "identifier_end": self.identifier_end, 204 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 205 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 206 "index_offset": self.index_offset, 207 "time_mapping": self.inverse_time_mapping, 208 "time_trie": self.inverse_time_trie, 209 "unnest_column_only": self.unnest_column_only, 210 "alias_post_tablesample": self.alias_post_tablesample, 211 "normalize_functions": self.normalize_functions, 212 "null_ordering": self.null_ordering, 213 **opts, 214 } 215 ) 216 217 218DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 219 220 221def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 222 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 223 224 225def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 226 if expression.args.get("accuracy"): 227 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 228 return self.func("APPROX_COUNT_DISTINCT", expression.this) 229 230 231def if_sql(self: Generator, expression: exp.If) -> str: 232 return self.func( 233 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 234 ) 235 236 237def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 238 return self.binary(expression, "->") 239 240 241def arrow_json_extract_scalar_sql( 242 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 243) -> str: 244 return self.binary(expression, "->>") 245 246 247def inline_array_sql(self: Generator, expression: exp.Array) -> str: 248 return f"[{self.expressions(expression)}]" 249 250 251def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 252 return self.like_sql( 253 exp.Like( 254 this=exp.Lower(this=expression.this), 255 expression=expression.args["expression"], 256 ) 257 ) 258 259 260def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 261 zone = self.sql(expression, "this") 262 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 263 264 265def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 266 if expression.args.get("recursive"): 267 self.unsupported("Recursive CTEs are unsupported") 268 expression.args["recursive"] = False 269 return self.with_sql(expression) 270 271 272def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 273 n = self.sql(expression, "this") 274 d = self.sql(expression, "expression") 275 return f"IF({d} <> 0, {n} / {d}, NULL)" 276 277 278def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 279 self.unsupported("TABLESAMPLE unsupported") 280 return self.sql(expression.this) 281 282 283def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 284 self.unsupported("PIVOT unsupported") 285 return self.sql(expression) 286 287 288def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 289 return self.cast_sql(expression) 290 291 292def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 293 self.unsupported("Properties unsupported") 294 return "" 295 296 297def no_comment_column_constraint_sql( 298 self: Generator, expression: exp.CommentColumnConstraint 299) -> str: 300 self.unsupported("CommentColumnConstraint unsupported") 301 return "" 302 303 304def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 305 this = self.sql(expression, "this") 306 substr = self.sql(expression, "substr") 307 position = self.sql(expression, "position") 308 if position: 309 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 310 return f"STRPOS({this}, {substr})" 311 312 313def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 314 this = self.sql(expression, "this") 315 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 316 return f"{this}.{struct_key}" 317 318 319def var_map_sql( 320 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 321) -> str: 322 keys = expression.args["keys"] 323 values = expression.args["values"] 324 325 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 326 self.unsupported("Cannot convert array columns into map.") 327 return self.func(map_func_name, keys, values) 328 329 args = [] 330 for key, value in zip(keys.expressions, values.expressions): 331 args.append(self.sql(key)) 332 args.append(self.sql(value)) 333 return self.func(map_func_name, *args) 334 335 336def format_time_lambda( 337 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 338) -> t.Callable[[t.Sequence], E]: 339 """Helper used for time expressions. 340 341 Args: 342 exp_class: the expression class to instantiate. 343 dialect: target sql dialect. 344 default: the default format, True being time. 345 346 Returns: 347 A callable that can be used to return the appropriately formatted time expression. 348 """ 349 350 def _format_time(args: t.Sequence): 351 return exp_class( 352 this=seq_get(args, 0), 353 format=Dialect[dialect].format_time( 354 seq_get(args, 1) 355 or (Dialect[dialect].time_format if default is True else default or None) 356 ), 357 ) 358 359 return _format_time 360 361 362def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 363 """ 364 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 365 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 366 columns are removed from the create statement. 367 """ 368 has_schema = isinstance(expression.this, exp.Schema) 369 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 370 371 if has_schema and is_partitionable: 372 expression = expression.copy() 373 prop = expression.find(exp.PartitionedByProperty) 374 if prop and prop.this and not isinstance(prop.this, exp.Schema): 375 schema = expression.this 376 columns = {v.name.upper() for v in prop.this.expressions} 377 partitions = [col for col in schema.expressions if col.name.upper() in columns] 378 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 379 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 380 expression.set("this", schema) 381 382 return self.create_sql(expression) 383 384 385def parse_date_delta( 386 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 387) -> t.Callable[[t.Sequence], E]: 388 def inner_func(args: t.Sequence) -> E: 389 unit_based = len(args) == 3 390 this = args[2] if unit_based else seq_get(args, 0) 391 unit = args[0] if unit_based else exp.Literal.string("DAY") 392 unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit 393 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 394 395 return inner_func 396 397 398def parse_date_delta_with_interval( 399 expression_class: t.Type[E], 400) -> t.Callable[[t.Sequence], t.Optional[E]]: 401 def func(args: t.Sequence) -> t.Optional[E]: 402 if len(args) < 2: 403 return None 404 405 interval = args[1] 406 expression = interval.this 407 if expression and expression.is_string: 408 expression = exp.Literal.number(expression.this) 409 410 return expression_class( 411 this=args[0], 412 expression=expression, 413 unit=exp.Literal.string(interval.text("unit")), 414 ) 415 416 return func 417 418 419def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: 420 unit = seq_get(args, 0) 421 this = seq_get(args, 1) 422 423 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 424 return exp.DateTrunc(unit=unit, this=this) 425 return exp.TimestampTrunc(this=this, unit=unit) 426 427 428def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 429 return self.func( 430 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 431 ) 432 433 434def locate_to_strposition(args: t.Sequence) -> exp.Expression: 435 return exp.StrPosition( 436 this=seq_get(args, 1), 437 substr=seq_get(args, 0), 438 position=seq_get(args, 2), 439 ) 440 441 442def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 443 return self.func( 444 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 445 ) 446 447 448def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 449 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 450 451 452def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 453 return f"CAST({self.sql(expression, 'this')} AS DATE)" 454 455 456def min_or_least(self: Generator, expression: exp.Min) -> str: 457 name = "LEAST" if expression.expressions else "MIN" 458 return rename_func(name)(self, expression) 459 460 461def max_or_greatest(self: Generator, expression: exp.Max) -> str: 462 name = "GREATEST" if expression.expressions else "MAX" 463 return rename_func(name)(self, expression) 464 465 466def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 467 cond = expression.this 468 469 if isinstance(expression.this, exp.Distinct): 470 cond = expression.this.expressions[0] 471 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 472 473 return self.func("sum", exp.func("if", cond, 1, 0)) 474 475 476def trim_sql(self: Generator, expression: exp.Trim) -> str: 477 target = self.sql(expression, "this") 478 trim_type = self.sql(expression, "position") 479 remove_chars = self.sql(expression, "expression") 480 collation = self.sql(expression, "collation") 481 482 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 483 if not remove_chars and not collation: 484 return self.trim_sql(expression) 485 486 trim_type = f"{trim_type} " if trim_type else "" 487 remove_chars = f"{remove_chars} " if remove_chars else "" 488 from_part = "FROM " if trim_type or remove_chars else "" 489 collation = f" COLLATE {collation}" if collation else "" 490 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 491 492 493def str_to_time_sql(self, expression: exp.Expression) -> str: 494 return self.func("STRPTIME", expression.this, self.format_time(expression)) 495 496 497def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 498 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 499 _dialect = Dialect.get_or_raise(dialect) 500 time_format = self.format_time(expression) 501 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 502 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 503 return f"CAST({self.sql(expression, 'this')} AS DATE)" 504 505 return _ts_or_ds_to_date_sql
class
Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum): 19 DIALECT = "" 20 21 BIGQUERY = "bigquery" 22 CLICKHOUSE = "clickhouse" 23 DUCKDB = "duckdb" 24 HIVE = "hive" 25 MYSQL = "mysql" 26 ORACLE = "oracle" 27 POSTGRES = "postgres" 28 PRESTO = "presto" 29 REDSHIFT = "redshift" 30 SNOWFLAKE = "snowflake" 31 SPARK = "spark" 32 SPARK2 = "spark2" 33 SQLITE = "sqlite" 34 STARROCKS = "starrocks" 35 TABLEAU = "tableau" 36 TRINO = "trino" 37 TSQL = "tsql" 38 DATABRICKS = "databricks" 39 DRILL = "drill" 40 TERADATA = "teradata"
An enumeration.
DIALECT =
<Dialects.DIALECT: ''>
BIGQUERY =
<Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE =
<Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB =
<Dialects.DUCKDB: 'duckdb'>
HIVE =
<Dialects.HIVE: 'hive'>
MYSQL =
<Dialects.MYSQL: 'mysql'>
ORACLE =
<Dialects.ORACLE: 'oracle'>
POSTGRES =
<Dialects.POSTGRES: 'postgres'>
PRESTO =
<Dialects.PRESTO: 'presto'>
REDSHIFT =
<Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE =
<Dialects.SNOWFLAKE: 'snowflake'>
SPARK =
<Dialects.SPARK: 'spark'>
SPARK2 =
<Dialects.SPARK2: 'spark2'>
SQLITE =
<Dialects.SQLITE: 'sqlite'>
STARROCKS =
<Dialects.STARROCKS: 'starrocks'>
TABLEAU =
<Dialects.TABLEAU: 'tableau'>
TRINO =
<Dialects.TRINO: 'trino'>
TSQL =
<Dialects.TSQL: 'tsql'>
DATABRICKS =
<Dialects.DATABRICKS: 'databricks'>
DRILL =
<Dialects.DRILL: 'drill'>
TERADATA =
<Dialects.TERADATA: 'teradata'>
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
class
Dialect:
102class Dialect(metaclass=_Dialect): 103 index_offset = 0 104 unnest_column_only = False 105 alias_post_tablesample = False 106 normalize_functions: t.Optional[str] = "upper" 107 null_ordering = "nulls_are_small" 108 109 date_format = "'%Y-%m-%d'" 110 dateint_format = "'%Y%m%d'" 111 time_format = "'%Y-%m-%d %H:%M:%S'" 112 time_mapping: t.Dict[str, str] = {} 113 114 # autofilled 115 quote_start = None 116 quote_end = None 117 identifier_start = None 118 identifier_end = None 119 120 time_trie = None 121 inverse_time_mapping = None 122 inverse_time_trie = None 123 tokenizer_class = None 124 parser_class = None 125 generator_class = None 126 127 @classmethod 128 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 129 if not dialect: 130 return cls 131 if isinstance(dialect, _Dialect): 132 return dialect 133 if isinstance(dialect, Dialect): 134 return dialect.__class__ 135 136 result = cls.get(dialect) 137 if not result: 138 raise ValueError(f"Unknown dialect '{dialect}'") 139 140 return result 141 142 @classmethod 143 def format_time( 144 cls, expression: t.Optional[str | exp.Expression] 145 ) -> t.Optional[exp.Expression]: 146 if isinstance(expression, str): 147 return exp.Literal.string( 148 format_time( 149 expression[1:-1], # the time formats are quoted 150 cls.time_mapping, 151 cls.time_trie, 152 ) 153 ) 154 if expression and expression.is_string: 155 return exp.Literal.string( 156 format_time( 157 expression.this, 158 cls.time_mapping, 159 cls.time_trie, 160 ) 161 ) 162 return expression 163 164 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 165 return self.parser(**opts).parse(self.tokenize(sql), sql) 166 167 def parse_into( 168 self, expression_type: exp.IntoType, sql: str, **opts 169 ) -> t.List[t.Optional[exp.Expression]]: 170 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 171 172 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 173 return self.generator(**opts).generate(expression) 174 175 def transpile(self, sql: str, **opts) -> t.List[str]: 176 return [self.generate(expression, **opts) for expression in self.parse(sql)] 177 178 def tokenize(self, sql: str) -> t.List[Token]: 179 return self.tokenizer.tokenize(sql) 180 181 @property 182 def tokenizer(self) -> Tokenizer: 183 if not hasattr(self, "_tokenizer"): 184 self._tokenizer = self.tokenizer_class() # type: ignore 185 return self._tokenizer 186 187 def parser(self, **opts) -> Parser: 188 return self.parser_class( # type: ignore 189 **{ 190 "index_offset": self.index_offset, 191 "unnest_column_only": self.unnest_column_only, 192 "alias_post_tablesample": self.alias_post_tablesample, 193 "null_ordering": self.null_ordering, 194 **opts, 195 }, 196 ) 197 198 def generator(self, **opts) -> Generator: 199 return self.generator_class( # type: ignore 200 **{ 201 "quote_start": self.quote_start, 202 "quote_end": self.quote_end, 203 "identifier_start": self.identifier_start, 204 "identifier_end": self.identifier_end, 205 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 206 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 207 "index_offset": self.index_offset, 208 "time_mapping": self.inverse_time_mapping, 209 "time_trie": self.inverse_time_trie, 210 "unnest_column_only": self.unnest_column_only, 211 "alias_post_tablesample": self.alias_post_tablesample, 212 "normalize_functions": self.normalize_functions, 213 "null_ordering": self.null_ordering, 214 **opts, 215 } 216 )
@classmethod
def
get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
127 @classmethod 128 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 129 if not dialect: 130 return cls 131 if isinstance(dialect, _Dialect): 132 return dialect 133 if isinstance(dialect, Dialect): 134 return dialect.__class__ 135 136 result = cls.get(dialect) 137 if not result: 138 raise ValueError(f"Unknown dialect '{dialect}'") 139 140 return result
@classmethod
def
format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
142 @classmethod 143 def format_time( 144 cls, expression: t.Optional[str | exp.Expression] 145 ) -> t.Optional[exp.Expression]: 146 if isinstance(expression, str): 147 return exp.Literal.string( 148 format_time( 149 expression[1:-1], # the time formats are quoted 150 cls.time_mapping, 151 cls.time_trie, 152 ) 153 ) 154 if expression and expression.is_string: 155 return exp.Literal.string( 156 format_time( 157 expression.this, 158 cls.time_mapping, 159 cls.time_trie, 160 ) 161 ) 162 return expression
def
parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
187 def parser(self, **opts) -> Parser: 188 return self.parser_class( # type: ignore 189 **{ 190 "index_offset": self.index_offset, 191 "unnest_column_only": self.unnest_column_only, 192 "alias_post_tablesample": self.alias_post_tablesample, 193 "null_ordering": self.null_ordering, 194 **opts, 195 }, 196 )
198 def generator(self, **opts) -> Generator: 199 return self.generator_class( # type: ignore 200 **{ 201 "quote_start": self.quote_start, 202 "quote_end": self.quote_end, 203 "identifier_start": self.identifier_start, 204 "identifier_end": self.identifier_end, 205 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 206 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 207 "index_offset": self.index_offset, 208 "time_mapping": self.inverse_time_mapping, 209 "time_trie": self.inverse_time_trie, 210 "unnest_column_only": self.unnest_column_only, 211 "alias_post_tablesample": self.alias_post_tablesample, 212 "normalize_functions": self.normalize_functions, 213 "null_ordering": self.null_ordering, 214 **opts, 215 } 216 )
def
rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
def
approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
def
arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
def
arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
def
inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
def
no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
def
no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
def
no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
def
no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
def
no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
def
no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
def
no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
def
str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
305def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 306 this = self.sql(expression, "this") 307 substr = self.sql(expression, "substr") 308 position = self.sql(expression, "position") 309 if position: 310 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 311 return f"STRPOS({this}, {substr})"
def
struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
def
var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
320def var_map_sql( 321 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 322) -> str: 323 keys = expression.args["keys"] 324 values = expression.args["values"] 325 326 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 327 self.unsupported("Cannot convert array columns into map.") 328 return self.func(map_func_name, keys, values) 329 330 args = [] 331 for key, value in zip(keys.expressions, values.expressions): 332 args.append(self.sql(key)) 333 args.append(self.sql(value)) 334 return self.func(map_func_name, *args)
def
format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
337def format_time_lambda( 338 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 339) -> t.Callable[[t.Sequence], E]: 340 """Helper used for time expressions. 341 342 Args: 343 exp_class: the expression class to instantiate. 344 dialect: target sql dialect. 345 default: the default format, True being time. 346 347 Returns: 348 A callable that can be used to return the appropriately formatted time expression. 349 """ 350 351 def _format_time(args: t.Sequence): 352 return exp_class( 353 this=seq_get(args, 0), 354 format=Dialect[dialect].format_time( 355 seq_get(args, 1) 356 or (Dialect[dialect].time_format if default is True else default or None) 357 ), 358 ) 359 360 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
def
create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
363def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 364 """ 365 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 366 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 367 columns are removed from the create statement. 368 """ 369 has_schema = isinstance(expression.this, exp.Schema) 370 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 371 372 if has_schema and is_partitionable: 373 expression = expression.copy() 374 prop = expression.find(exp.PartitionedByProperty) 375 if prop and prop.this and not isinstance(prop.this, exp.Schema): 376 schema = expression.this 377 columns = {v.name.upper() for v in prop.this.expressions} 378 partitions = [col for col in schema.expressions if col.name.upper() in columns] 379 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 380 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 381 expression.set("this", schema) 382 383 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
def
parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
386def parse_date_delta( 387 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 388) -> t.Callable[[t.Sequence], E]: 389 def inner_func(args: t.Sequence) -> E: 390 unit_based = len(args) == 3 391 this = args[2] if unit_based else seq_get(args, 0) 392 unit = args[0] if unit_based else exp.Literal.string("DAY") 393 unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit 394 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 395 396 return inner_func
def
parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[Sequence], Optional[~E]]:
399def parse_date_delta_with_interval( 400 expression_class: t.Type[E], 401) -> t.Callable[[t.Sequence], t.Optional[E]]: 402 def func(args: t.Sequence) -> t.Optional[E]: 403 if len(args) < 2: 404 return None 405 406 interval = args[1] 407 expression = interval.this 408 if expression and expression.is_string: 409 expression = exp.Literal.number(expression.this) 410 411 return expression_class( 412 this=args[0], 413 expression=expression, 414 unit=exp.Literal.string(interval.text("unit")), 415 ) 416 417 return func
def
date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
420def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: 421 unit = seq_get(args, 0) 422 this = seq_get(args, 1) 423 424 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 425 return exp.DateTrunc(unit=unit, this=this) 426 return exp.TimestampTrunc(this=this, unit=unit)
def
timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
def
strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
def
timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
def
datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
def
max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
def
count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
467def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 468 cond = expression.this 469 470 if isinstance(expression.this, exp.Distinct): 471 cond = expression.this.expressions[0] 472 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 473 474 return self.func("sum", exp.func("if", cond, 1, 0))
477def trim_sql(self: Generator, expression: exp.Trim) -> str: 478 target = self.sql(expression, "this") 479 trim_type = self.sql(expression, "position") 480 remove_chars = self.sql(expression, "expression") 481 collation = self.sql(expression, "collation") 482 483 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 484 if not remove_chars and not collation: 485 return self.trim_sql(expression) 486 487 trim_type = f"{trim_type} " if trim_type else "" 488 remove_chars = f"{remove_chars} " if remove_chars else "" 489 from_part = "FROM " if trim_type or remove_chars else "" 490 collation = f" COLLATE {collation}" if collation else "" 491 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def
ts_or_ds_to_date_sql(dialect: str) -> Callable:
498def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 499 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 500 _dialect = Dialect.get_or_raise(dialect) 501 time_format = self.format_time(expression) 502 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 503 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 504 return f"CAST({self.sql(expression, 'this')} AS DATE)" 505 506 return _ts_or_ds_to_date_sql