diff options
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 124 |
1 files changed, 72 insertions, 52 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1b20e0a..176a8ce 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -11,6 +11,8 @@ from sqlglot.time import format_time from sqlglot.tokens import Tokenizer from sqlglot.trie import new_trie +E = t.TypeVar("E", bound=exp.Expression) + class Dialects(str, Enum): DIALECT = "" @@ -37,14 +39,16 @@ class Dialects(str, Enum): class _Dialect(type): - classes: t.Dict[str, Dialect] = {} + classes: t.Dict[str, t.Type[Dialect]] = {} @classmethod - def __getitem__(cls, key): + def __getitem__(cls, key: str) -> t.Type[Dialect]: return cls.classes[key] @classmethod - def get(cls, key, default=None): + def get( + cls, key: str, default: t.Optional[t.Type[Dialect]] = None + ) -> t.Optional[t.Type[Dialect]]: return cls.classes.get(key, default) def __new__(cls, clsname, bases, attrs): @@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect): generator_class = None @classmethod - def get_or_raise(cls, dialect): + def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: if not dialect: return cls if isinstance(dialect, _Dialect): @@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect): return result @classmethod - def format_time(cls, expression): + def format_time( + cls, expression: t.Optional[str | exp.Expression] + ) -> t.Optional[exp.Expression]: if isinstance(expression, str): return exp.Literal.string( format_time( @@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect): ) return expression - def parse(self, sql, **opts): + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) - def parse_into(self, expression_type, sql, **opts): + def parse_into( + self, expression_type: exp.IntoType, sql: str, **opts + ) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) - def generate(self, expression, **opts): + def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: return self.generator(**opts).generate(expression) - def transpile(self, code, **opts): - return self.generate(self.parse(code), **opts) + def transpile(self, sql: str, **opts) -> t.List[str]: + return [self.generate(expression, **opts) for expression in self.parse(sql)] @property - def tokenizer(self): + def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() + self._tokenizer = self.tokenizer_class() # type: ignore return self._tokenizer - def parser(self, **opts): - return self.parser_class( + def parser(self, **opts) -> Parser: + return self.parser_class( # type: ignore **{ "index_offset": self.index_offset, "unnest_column_only": self.unnest_column_only, @@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect): }, ) - def generator(self, **opts): - return self.generator_class( + def generator(self, **opts) -> Generator: + return self.generator_class( # type: ignore **{ "quote_start": self.quote_start, "quote_end": self.quote_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, - "escape": self.tokenizer_class.ESCAPES[0], + "string_escape": self.tokenizer_class.STRING_ESCAPES[0], + "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], "index_offset": self.index_offset, "time_mapping": self.inverse_time_mapping, "time_trie": self.inverse_time_trie, @@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect): ) -if t.TYPE_CHECKING: - DialectType = t.Union[str, Dialect, t.Type[Dialect], None] +DialectType = t.Union[str, Dialect, t.Type[Dialect], None] -def rename_func(name): +def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: def _rename(self, expression): args = flatten(expression.args.values()) return f"{self.normalize_func(name)}({self.format_args(*args)})" @@ -214,32 +222,34 @@ def rename_func(name): return _rename -def approx_count_distinct_sql(self, expression): +def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: if expression.args.get("accuracy"): self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})" -def if_sql(self, expression): +def if_sql(self: Generator, expression: exp.If) -> str: expressions = self.format_args( expression.this, expression.args.get("true"), expression.args.get("false") ) return f"IF({expressions})" -def arrow_json_extract_sql(self, expression): +def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: return self.binary(expression, "->") -def arrow_json_extract_scalar_sql(self, expression): +def arrow_json_extract_scalar_sql( + self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar +) -> str: return self.binary(expression, "->>") -def inline_array_sql(self, expression): +def inline_array_sql(self: Generator, expression: exp.Array) -> str: return f"[{self.expressions(expression)}]" -def no_ilike_sql(self, expression): +def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( exp.Like( this=exp.Lower(this=expression.this), @@ -248,44 +258,44 @@ def no_ilike_sql(self, expression): ) -def no_paren_current_date_sql(self, expression): +def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" -def no_recursive_cte_sql(self, expression): +def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: if expression.args.get("recursive"): self.unsupported("Recursive CTEs are unsupported") expression.args["recursive"] = False return self.with_sql(expression) -def no_safe_divide_sql(self, expression): +def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: n = self.sql(expression, "this") d = self.sql(expression, "expression") return f"IF({d} <> 0, {n} / {d}, NULL)" -def no_tablesample_sql(self, expression): +def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: self.unsupported("TABLESAMPLE unsupported") return self.sql(expression.this) -def no_pivot_sql(self, expression): +def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: self.unsupported("PIVOT unsupported") return self.sql(expression) -def no_trycast_sql(self, expression): +def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: return self.cast_sql(expression) -def no_properties_sql(self, expression): +def no_properties_sql(self: Generator, expression: exp.Properties) -> str: self.unsupported("Properties unsupported") return "" -def str_position_sql(self, expression): +def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") position = self.sql(expression, "position") @@ -294,13 +304,15 @@ def str_position_sql(self, expression): return f"STRPOS({this}, {substr})" -def struct_extract_sql(self, expression): +def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: this = self.sql(expression, "this") struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) return f"{this}.{struct_key}" -def var_map_sql(self, expression, map_func_name="MAP"): +def var_map_sql( + self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" +) -> str: keys = expression.args["keys"] values = expression.args["values"] @@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"): return f"{map_func_name}({self.format_args(*args)})" -def format_time_lambda(exp_class, dialect, default=None): +def format_time_lambda( + exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None +) -> t.Callable[[t.Sequence], E]: """Helper used for time expressions. - Args - exp_class (Class): the expression class to instantiate - dialect (string): sql dialect - default (Option[bool | str]): the default format, True being time + Args: + exp_class: the expression class to instantiate. + dialect: target sql dialect. + default: the default format, True being time. + + Returns: + A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args): + def _format_time(args: t.Sequence): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( - seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) + seq_get(args, 1) + or (Dialect[dialect].time_format if default is True else default or None) ), ) return _format_time -def create_with_partitions_sql(self, expression): +def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: """ In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding @@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression): return self.create_sql(expression) -def parse_date_delta(exp_class, unit_mapping=None): - def inner_func(args): +def parse_date_delta( + exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None +) -> t.Callable[[t.Sequence], E]: + def inner_func(args: t.Sequence) -> E: unit_based = len(args) == 3 this = seq_get(args, 2) if unit_based else seq_get(args, 0) expression = seq_get(args, 1) if unit_based else seq_get(args, 1) unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore return exp_class(this=this, expression=expression, unit=unit) return inner_func -def locate_to_strposition(args): +def locate_to_strposition(args: t.Sequence) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -379,22 +399,22 @@ def locate_to_strposition(args): ) -def strposition_to_locate_sql(self, expression): +def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: args = self.format_args( expression.args.get("substr"), expression.this, expression.args.get("position") ) return f"LOCATE({args})" -def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: +def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" -def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: +def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return f"CAST({self.sql(expression, 'this')} AS DATE)" -def trim_sql(self, expression): +def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") |