From 918abde014f9e5c75dfbe21110c379f7f70435c9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 12 Feb 2023 11:06:28 +0100 Subject: Merging upstream version 11.0.1. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 3 +- sqlglot/dialects/bigquery.py | 54 ++++++++++--- sqlglot/dialects/clickhouse.py | 2 +- sqlglot/dialects/dialect.py | 124 +++++++++++++++++------------- sqlglot/dialects/drill.py | 33 ++++---- sqlglot/dialects/duckdb.py | 8 +- sqlglot/dialects/hive.py | 2 +- sqlglot/dialects/mysql.py | 7 +- sqlglot/dialects/postgres.py | 3 +- sqlglot/dialects/redshift.py | 3 +- sqlglot/dialects/snowflake.py | 13 +++- sqlglot/dialects/spark.py | 1 + sqlglot/dialects/sqlite.py | 1 - sqlglot/diff.py | 1 + sqlglot/errors.py | 15 +++- sqlglot/executor/__init__.py | 1 + sqlglot/executor/python.py | 2 +- sqlglot/expressions.py | 107 ++++++++++++++++++-------- sqlglot/generator.py | 54 +++++++++---- sqlglot/lineage.py | 3 +- sqlglot/optimizer/annotate_types.py | 17 +++- sqlglot/optimizer/expand_laterals.py | 34 ++++++++ sqlglot/optimizer/optimizer.py | 5 +- sqlglot/optimizer/pushdown_projections.py | 6 +- sqlglot/optimizer/qualify_columns.py | 30 ++++---- sqlglot/optimizer/qualify_tables.py | 13 +++- sqlglot/optimizer/scope.py | 20 ++++- sqlglot/parser.py | 48 +++++++++--- sqlglot/tokens.py | 38 ++++++--- 29 files changed, 452 insertions(+), 196 deletions(-) create mode 100644 sqlglot/optimizer/expand_laterals.py (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 714897f..7b07ae1 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,5 +1,6 @@ """ .. include:: ../README.md + ---- """ @@ -39,7 +40,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "10.6.3" +__version__ = "11.0.1" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 90ae229..6a19b46 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -2,6 +2,8 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +E = t.TypeVar("E", bound=exp.Expression) + -def _date_add(expression_class): +def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]: def func(args): interval = seq_get(args, 1) return expression_class( @@ -27,26 +31,26 @@ def _date_add(expression_class): return func -def _date_trunc(args): +def _date_trunc(args: t.Sequence) -> exp.Expression: unit = seq_get(args, 1) if isinstance(unit, exp.Column): unit = exp.Var(this=unit.name) return exp.DateTrunc(this=seq_get(args, 0), expression=unit) -def _date_add_sql(data_type, kind): +def _date_add_sql( + data_type: str, kind: str +) -> t.Callable[[generator.Generator, exp.Expression], str]: def func(self, expression): this = self.sql(expression, "this") - unit = self.sql(expression, "unit") or "'day'" - expression = self.sql(expression, "expression") - return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})" + return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})" return func -def _derived_table_values_to_unnest(self, expression): +def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: if not isinstance(expression.unnest().parent, exp.From): - expression = transforms.remove_precision_parameterized_types(expression) + expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression)) return self.values_sql(expression) rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)] structs = [] @@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression): return self.unnest_sql(unnest_exp) -def _returnsproperty_sql(self, expression): +def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): this = f"{this.this} <{self.expressions(this)}>" @@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression): return f"RETURNS {this}" -def _create_sql(self, expression): - kind = expression.args.get("kind") +def _create_sql(self: generator.Generator, expression: exp.Create) -> str: + kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): expression = expression.copy() @@ -89,6 +93,29 @@ def _create_sql(self, expression): return self.create_sql(expression) +def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: + """Remove references to unnest table aliases since bigquery doesn't allow them. + + These are added by the optimizer's qualify_column step. + """ + if isinstance(expression, exp.Select): + unnests = { + unnest.alias + for unnest in expression.args.get("from", exp.From(expressions=[])).expressions + if isinstance(unnest, exp.Unnest) and unnest.alias + } + + if unnests: + expression = expression.copy() + + for select in expression.expressions: + for column in select.find_all(exp.Column): + if column.table in unnests: + column.set("table", None) + + return expression + + class BigQuery(Dialect): unnest_column_only = True time_mapping = { @@ -110,7 +137,7 @@ class BigQuery(Dialect): ] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { @@ -190,6 +217,9 @@ class BigQuery(Dialect): exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.Select: transforms.preprocess( + [_unqualify_unnest], transforms.delegate("select_sql") + ), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 9e8c691..b553df2 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -9,7 +9,7 @@ from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType -def _lower_func(sql): +def _lower_func(sql: str) -> str: index = sql.index("(") return sql[:index].lower() + sql[index:] diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1b20e0a..176a8ce 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -11,6 +11,8 @@ from sqlglot.time import format_time from sqlglot.tokens import Tokenizer from sqlglot.trie import new_trie +E = t.TypeVar("E", bound=exp.Expression) + class Dialects(str, Enum): DIALECT = "" @@ -37,14 +39,16 @@ class Dialects(str, Enum): class _Dialect(type): - classes: t.Dict[str, Dialect] = {} + classes: t.Dict[str, t.Type[Dialect]] = {} @classmethod - def __getitem__(cls, key): + def __getitem__(cls, key: str) -> t.Type[Dialect]: return cls.classes[key] @classmethod - def get(cls, key, default=None): + def get( + cls, key: str, default: t.Optional[t.Type[Dialect]] = None + ) -> t.Optional[t.Type[Dialect]]: return cls.classes.get(key, default) def __new__(cls, clsname, bases, attrs): @@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect): generator_class = None @classmethod - def get_or_raise(cls, dialect): + def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: if not dialect: return cls if isinstance(dialect, _Dialect): @@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect): return result @classmethod - def format_time(cls, expression): + def format_time( + cls, expression: t.Optional[str | exp.Expression] + ) -> t.Optional[exp.Expression]: if isinstance(expression, str): return exp.Literal.string( format_time( @@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect): ) return expression - def parse(self, sql, **opts): + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) - def parse_into(self, expression_type, sql, **opts): + def parse_into( + self, expression_type: exp.IntoType, sql: str, **opts + ) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) - def generate(self, expression, **opts): + def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: return self.generator(**opts).generate(expression) - def transpile(self, code, **opts): - return self.generate(self.parse(code), **opts) + def transpile(self, sql: str, **opts) -> t.List[str]: + return [self.generate(expression, **opts) for expression in self.parse(sql)] @property - def tokenizer(self): + def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() + self._tokenizer = self.tokenizer_class() # type: ignore return self._tokenizer - def parser(self, **opts): - return self.parser_class( + def parser(self, **opts) -> Parser: + return self.parser_class( # type: ignore **{ "index_offset": self.index_offset, "unnest_column_only": self.unnest_column_only, @@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect): }, ) - def generator(self, **opts): - return self.generator_class( + def generator(self, **opts) -> Generator: + return self.generator_class( # type: ignore **{ "quote_start": self.quote_start, "quote_end": self.quote_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, - "escape": self.tokenizer_class.ESCAPES[0], + "string_escape": self.tokenizer_class.STRING_ESCAPES[0], + "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], "index_offset": self.index_offset, "time_mapping": self.inverse_time_mapping, "time_trie": self.inverse_time_trie, @@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect): ) -if t.TYPE_CHECKING: - DialectType = t.Union[str, Dialect, t.Type[Dialect], None] +DialectType = t.Union[str, Dialect, t.Type[Dialect], None] -def rename_func(name): +def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: def _rename(self, expression): args = flatten(expression.args.values()) return f"{self.normalize_func(name)}({self.format_args(*args)})" @@ -214,32 +222,34 @@ def rename_func(name): return _rename -def approx_count_distinct_sql(self, expression): +def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: if expression.args.get("accuracy"): self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})" -def if_sql(self, expression): +def if_sql(self: Generator, expression: exp.If) -> str: expressions = self.format_args( expression.this, expression.args.get("true"), expression.args.get("false") ) return f"IF({expressions})" -def arrow_json_extract_sql(self, expression): +def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: return self.binary(expression, "->") -def arrow_json_extract_scalar_sql(self, expression): +def arrow_json_extract_scalar_sql( + self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar +) -> str: return self.binary(expression, "->>") -def inline_array_sql(self, expression): +def inline_array_sql(self: Generator, expression: exp.Array) -> str: return f"[{self.expressions(expression)}]" -def no_ilike_sql(self, expression): +def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( exp.Like( this=exp.Lower(this=expression.this), @@ -248,44 +258,44 @@ def no_ilike_sql(self, expression): ) -def no_paren_current_date_sql(self, expression): +def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" -def no_recursive_cte_sql(self, expression): +def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: if expression.args.get("recursive"): self.unsupported("Recursive CTEs are unsupported") expression.args["recursive"] = False return self.with_sql(expression) -def no_safe_divide_sql(self, expression): +def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: n = self.sql(expression, "this") d = self.sql(expression, "expression") return f"IF({d} <> 0, {n} / {d}, NULL)" -def no_tablesample_sql(self, expression): +def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: self.unsupported("TABLESAMPLE unsupported") return self.sql(expression.this) -def no_pivot_sql(self, expression): +def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: self.unsupported("PIVOT unsupported") return self.sql(expression) -def no_trycast_sql(self, expression): +def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: return self.cast_sql(expression) -def no_properties_sql(self, expression): +def no_properties_sql(self: Generator, expression: exp.Properties) -> str: self.unsupported("Properties unsupported") return "" -def str_position_sql(self, expression): +def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") position = self.sql(expression, "position") @@ -294,13 +304,15 @@ def str_position_sql(self, expression): return f"STRPOS({this}, {substr})" -def struct_extract_sql(self, expression): +def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: this = self.sql(expression, "this") struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) return f"{this}.{struct_key}" -def var_map_sql(self, expression, map_func_name="MAP"): +def var_map_sql( + self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" +) -> str: keys = expression.args["keys"] values = expression.args["values"] @@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"): return f"{map_func_name}({self.format_args(*args)})" -def format_time_lambda(exp_class, dialect, default=None): +def format_time_lambda( + exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None +) -> t.Callable[[t.Sequence], E]: """Helper used for time expressions. - Args - exp_class (Class): the expression class to instantiate - dialect (string): sql dialect - default (Option[bool | str]): the default format, True being time + Args: + exp_class: the expression class to instantiate. + dialect: target sql dialect. + default: the default format, True being time. + + Returns: + A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args): + def _format_time(args: t.Sequence): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( - seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) + seq_get(args, 1) + or (Dialect[dialect].time_format if default is True else default or None) ), ) return _format_time -def create_with_partitions_sql(self, expression): +def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: """ In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding @@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression): return self.create_sql(expression) -def parse_date_delta(exp_class, unit_mapping=None): - def inner_func(args): +def parse_date_delta( + exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None +) -> t.Callable[[t.Sequence], E]: + def inner_func(args: t.Sequence) -> E: unit_based = len(args) == 3 this = seq_get(args, 2) if unit_based else seq_get(args, 0) expression = seq_get(args, 1) if unit_based else seq_get(args, 1) unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore return exp_class(this=this, expression=expression, unit=unit) return inner_func -def locate_to_strposition(args): +def locate_to_strposition(args: t.Sequence) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -379,22 +399,22 @@ def locate_to_strposition(args): ) -def strposition_to_locate_sql(self, expression): +def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: args = self.format_args( expression.args.get("substr"), expression.this, expression.args.get("position") ) return f"LOCATE({args})" -def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: +def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" -def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: +def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return f"CAST({self.sql(expression, 'this')} AS DATE)" -def trim_sql(self, expression): +def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index d0a0251..1730eaf 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( @@ -16,35 +17,29 @@ from sqlglot.dialects.dialect import ( ) -def _to_timestamp(args): - # TO_TIMESTAMP accepts either a single double argument or (text, text) - if len(args) == 1 and args[0].is_number: - return exp.UnixToTime.from_arg_list(args) - return format_time_lambda(exp.StrToTime, "drill")(args) - - -def _str_to_time_sql(self, expression): +def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self, expression): +def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Drill.time_format, Drill.date_format): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)" -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") - unit = expression.text("unit").upper() or "DAY" - expression = self.sql(expression, "expression") - return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})" + unit = exp.Var(this=expression.text("unit").upper() or "DAY") + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func -def if_sql(self, expression): +def if_sql(self: generator.Generator, expression: exp.If) -> str: """ Drill requires backticks around certain SQL reserved words, IF being one of them, This function adds the backticks around the keyword IF. @@ -61,7 +56,7 @@ def if_sql(self, expression): return f"`IF`({expressions})" -def _str_to_date(self, expression): +def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Drill.date_format: @@ -111,7 +106,7 @@ class Drill(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'"] IDENTIFIERS = ["`"] - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] ENCODE = "utf-8" class Parser(parser.Parser): @@ -168,10 +163,10 @@ class Drill(Dialect): exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})", exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } - def normalize_func(self, name): + def normalize_func(self, name: str) -> str: return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 95ff95c..959e5e2 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -25,10 +25,9 @@ def _str_to_time_sql(self, expression): def _ts_or_ds_add(self, expression): - this = self.sql(expression, "this") - e = self.sql(expression, "expression") + this = expression.args.get("this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"CAST({this} AS DATE) + INTERVAL {e} {unit}" + return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _ts_or_ds_to_date_sql(self, expression): @@ -40,9 +39,8 @@ def _ts_or_ds_to_date_sql(self, expression): def _date_add(self, expression): this = self.sql(expression, "this") - e = self.sql(expression, "expression") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"{this} + INTERVAL {e} {unit}" + return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _array_sort_sql(self, expression): diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index f2b6eaa..c558b70 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -172,7 +172,7 @@ class Hive(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] IDENTIFIERS = ["`"] - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] ENCODE = "utf-8" KEYWORDS = { diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index a5bd86b..c2c2c8c 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -89,8 +89,9 @@ def _date_add_sql(kind): def func(self, expression): this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" - expression = self.sql(expression, "expression") - return f"DATE_{kind}({this}, INTERVAL {expression} {unit})" + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func @@ -117,7 +118,7 @@ class MySQL(Dialect): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPES = ["'", "\\"] + STRING_ESCAPES = ["'", "\\"] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 6418032..c709665 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -40,8 +40,7 @@ def _date_add_sql(kind): expression = expression.copy() expression.args["is_string"] = True - expression = self.sql(expression) - return f"{this} {kind} INTERVAL {expression} {unit}" + return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}" return func diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index c3c99eb..813ee5f 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -37,11 +37,10 @@ class Redshift(Postgres): return this class Tokenizer(Postgres.Tokenizer): - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore - "COPY": TokenType.COMMAND, "ENCODE": TokenType.ENCODE, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 3b83b02..55a6bd3 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -180,7 +180,7 @@ class Snowflake(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPES = ["\\", "'"] + STRING_ESCAPES = ["\\", "'"] SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, @@ -191,6 +191,7 @@ class Snowflake(Dialect): **tokens.Tokenizer.KEYWORDS, "EXCLUDE": TokenType.EXCEPT, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, + "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, @@ -222,6 +223,7 @@ class Snowflake(Dialect): exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.UnixToTime: _unix_to_time_sql, + exp.DayOfWeek: rename_func("DAYOFWEEK"), } TYPE_MAPPING = { @@ -294,3 +296,12 @@ class Snowflake(Dialect): kind = f" {kind_value}" if kind_value else "" this = f" {self.sql(expression, 'this')}" return f"DESCRIBE{kind}{this}" + + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + start = expression.args.get("start") + start = f" START {start}" if start else "" + increment = expression.args.get("increment") + increment = f" INCREMENT {increment}" if increment else "" + return f"AUTOINCREMENT{start}{increment}" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 8ef4a87..03ec211 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -157,6 +157,7 @@ class Spark(Hive): TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False + CREATE_FUNCTION_AS = False def cast_sql(self, expression: exp.Cast) -> str: if isinstance(expression.this, exp.Cast) and expression.this.is_type( diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 1b39449..a428dd5 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -49,7 +49,6 @@ class SQLite(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "AUTOINCREMENT": TokenType.AUTO_INCREMENT, } class Parser(parser.Parser): diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 7d5ec21..7530613 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -1,5 +1,6 @@ """ .. include:: ../posts/sql_diff.md + ---- """ diff --git a/sqlglot/errors.py b/sqlglot/errors.py index b5ef5ad..300c215 100644 --- a/sqlglot/errors.py +++ b/sqlglot/errors.py @@ -7,10 +7,17 @@ from sqlglot.helper import AutoName class ErrorLevel(AutoName): - IGNORE = auto() # Ignore any parser errors - WARN = auto() # Log any parser errors with ERROR level - RAISE = auto() # Collect all parser errors and raise a single exception - IMMEDIATE = auto() # Immediately raise an exception on the first parser error + IGNORE = auto() + """Ignore all errors.""" + + WARN = auto() + """Log all errors.""" + + RAISE = auto() + """Collect all errors and raise a single exception.""" + + IMMEDIATE = auto() + """Immediately raise an exception on the first error found.""" class SqlglotError(Exception): diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index 67b4b00..c3d2701 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -1,5 +1,6 @@ """ .. include:: ../../posts/python_sql_engine.md + ---- """ diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 29848c6..de570b0 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -408,7 +408,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str: class Python(Dialect): class Tokenizer(tokens.Tokenizer): - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] class Generator(generator.Generator): TRANSFORMS = { diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 6bb083a..6800cd5 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6,6 +6,7 @@ Every AST node in SQLGlot is represented by a subclass of `Expression`. This module contains the implementation of all supported `Expression` types. Additionally, it exposes a number of helper functions, which are mainly used to programmatically build SQL expressions, such as `sqlglot.expressions.select`. + ---- """ @@ -137,6 +138,8 @@ class Expression(metaclass=_Expression): return field if isinstance(field, (Identifier, Literal, Var)): return field.this + if isinstance(field, (Star, Null)): + return field.name return "" @property @@ -176,13 +179,11 @@ class Expression(metaclass=_Expression): return self.text("alias") @property - def name(self): + def name(self) -> str: return self.text("this") @property def alias_or_name(self): - if isinstance(self, Null): - return "NULL" return self.alias or self.name @property @@ -589,12 +590,11 @@ class Expression(metaclass=_Expression): return load(obj) -if t.TYPE_CHECKING: - IntoType = t.Union[ - str, - t.Type[Expression], - t.Collection[t.Union[str, t.Type[Expression]]], - ] +IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], +] class Condition(Expression): @@ -939,7 +939,7 @@ class EncodeColumnConstraint(ColumnConstraintKind): class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT - arg_types = {"this": True, "start": False, "increment": False} + arg_types = {"this": False, "start": False, "increment": False} class NotNullColumnConstraint(ColumnConstraintKind): @@ -2390,7 +2390,7 @@ class Star(Expression): arg_types = {"except": False, "replace": False} @property - def name(self): + def name(self) -> str: return "*" @property @@ -2413,6 +2413,10 @@ class Placeholder(Expression): class Null(Condition): arg_types: t.Dict[str, t.Any] = {} + @property + def name(self) -> str: + return "NULL" + class Boolean(Condition): pass @@ -2644,7 +2648,9 @@ class Div(Binary): class Dot(Binary): - pass + @property + def name(self) -> str: + return self.expression.name class DPipe(Binary): @@ -2961,7 +2967,7 @@ class Cast(Func): arg_types = {"this": True, "to": True} @property - def name(self): + def name(self) -> str: return self.this.name @property @@ -4027,17 +4033,39 @@ def paren(expression) -> Paren: SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") -def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: - if alias is None: +@t.overload +def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None: + ... + + +@t.overload +def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier: + ... + + +def to_identifier(name, quoted=None): + """Builds an identifier. + + Args: + name: The name to turn into an identifier. + quoted: Whether or not force quote the identifier. + + Returns: + The identifier ast node. + """ + + if name is None: return None - if isinstance(alias, Identifier): - identifier = alias - elif isinstance(alias, str): - if quoted is None: - quoted = not re.match(SAFE_IDENTIFIER_RE, alias) - identifier = Identifier(this=alias, quoted=quoted) + + if isinstance(name, Identifier): + identifier = name + elif isinstance(name, str): + identifier = Identifier( + this=name, + quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted, + ) else: - raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}") + raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}") return identifier @@ -4112,20 +4140,31 @@ def to_column(sql_path: str | Column, **kwargs) -> Column: return Column(this=column_name, table=table_name, **kwargs) -def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): - """ - Create an Alias expression. +def alias_( + expression: str | Expression, + alias: str | Identifier, + table: bool | t.Sequence[str | Identifier] = False, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + **opts, +): + """Create an Alias expression. + Example: >>> alias_('foo', 'bar').sql() 'foo AS bar' + >>> alias_('(select 1, 2)', 'bar', table=['a', 'b']).sql() + '(SELECT 1, 2) AS bar(a, b)' + Args: - expression (str | Expression): the SQL code strings to parse. + expression: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - alias (str | Identifier): the alias name to use. If the name has + alias: the alias name to use. If the name has special characters it is quoted. - table (bool): create a table alias, default false - dialect (str): the dialect used to parse the input expression. + table: Whether or not to create a table alias, can also be a list of columns. + quoted: whether or not to quote the alias + dialect: the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. Returns: @@ -4135,8 +4174,14 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): alias = to_identifier(alias, quoted=quoted) if table: - expression.set("alias", TableAlias(this=alias)) - return expression + table_alias = TableAlias(this=alias) + exp.set("alias", table_alias) + + if not isinstance(table, bool): + for column in table: + table_alias.append("columns", to_identifier(column, quoted=quoted)) + + return exp # We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in # the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b95e9bc..0d72fe3 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import re import typing as t from sqlglot import exp @@ -11,6 +12,8 @@ from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") +BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)") + class Generator: """ @@ -28,7 +31,8 @@ class Generator: identify (bool): if set to True all identifiers will be delimited by the corresponding character. normalize (bool): if set to True all identifiers will lower cased - escape (str): specifies an escape character. Default: '. + string_escape (str): specifies a string escape character. Default: '. + identifier_escape (str): specifies an identifier escape character. Default: ". pad (int): determines padding in a formatted string. Default: 2. indent (int): determines the size of indentation in a formatted string. Default: 4. unnest_column_only (bool): if true unnest table aliases are considered only as column aliases @@ -85,6 +89,9 @@ class Generator: # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True + # Whether or not create function uses an AS before the def. + CREATE_FUNCTION_AS = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -154,7 +161,8 @@ class Generator: "identifier_end", "identify", "normalize", - "escape", + "string_escape", + "identifier_escape", "pad", "index_offset", "unnest_column_only", @@ -167,6 +175,7 @@ class Generator: "_indent", "_replace_backslash", "_escaped_quote_end", + "_escaped_identifier_end", "_leading_comma", "_max_text_width", "_comments", @@ -183,7 +192,8 @@ class Generator: identifier_end=None, identify=False, normalize=False, - escape=None, + string_escape=None, + identifier_escape=None, pad=2, indent=2, index_offset=0, @@ -208,7 +218,8 @@ class Generator: self.identifier_end = identifier_end or '"' self.identify = identify self.normalize = normalize - self.escape = escape or "'" + self.string_escape = string_escape or "'" + self.identifier_escape = identifier_escape or '"' self.pad = pad self.index_offset = index_offset self.unnest_column_only = unnest_column_only @@ -219,8 +230,9 @@ class Generator: self.max_unsupported = max_unsupported self.null_ordering = null_ordering self._indent = indent - self._replace_backslash = self.escape == "\\" - self._escaped_quote_end = self.escape + self.quote_end + self._replace_backslash = self.string_escape == "\\" + self._escaped_quote_end = self.string_escape + self.quote_end + self._escaped_identifier_end = self.identifier_escape + self.identifier_end self._leading_comma = leading_comma self._max_text_width = max_text_width self._comments = comments @@ -441,6 +453,9 @@ class Generator: def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: + this = "" + if expression.this is not None: + this = " ALWAYS " if expression.this else " BY DEFAULT " start = expression.args.get("start") start = f"START WITH {start}" if start else "" increment = expression.args.get("increment") @@ -449,9 +464,7 @@ class Generator: if start or increment: sequence_opts = f"{start} {increment}" sequence_opts = f" ({sequence_opts.strip()})" - return ( - f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}" - ) + return f"GENERATED{this}AS IDENTITY{sequence_opts}" def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" @@ -496,7 +509,12 @@ class Generator: properties_sql = self.sql(properties_exp, "properties") begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") - expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" + if expression_sql: + expression_sql = f"{begin}{self.sep()}{expression_sql}" + + if self.CREATE_FUNCTION_AS or kind != "FUNCTION": + expression_sql = f" AS{expression_sql}" + temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" @@ -701,6 +719,7 @@ class Generator: def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name text = text.lower() if self.normalize else text + text = text.replace(self.identifier_end, self._escaped_identifier_end) if expression.args.get("quoted") or self.identify: text = f"{self.identifier_start}{text}{self.identifier_end}" return text @@ -1121,7 +1140,7 @@ class Generator: text = expression.this or "" if expression.is_string: if self._replace_backslash: - text = text.replace("\\", "\\\\") + text = BACKSLASH_RE.sub(r"\\\\", text) text = text.replace(self.quote_end, self._escaped_quote_end) if self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) @@ -1486,9 +1505,16 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression: exp.Interval) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - unit = self.sql(expression, "unit") + this = expression.args.get("this") + if this: + this = ( + f" {this}" + if isinstance(this, exp.Literal) or isinstance(this, exp.Paren) + else f" ({this})" + ) + else: + this = "" + unit = expression.args.get("unit") unit = f" {unit}" if unit else "" return f"INTERVAL{this}{unit}" diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index a39ad8c..908f126 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from sqlglot import Schema, exp, maybe_parse from sqlglot.optimizer import Scope, build_scope, optimize +from sqlglot.optimizer.expand_laterals import expand_laterals from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables @@ -38,7 +39,7 @@ def lineage( sql: str | exp.Expression, schema: t.Optional[t.Dict | Schema] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, - rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns), + rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals), dialect: DialectType = None, ) -> Node: """Build the lineage graph for a column of a SQL query. diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index bfb2bb8..66f97a9 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -255,12 +255,23 @@ class TypeAnnotator: for name, source in scope.sources.items(): if not isinstance(source, Scope): continue - if isinstance(source.expression, exp.Values): + if isinstance(source.expression, exp.UDTF): + values = [] + + if isinstance(source.expression, exp.Lateral): + if isinstance(source.expression.this, exp.Explode): + values = [source.expression.this.this] + else: + values = source.expression.expressions[0].expressions + + if not values: + continue + selects[name] = { alias: column for alias, column in zip( source.expression.alias_column_names, - source.expression.expressions[0].expressions, + values, ) } else: @@ -272,7 +283,7 @@ class TypeAnnotator: source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) - elif source: + elif source and col.table in selects: col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py new file mode 100644 index 0000000..59f3fec --- /dev/null +++ b/sqlglot/optimizer/expand_laterals.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp + + +def expand_laterals(expression: exp.Expression) -> exp.Expression: + """ + Expand lateral column alias references. + + This assumes `qualify_columns` as already run. + + Example: + >>> import sqlglot + >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x" + >>> expression = sqlglot.parse_one(sql) + >>> expand_laterals(expression).sql() + 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x' + + Args: + expression: expression to optimize + Returns: + optimized expression + """ + for select in expression.find_all(exp.Select): + alias_to_expression: t.Dict[str, exp.Expression] = {} + for projection in select.expressions: + for column in projection.find_all(exp.Column): + if not column.table and column.name in alias_to_expression: + column.replace(alias_to_expression[column.name].copy()) + if isinstance(projection, exp.Alias): + alias_to_expression[projection.alias] = projection.this + return expression diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 766e059..96fd56b 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries +from sqlglot.optimizer.expand_laterals import expand_laterals from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.lower_identities import lower_identities @@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections -from sqlglot.optimizer.qualify_columns import qualify_columns +from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema @@ -22,6 +23,8 @@ RULES = ( qualify_tables, isolate_table_selects, qualify_columns, + expand_laterals, + validate_qualify_columns, pushdown_projections, normalize, unnest_subqueries, diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index a73647c..54c5021 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope SELECT_ALL = object() # Selection to use if selection list is empty -DEFAULT_SELECTION = alias("1", "_") +DEFAULT_SELECTION = lambda: alias("1", "_") def pushdown_projections(expression): @@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION.copy()) + new_selections.append(DEFAULT_SELECTION()) scope.expression.set("expressions", new_selections) if removed: @@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove): selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove ] if not new_selections: - new_selections.append(DEFAULT_SELECTION.copy()) + new_selections.append(DEFAULT_SELECTION()) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 54425a8..ab13d01 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -37,11 +37,24 @@ def qualify_columns(expression, schema): if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) - _check_unknown_tables(scope) return expression +def validate_qualify_columns(expression): + """Raise an `OptimizeError` if any columns aren't qualified""" + unqualified_columns = [] + for scope in traverse_scope(expression): + if isinstance(scope.expression, exp.Select): + unqualified_columns.extend(scope.unqualified_columns) + if scope.external_columns and not scope.is_correlated_subquery: + raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") + + if unqualified_columns: + raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") + return expression + + def _pop_table_column_aliases(derived_tables): """ Remove table column aliases. @@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver): if not column_table: column_table = resolver.get_table(column_name) - if not scope.is_subquery and not scope.is_udtf: - if column_table is None: - raise OptimizeError(f"Ambiguous column: {column_name}") - # column_table can be a '' because bigquery unnest has no table alias if column_table: column.set("table", exp.to_identifier(column_table)) @@ -231,10 +240,8 @@ def _qualify_columns(scope, resolver): for column in columns_missing_from_scope: column_table = resolver.get_table(column.name) - if column_table is None: - raise OptimizeError(f"Ambiguous column: {column.name}") - - column.set("table", exp.to_identifier(column_table)) + if column_table: + column.set("table", exp.to_identifier(column_table)) def _expand_stars(scope, resolver): @@ -322,11 +329,6 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) -def _check_unknown_tables(scope): - if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery: - raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") - - class _Resolver: """ Helper for resolving columns. diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 5d8e0d9..65593bd 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -2,7 +2,7 @@ import itertools from sqlglot import alias, exp from sqlglot.helper import csv_reader -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import Scope, traverse_scope def qualify_tables(expression, db=None, catalog=None, schema=None): @@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): """ sequence = itertools.count() + next_name = lambda: f"_q_{next(sequence)}" + for scope in traverse_scope(expression): for derived_table in scope.ctes + scope.derived_tables: if not derived_table.args.get("alias"): @@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): source = source.replace( alias( source.copy(), - source.this if identifier else f"_q_{next(sequence)}", + source.this if identifier else next_name(), table=True, ) ) @@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): schema.add_table( source, {k: type(v).__name__ for k, v in zip(header, columns)} ) + elif isinstance(source, Scope) and source.is_udtf: + udtf = source.expression + table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) + udtf.set("alias", table_alias) + + if not table_alias.name: + table_alias.set("this", next_name()) return expression diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index badbb87..8565c64 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -237,6 +237,8 @@ class Scope: ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) if ( not ancestor + # Window functions can have an ORDER BY clause + or not isinstance(ancestor.parent, exp.Select) or column.table or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) ): @@ -479,7 +481,7 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.UDTF): - pass + _set_udtf_scope(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) else: @@ -509,6 +511,22 @@ def _traverse_union(scope): scope.union_scopes = [left, right] +def _set_udtf_scope(scope): + parent = scope.expression.parent + from_ = parent.args.get("from") + + if not from_: + return + + for table in from_.expressions: + if isinstance(table, exp.Table): + scope.tables.append(table) + elif isinstance(table, exp.Subquery): + scope.subqueries.append(table) + _add_table_sources(scope) + _traverse_subqueries(scope) + + def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} is_cte = scope_type == ScopeType.CTE diff --git a/sqlglot/parser.py b/sqlglot/parser.py index e2b2c54..579c2ce 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -194,6 +194,7 @@ class Parser(metaclass=_Parser): TokenType.INTERVAL, TokenType.LAZY, TokenType.LEADING, + TokenType.LEFT, TokenType.LOCAL, TokenType.MATERIALIZED, TokenType.MERGE, @@ -208,6 +209,7 @@ class Parser(metaclass=_Parser): TokenType.PRECEDING, TokenType.RANGE, TokenType.REFERENCES, + TokenType.RIGHT, TokenType.ROW, TokenType.ROWS, TokenType.SCHEMA, @@ -237,8 +239,10 @@ class Parser(metaclass=_Parser): TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { TokenType.APPLY, + TokenType.LEFT, TokenType.NATURAL, TokenType.OFFSET, + TokenType.RIGHT, TokenType.WINDOW, } @@ -258,6 +262,8 @@ class Parser(metaclass=_Parser): TokenType.IDENTIFIER, TokenType.INDEX, TokenType.ISNULL, + TokenType.ILIKE, + TokenType.LIKE, TokenType.MERGE, TokenType.OFFSET, TokenType.PRIMARY_KEY, @@ -971,13 +977,14 @@ class Parser(metaclass=_Parser): if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=create_token.token_type) properties = self._parse_properties() - if self._match(TokenType.ALIAS): - begin = self._match(TokenType.BEGIN) - return_ = self._match_text_seq("RETURN") - expression = self._parse_statement() - if return_: - expression = self.expression(exp.Return, this=expression) + self._match(TokenType.ALIAS) + begin = self._match(TokenType.BEGIN) + return_ = self._match_text_seq("RETURN") + expression = self._parse_statement() + + if return_: + expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: this = self._parse_index() elif create_token.token_type in ( @@ -2163,7 +2170,9 @@ class Parser(metaclass=_Parser): ) -> t.Optional[exp.Expression]: if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) - limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) + limit_exp = self.expression( + exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term() + ) if limit_paren: self._match_r_paren() @@ -2740,8 +2749,23 @@ class Parser(metaclass=_Parser): kind: exp.Expression - if self._match(TokenType.AUTO_INCREMENT): - kind = exp.AutoIncrementColumnConstraint() + if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)): + start = None + increment = None + + if self._match(TokenType.L_PAREN, advance=False): + args = self._parse_wrapped_csv(self._parse_bitwise) + start = seq_get(args, 0) + increment = seq_get(args, 1) + elif self._match_text_seq("START"): + start = self._parse_bitwise() + self._match_text_seq("INCREMENT") + increment = self._parse_bitwise() + + if start and increment: + kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment) + else: + kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): constraint = self._parse_wrapped(self._parse_conjunction) kind = self.expression(exp.CheckColumnConstraint, this=constraint) @@ -3294,8 +3318,8 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.EXCEPT): return None if self._match(TokenType.L_PAREN, advance=False): - return self._parse_wrapped_id_vars() - return self._parse_csv(self._parse_id_var) + return self._parse_wrapped_csv(self._parse_column) + return self._parse_csv(self._parse_column) def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.REPLACE): @@ -3442,7 +3466,7 @@ class Parser(metaclass=_Parser): def _parse_alter(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE): - return None + return self._parse_as_command(self._prev) exists = self._parse_exists() this = self._parse_table(schema=True) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index e95057a..8cf17a7 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -357,7 +357,8 @@ class _Tokenizer(type): klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) - klass._ESCAPES = set(klass.ESCAPES) + klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) + klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES) klass._COMMENTS = dict( (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS @@ -429,9 +430,13 @@ class Tokenizer(metaclass=_Tokenizer): IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] - ESCAPES = ["'"] + STRING_ESCAPES = ["'"] - _ESCAPES: t.Set[str] = set() + _STRING_ESCAPES: t.Set[str] = set() + + IDENTIFIER_ESCAPES = ['"'] + + _IDENTIFIER_ESCAPES: t.Set[str] = set() KEYWORDS = { **{ @@ -469,6 +474,7 @@ class Tokenizer(metaclass=_Tokenizer): "ASC": TokenType.ASC, "AS": TokenType.ALIAS, "AT TIME ZONE": TokenType.AT_TIME_ZONE, + "AUTOINCREMENT": TokenType.AUTO_INCREMENT, "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, "BEGIN": TokenType.BEGIN, "BETWEEN": TokenType.BETWEEN, @@ -691,6 +697,7 @@ class Tokenizer(metaclass=_Tokenizer): "ALTER VIEW": TokenType.COMMAND, "ANALYZE": TokenType.COMMAND, "CALL": TokenType.COMMAND, + "COPY": TokenType.COMMAND, "EXPLAIN": TokenType.COMMAND, "OPTIMIZE": TokenType.COMMAND, "PREPARE": TokenType.COMMAND, @@ -744,7 +751,7 @@ class Tokenizer(metaclass=_Tokenizer): ) def __init__(self) -> None: - self._replace_backslash = "\\" in self._ESCAPES + self._replace_backslash = "\\" in self._STRING_ESCAPES self.reset() def reset(self) -> None: @@ -1046,12 +1053,25 @@ class Tokenizer(metaclass=_Tokenizer): return True def _scan_identifier(self, identifier_end: str) -> None: - while self._peek != identifier_end: + text = "" + identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES + + while True: if self._end: raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") + self._advance() - self._advance() - self._add(TokenType.IDENTIFIER, self._text[1:-1]) + if self._char == identifier_end: + if identifier_end_is_escape and self._peek == identifier_end: + text += identifier_end # type: ignore + self._advance() + continue + + break + + text += self._char # type: ignore + + self._add(TokenType.IDENTIFIER, text) def _scan_var(self) -> None: while True: @@ -1072,9 +1092,9 @@ class Tokenizer(metaclass=_Tokenizer): while True: if ( - self._char in self._ESCAPES + self._char in self._STRING_ESCAPES and self._peek - and (self._peek == delimiter or self._peek in self._ESCAPES) + and (self._peek == delimiter or self._peek in self._STRING_ESCAPES) ): text += self._peek self._advance(2) -- cgit v1.2.3