diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 54 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 124 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 33 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 1 |
12 files changed, 152 insertions, 99 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 90ae229..6a19b46 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -2,6 +2,8 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +E = t.TypeVar("E", bound=exp.Expression) + -def _date_add(expression_class): +def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]: def func(args): interval = seq_get(args, 1) return expression_class( @@ -27,26 +31,26 @@ def _date_add(expression_class): return func -def _date_trunc(args): +def _date_trunc(args: t.Sequence) -> exp.Expression: unit = seq_get(args, 1) if isinstance(unit, exp.Column): unit = exp.Var(this=unit.name) return exp.DateTrunc(this=seq_get(args, 0), expression=unit) -def _date_add_sql(data_type, kind): +def _date_add_sql( + data_type: str, kind: str +) -> t.Callable[[generator.Generator, exp.Expression], str]: def func(self, expression): this = self.sql(expression, "this") - unit = self.sql(expression, "unit") or "'day'" - expression = self.sql(expression, "expression") - return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})" + return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})" return func -def _derived_table_values_to_unnest(self, expression): +def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: if not isinstance(expression.unnest().parent, exp.From): - expression = transforms.remove_precision_parameterized_types(expression) + expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression)) return self.values_sql(expression) rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)] structs = [] @@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression): return self.unnest_sql(unnest_exp) -def _returnsproperty_sql(self, expression): +def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): this = f"{this.this} <{self.expressions(this)}>" @@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression): return f"RETURNS {this}" -def _create_sql(self, expression): - kind = expression.args.get("kind") +def _create_sql(self: generator.Generator, expression: exp.Create) -> str: + kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): expression = expression.copy() @@ -89,6 +93,29 @@ def _create_sql(self, expression): return self.create_sql(expression) +def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: + """Remove references to unnest table aliases since bigquery doesn't allow them. + + These are added by the optimizer's qualify_column step. + """ + if isinstance(expression, exp.Select): + unnests = { + unnest.alias + for unnest in expression.args.get("from", exp.From(expressions=[])).expressions + if isinstance(unnest, exp.Unnest) and unnest.alias + } + + if unnests: + expression = expression.copy() + + for select in expression.expressions: + for column in select.find_all(exp.Column): + if column.table in unnests: + column.set("table", None) + + return expression + + class BigQuery(Dialect): unnest_column_only = True time_mapping = { @@ -110,7 +137,7 @@ class BigQuery(Dialect): ] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { @@ -190,6 +217,9 @@ class BigQuery(Dialect): exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.Select: transforms.preprocess( + [_unqualify_unnest], transforms.delegate("select_sql") + ), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 9e8c691..b553df2 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -9,7 +9,7 @@ from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType -def _lower_func(sql): +def _lower_func(sql: str) -> str: index = sql.index("(") return sql[:index].lower() + sql[index:] diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1b20e0a..176a8ce 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -11,6 +11,8 @@ from sqlglot.time import format_time from sqlglot.tokens import Tokenizer from sqlglot.trie import new_trie +E = t.TypeVar("E", bound=exp.Expression) + class Dialects(str, Enum): DIALECT = "" @@ -37,14 +39,16 @@ class Dialects(str, Enum): class _Dialect(type): - classes: t.Dict[str, Dialect] = {} + classes: t.Dict[str, t.Type[Dialect]] = {} @classmethod - def __getitem__(cls, key): + def __getitem__(cls, key: str) -> t.Type[Dialect]: return cls.classes[key] @classmethod - def get(cls, key, default=None): + def get( + cls, key: str, default: t.Optional[t.Type[Dialect]] = None + ) -> t.Optional[t.Type[Dialect]]: return cls.classes.get(key, default) def __new__(cls, clsname, bases, attrs): @@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect): generator_class = None @classmethod - def get_or_raise(cls, dialect): + def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: if not dialect: return cls if isinstance(dialect, _Dialect): @@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect): return result @classmethod - def format_time(cls, expression): + def format_time( + cls, expression: t.Optional[str | exp.Expression] + ) -> t.Optional[exp.Expression]: if isinstance(expression, str): return exp.Literal.string( format_time( @@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect): ) return expression - def parse(self, sql, **opts): + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) - def parse_into(self, expression_type, sql, **opts): + def parse_into( + self, expression_type: exp.IntoType, sql: str, **opts + ) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) - def generate(self, expression, **opts): + def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: return self.generator(**opts).generate(expression) - def transpile(self, code, **opts): - return self.generate(self.parse(code), **opts) + def transpile(self, sql: str, **opts) -> t.List[str]: + return [self.generate(expression, **opts) for expression in self.parse(sql)] @property - def tokenizer(self): + def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() + self._tokenizer = self.tokenizer_class() # type: ignore return self._tokenizer - def parser(self, **opts): - return self.parser_class( + def parser(self, **opts) -> Parser: + return self.parser_class( # type: ignore **{ "index_offset": self.index_offset, "unnest_column_only": self.unnest_column_only, @@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect): }, ) - def generator(self, **opts): - return self.generator_class( + def generator(self, **opts) -> Generator: + return self.generator_class( # type: ignore **{ "quote_start": self.quote_start, "quote_end": self.quote_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, - "escape": self.tokenizer_class.ESCAPES[0], + "string_escape": self.tokenizer_class.STRING_ESCAPES[0], + "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], "index_offset": self.index_offset, "time_mapping": self.inverse_time_mapping, "time_trie": self.inverse_time_trie, @@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect): ) -if t.TYPE_CHECKING: - DialectType = t.Union[str, Dialect, t.Type[Dialect], None] +DialectType = t.Union[str, Dialect, t.Type[Dialect], None] -def rename_func(name): +def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: def _rename(self, expression): args = flatten(expression.args.values()) return f"{self.normalize_func(name)}({self.format_args(*args)})" @@ -214,32 +222,34 @@ def rename_func(name): return _rename -def approx_count_distinct_sql(self, expression): +def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: if expression.args.get("accuracy"): self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})" -def if_sql(self, expression): +def if_sql(self: Generator, expression: exp.If) -> str: expressions = self.format_args( expression.this, expression.args.get("true"), expression.args.get("false") ) return f"IF({expressions})" -def arrow_json_extract_sql(self, expression): +def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: return self.binary(expression, "->") -def arrow_json_extract_scalar_sql(self, expression): +def arrow_json_extract_scalar_sql( + self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar +) -> str: return self.binary(expression, "->>") -def inline_array_sql(self, expression): +def inline_array_sql(self: Generator, expression: exp.Array) -> str: return f"[{self.expressions(expression)}]" -def no_ilike_sql(self, expression): +def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( exp.Like( this=exp.Lower(this=expression.this), @@ -248,44 +258,44 @@ def no_ilike_sql(self, expression): ) -def no_paren_current_date_sql(self, expression): +def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" -def no_recursive_cte_sql(self, expression): +def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: if expression.args.get("recursive"): self.unsupported("Recursive CTEs are unsupported") expression.args["recursive"] = False return self.with_sql(expression) -def no_safe_divide_sql(self, expression): +def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: n = self.sql(expression, "this") d = self.sql(expression, "expression") return f"IF({d} <> 0, {n} / {d}, NULL)" -def no_tablesample_sql(self, expression): +def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: self.unsupported("TABLESAMPLE unsupported") return self.sql(expression.this) -def no_pivot_sql(self, expression): +def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: self.unsupported("PIVOT unsupported") return self.sql(expression) -def no_trycast_sql(self, expression): +def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: return self.cast_sql(expression) -def no_properties_sql(self, expression): +def no_properties_sql(self: Generator, expression: exp.Properties) -> str: self.unsupported("Properties unsupported") return "" -def str_position_sql(self, expression): +def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") position = self.sql(expression, "position") @@ -294,13 +304,15 @@ def str_position_sql(self, expression): return f"STRPOS({this}, {substr})" -def struct_extract_sql(self, expression): +def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: this = self.sql(expression, "this") struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) return f"{this}.{struct_key}" -def var_map_sql(self, expression, map_func_name="MAP"): +def var_map_sql( + self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" +) -> str: keys = expression.args["keys"] values = expression.args["values"] @@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"): return f"{map_func_name}({self.format_args(*args)})" -def format_time_lambda(exp_class, dialect, default=None): +def format_time_lambda( + exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None +) -> t.Callable[[t.Sequence], E]: """Helper used for time expressions. - Args - exp_class (Class): the expression class to instantiate - dialect (string): sql dialect - default (Option[bool | str]): the default format, True being time + Args: + exp_class: the expression class to instantiate. + dialect: target sql dialect. + default: the default format, True being time. + + Returns: + A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args): + def _format_time(args: t.Sequence): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( - seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) + seq_get(args, 1) + or (Dialect[dialect].time_format if default is True else default or None) ), ) return _format_time -def create_with_partitions_sql(self, expression): +def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: """ In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding @@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression): return self.create_sql(expression) -def parse_date_delta(exp_class, unit_mapping=None): - def inner_func(args): +def parse_date_delta( + exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None +) -> t.Callable[[t.Sequence], E]: + def inner_func(args: t.Sequence) -> E: unit_based = len(args) == 3 this = seq_get(args, 2) if unit_based else seq_get(args, 0) expression = seq_get(args, 1) if unit_based else seq_get(args, 1) unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore return exp_class(this=this, expression=expression, unit=unit) return inner_func -def locate_to_strposition(args): +def locate_to_strposition(args: t.Sequence) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -379,22 +399,22 @@ def locate_to_strposition(args): ) -def strposition_to_locate_sql(self, expression): +def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: args = self.format_args( expression.args.get("substr"), expression.this, expression.args.get("position") ) return f"LOCATE({args})" -def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: +def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" -def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: +def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return f"CAST({self.sql(expression, 'this')} AS DATE)" -def trim_sql(self, expression): +def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index d0a0251..1730eaf 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( @@ -16,35 +17,29 @@ from sqlglot.dialects.dialect import ( ) -def _to_timestamp(args): - # TO_TIMESTAMP accepts either a single double argument or (text, text) - if len(args) == 1 and args[0].is_number: - return exp.UnixToTime.from_arg_list(args) - return format_time_lambda(exp.StrToTime, "drill")(args) - - -def _str_to_time_sql(self, expression): +def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self, expression): +def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Drill.time_format, Drill.date_format): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)" -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") - unit = expression.text("unit").upper() or "DAY" - expression = self.sql(expression, "expression") - return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})" + unit = exp.Var(this=expression.text("unit").upper() or "DAY") + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func -def if_sql(self, expression): +def if_sql(self: generator.Generator, expression: exp.If) -> str: """ Drill requires backticks around certain SQL reserved words, IF being one of them, This function adds the backticks around the keyword IF. @@ -61,7 +56,7 @@ def if_sql(self, expression): return f"`IF`({expressions})" -def _str_to_date(self, expression): +def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Drill.date_format: @@ -111,7 +106,7 @@ class Drill(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'"] IDENTIFIERS = ["`"] - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] ENCODE = "utf-8" class Parser(parser.Parser): @@ -168,10 +163,10 @@ class Drill(Dialect): exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})", exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } - def normalize_func(self, name): + def normalize_func(self, name: str) -> str: return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 95ff95c..959e5e2 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -25,10 +25,9 @@ def _str_to_time_sql(self, expression): def _ts_or_ds_add(self, expression): - this = self.sql(expression, "this") - e = self.sql(expression, "expression") + this = expression.args.get("this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"CAST({this} AS DATE) + INTERVAL {e} {unit}" + return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _ts_or_ds_to_date_sql(self, expression): @@ -40,9 +39,8 @@ def _ts_or_ds_to_date_sql(self, expression): def _date_add(self, expression): this = self.sql(expression, "this") - e = self.sql(expression, "expression") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"{this} + INTERVAL {e} {unit}" + return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _array_sort_sql(self, expression): diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index f2b6eaa..c558b70 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -172,7 +172,7 @@ class Hive(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] IDENTIFIERS = ["`"] - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] ENCODE = "utf-8" KEYWORDS = { diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index a5bd86b..c2c2c8c 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -89,8 +89,9 @@ def _date_add_sql(kind): def func(self, expression): this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" - expression = self.sql(expression, "expression") - return f"DATE_{kind}({this}, INTERVAL {expression} {unit})" + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func @@ -117,7 +118,7 @@ class MySQL(Dialect): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPES = ["'", "\\"] + STRING_ESCAPES = ["'", "\\"] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 6418032..c709665 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -40,8 +40,7 @@ def _date_add_sql(kind): expression = expression.copy() expression.args["is_string"] = True - expression = self.sql(expression) - return f"{this} {kind} INTERVAL {expression} {unit}" + return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}" return func diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index c3c99eb..813ee5f 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -37,11 +37,10 @@ class Redshift(Postgres): return this class Tokenizer(Postgres.Tokenizer): - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore - "COPY": TokenType.COMMAND, "ENCODE": TokenType.ENCODE, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 3b83b02..55a6bd3 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -180,7 +180,7 @@ class Snowflake(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPES = ["\\", "'"] + STRING_ESCAPES = ["\\", "'"] SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, @@ -191,6 +191,7 @@ class Snowflake(Dialect): **tokens.Tokenizer.KEYWORDS, "EXCLUDE": TokenType.EXCEPT, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, + "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, @@ -222,6 +223,7 @@ class Snowflake(Dialect): exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.UnixToTime: _unix_to_time_sql, + exp.DayOfWeek: rename_func("DAYOFWEEK"), } TYPE_MAPPING = { @@ -294,3 +296,12 @@ class Snowflake(Dialect): kind = f" {kind_value}" if kind_value else "" this = f" {self.sql(expression, 'this')}" return f"DESCRIBE{kind}{this}" + + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + start = expression.args.get("start") + start = f" START {start}" if start else "" + increment = expression.args.get("increment") + increment = f" INCREMENT {increment}" if increment else "" + return f"AUTOINCREMENT{start}{increment}" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 8ef4a87..03ec211 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -157,6 +157,7 @@ class Spark(Hive): TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False + CREATE_FUNCTION_AS = False def cast_sql(self, expression: exp.Cast) -> str: if isinstance(expression.this, exp.Cast) and expression.this.is_type( diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 1b39449..a428dd5 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -49,7 +49,6 @@ class SQLite(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "AUTOINCREMENT": TokenType.AUTO_INCREMENT, } class Parser(parser.Parser): |