summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py124
1 files changed, 72 insertions, 52 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 1b20e0a..176a8ce 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -11,6 +11,8 @@ from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
from sqlglot.trie import new_trie
+E = t.TypeVar("E", bound=exp.Expression)
+
class Dialects(str, Enum):
DIALECT = ""
@@ -37,14 +39,16 @@ class Dialects(str, Enum):
class _Dialect(type):
- classes: t.Dict[str, Dialect] = {}
+ classes: t.Dict[str, t.Type[Dialect]] = {}
@classmethod
- def __getitem__(cls, key):
+ def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@classmethod
- def get(cls, key, default=None):
+ def get(
+ cls, key: str, default: t.Optional[t.Type[Dialect]] = None
+ ) -> t.Optional[t.Type[Dialect]]:
return cls.classes.get(key, default)
def __new__(cls, clsname, bases, attrs):
@@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect):
generator_class = None
@classmethod
- def get_or_raise(cls, dialect):
+ def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
return cls
if isinstance(dialect, _Dialect):
@@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect):
return result
@classmethod
- def format_time(cls, expression):
+ def format_time(
+ cls, expression: t.Optional[str | exp.Expression]
+ ) -> t.Optional[exp.Expression]:
if isinstance(expression, str):
return exp.Literal.string(
format_time(
@@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect):
)
return expression
- def parse(self, sql, **opts):
+ def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
- def parse_into(self, expression_type, sql, **opts):
+ def parse_into(
+ self, expression_type: exp.IntoType, sql: str, **opts
+ ) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
- def generate(self, expression, **opts):
+ def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression)
- def transpile(self, code, **opts):
- return self.generate(self.parse(code), **opts)
+ def transpile(self, sql: str, **opts) -> t.List[str]:
+ return [self.generate(expression, **opts) for expression in self.parse(sql)]
@property
- def tokenizer(self):
+ def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
- self._tokenizer = self.tokenizer_class()
+ self._tokenizer = self.tokenizer_class() # type: ignore
return self._tokenizer
- def parser(self, **opts):
- return self.parser_class(
+ def parser(self, **opts) -> Parser:
+ return self.parser_class( # type: ignore
**{
"index_offset": self.index_offset,
"unnest_column_only": self.unnest_column_only,
@@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect):
},
)
- def generator(self, **opts):
- return self.generator_class(
+ def generator(self, **opts) -> Generator:
+ return self.generator_class( # type: ignore
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
- "escape": self.tokenizer_class.ESCAPES[0],
+ "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
+ "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect):
)
-if t.TYPE_CHECKING:
- DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
+DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
-def rename_func(name):
+def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
def _rename(self, expression):
args = flatten(expression.args.values())
return f"{self.normalize_func(name)}({self.format_args(*args)})"
@@ -214,32 +222,34 @@ def rename_func(name):
return _rename
-def approx_count_distinct_sql(self, expression):
+def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
-def if_sql(self, expression):
+def if_sql(self: Generator, expression: exp.If) -> str:
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"IF({expressions})"
-def arrow_json_extract_sql(self, expression):
+def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
return self.binary(expression, "->")
-def arrow_json_extract_scalar_sql(self, expression):
+def arrow_json_extract_scalar_sql(
+ self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
+) -> str:
return self.binary(expression, "->>")
-def inline_array_sql(self, expression):
+def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression)}]"
-def no_ilike_sql(self, expression):
+def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(
this=exp.Lower(this=expression.this),
@@ -248,44 +258,44 @@ def no_ilike_sql(self, expression):
)
-def no_paren_current_date_sql(self, expression):
+def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
-def no_recursive_cte_sql(self, expression):
+def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
if expression.args.get("recursive"):
self.unsupported("Recursive CTEs are unsupported")
expression.args["recursive"] = False
return self.with_sql(expression)
-def no_safe_divide_sql(self, expression):
+def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
return f"IF({d} <> 0, {n} / {d}, NULL)"
-def no_tablesample_sql(self, expression):
+def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
self.unsupported("TABLESAMPLE unsupported")
return self.sql(expression.this)
-def no_pivot_sql(self, expression):
+def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
return self.sql(expression)
-def no_trycast_sql(self, expression):
+def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression)
-def no_properties_sql(self, expression):
+def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
self.unsupported("Properties unsupported")
return ""
-def str_position_sql(self, expression):
+def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
@@ -294,13 +304,15 @@ def str_position_sql(self, expression):
return f"STRPOS({this}, {substr})"
-def struct_extract_sql(self, expression):
+def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
this = self.sql(expression, "this")
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
return f"{this}.{struct_key}"
-def var_map_sql(self, expression, map_func_name="MAP"):
+def var_map_sql(
+ self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
+) -> str:
keys = expression.args["keys"]
values = expression.args["values"]
@@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"):
return f"{map_func_name}({self.format_args(*args)})"
-def format_time_lambda(exp_class, dialect, default=None):
+def format_time_lambda(
+ exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
+) -> t.Callable[[t.Sequence], E]:
"""Helper used for time expressions.
- Args
- exp_class (Class): the expression class to instantiate
- dialect (string): sql dialect
- default (Option[bool | str]): the default format, True being time
+ Args:
+ exp_class: the expression class to instantiate.
+ dialect: target sql dialect.
+ default: the default format, True being time.
+
+ Returns:
+ A callable that can be used to return the appropriately formatted time expression.
"""
- def _format_time(args):
+ def _format_time(args: t.Sequence):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
- seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
+ seq_get(args, 1)
+ or (Dialect[dialect].time_format if default is True else default or None)
),
)
return _format_time
-def create_with_partitions_sql(self, expression):
+def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
@@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression):
return self.create_sql(expression)
-def parse_date_delta(exp_class, unit_mapping=None):
- def inner_func(args):
+def parse_date_delta(
+ exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
+) -> t.Callable[[t.Sequence], E]:
+ def inner_func(args: t.Sequence) -> E:
unit_based = len(args) == 3
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
- unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
+ unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
-def locate_to_strposition(args):
+def locate_to_strposition(args: t.Sequence) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -379,22 +399,22 @@ def locate_to_strposition(args):
)
-def strposition_to_locate_sql(self, expression):
+def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"
-def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
+def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
-def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
+def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
-def trim_sql(self, expression):
+def trim_sql(self: Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")