diff options
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 99 |
1 files changed, 76 insertions, 23 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 71269f2..890a3c3 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -8,10 +8,16 @@ from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time -from sqlglot.tokens import Token, Tokenizer +from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie -E = t.TypeVar("E", bound=exp.Expression) +if t.TYPE_CHECKING: + from sqlglot._typing import E + + +# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. +# https://docs.snowflake.com/en/sql-reference/identifiers-syntax +RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} class Dialects(str, Enum): @@ -42,6 +48,19 @@ class Dialects(str, Enum): class _Dialect(type): classes: t.Dict[str, t.Type[Dialect]] = {} + def __eq__(cls, other: t.Any) -> bool: + if cls is other: + return True + if isinstance(other, str): + return cls is cls.get(other) + if isinstance(other, Dialect): + return cls is type(other) + + return False + + def __hash__(cls) -> int: + return hash(cls.__name__.lower()) + @classmethod def __getitem__(cls, key: str) -> t.Type[Dialect]: return cls.classes[key] @@ -70,17 +89,20 @@ class _Dialect(type): klass.tokenizer_class._IDENTIFIERS.items() )[0] - klass.bit_start, klass.bit_end = seq_get( - list(klass.tokenizer_class._BIT_STRINGS.items()), 0 - ) or (None, None) - - klass.hex_start, klass.hex_end = seq_get( - list(klass.tokenizer_class._HEX_STRINGS.items()), 0 - ) or (None, None) + def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: + return next( + ( + (s, e) + for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() + if t == token_type + ), + (None, None), + ) - klass.byte_start, klass.byte_end = seq_get( - list(klass.tokenizer_class._BYTE_STRINGS.items()), 0 - ) or (None, None) + klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) + klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) + klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) + klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) return klass @@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect): parser_class = None generator_class = None + def __eq__(self, other: t.Any) -> bool: + return type(self) == other + + def __hash__(self) -> int: + return hash(type(self)) + @classmethod def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: if not dialect: @@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect): "hex_end": self.hex_end, "byte_start": self.byte_start, "byte_end": self.byte_end, + "raw_start": self.raw_start, + "raw_end": self.raw_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, "string_escape": self.tokenizer_class.STRING_ESCAPES[0], @@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: self.unsupported("PIVOT unsupported") - return self.sql(expression) + return "" def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: @@ -328,7 +358,7 @@ def var_map_sql( def format_time_lambda( exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None -) -> t.Callable[[t.Sequence], E]: +) -> t.Callable[[t.List], E]: """Helper used for time expressions. Args: @@ -340,7 +370,7 @@ def format_time_lambda( A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args: t.Sequence): + def _format_time(args: t.List): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( @@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: def parse_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None -) -> t.Callable[[t.Sequence], E]: - def inner_func(args: t.Sequence) -> E: +) -> t.Callable[[t.List], E]: + def inner_func(args: t.List) -> E: unit_based = len(args) == 3 this = args[2] if unit_based else seq_get(args, 0) unit = args[0] if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit return exp_class(this=this, expression=seq_get(args, 1), unit=unit) return inner_func @@ -390,8 +420,8 @@ def parse_date_delta( def parse_date_delta_with_interval( expression_class: t.Type[E], -) -> t.Callable[[t.Sequence], t.Optional[E]]: - def func(args: t.Sequence) -> t.Optional[E]: +) -> t.Callable[[t.List], t.Optional[E]]: + def func(args: t.List) -> t.Optional[E]: if len(args) < 2: return None @@ -409,7 +439,7 @@ def parse_date_delta_with_interval( return func -def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: +def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) @@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: ) -def locate_to_strposition(args: t.Sequence) -> exp.Expression: +def locate_to_strposition(args: t.List) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" -def str_to_time_sql(self, expression: exp.Expression) -> str: +def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: return self.func("STRPTIME", expression.this, self.format_time(expression)) @@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: return f"CAST({self.sql(expression, 'this')} AS DATE)" return _ts_or_ds_to_date_sql + + +# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator +def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: + names = [] + for agg in aggregations: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + """ + agg_all_unquoted = agg.transform( + lambda node: exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) + + return names |