summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py99
1 files changed, 76 insertions, 23 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 71269f2..890a3c3 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -8,10 +8,16 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
-from sqlglot.tokens import Token, Tokenizer
+from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
-E = t.TypeVar("E", bound=exp.Expression)
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
+
+# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
+# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
+RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
class Dialects(str, Enum):
@@ -42,6 +48,19 @@ class Dialects(str, Enum):
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
+ def __eq__(cls, other: t.Any) -> bool:
+ if cls is other:
+ return True
+ if isinstance(other, str):
+ return cls is cls.get(other)
+ if isinstance(other, Dialect):
+ return cls is type(other)
+
+ return False
+
+ def __hash__(cls) -> int:
+ return hash(cls.__name__.lower())
+
@classmethod
def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@@ -70,17 +89,20 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
- klass.bit_start, klass.bit_end = seq_get(
- list(klass.tokenizer_class._BIT_STRINGS.items()), 0
- ) or (None, None)
-
- klass.hex_start, klass.hex_end = seq_get(
- list(klass.tokenizer_class._HEX_STRINGS.items()), 0
- ) or (None, None)
+ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
+ return next(
+ (
+ (s, e)
+ for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
+ if t == token_type
+ ),
+ (None, None),
+ )
- klass.byte_start, klass.byte_end = seq_get(
- list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
- ) or (None, None)
+ klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
+ klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
+ klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
+ klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
return klass
@@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect):
parser_class = None
generator_class = None
+ def __eq__(self, other: t.Any) -> bool:
+ return type(self) == other
+
+ def __hash__(self) -> int:
+ return hash(type(self))
+
@classmethod
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
@@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect):
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
+ "raw_start": self.raw_start,
+ "raw_end": self.raw_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
@@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
- return self.sql(expression)
+ return ""
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
@@ -328,7 +358,7 @@ def var_map_sql(
def format_time_lambda(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
-) -> t.Callable[[t.Sequence], E]:
+) -> t.Callable[[t.List], E]:
"""Helper used for time expressions.
Args:
@@ -340,7 +370,7 @@ def format_time_lambda(
A callable that can be used to return the appropriately formatted time expression.
"""
- def _format_time(args: t.Sequence):
+ def _format_time(args: t.List):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
@@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
-) -> t.Callable[[t.Sequence], E]:
- def inner_func(args: t.Sequence) -> E:
+) -> t.Callable[[t.List], E]:
+ def inner_func(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
- unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
+ unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
@@ -390,8 +420,8 @@ def parse_date_delta(
def parse_date_delta_with_interval(
expression_class: t.Type[E],
-) -> t.Callable[[t.Sequence], t.Optional[E]]:
- def func(args: t.Sequence) -> t.Optional[E]:
+) -> t.Callable[[t.List], t.Optional[E]]:
+ def func(args: t.List) -> t.Optional[E]:
if len(args) < 2:
return None
@@ -409,7 +439,7 @@ def parse_date_delta_with_interval(
return func
-def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
+def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
@@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
)
-def locate_to_strposition(args: t.Sequence) -> exp.Expression:
+def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
-def str_to_time_sql(self, expression: exp.Expression) -> str:
+def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
@@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
+
+
+# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
+def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
+ names = []
+ for agg in aggregations:
+ if isinstance(agg, exp.Alias):
+ names.append(agg.alias)
+ else:
+ """
+ This case corresponds to aggregations without aliases being used as suffixes
+ (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
+ be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
+ Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
+ """
+ agg_all_unquoted = agg.transform(
+ lambda node: exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
+ names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
+
+ return names