From 6a22086850fc960715b618e82f4c2e43a4529146 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 20 Feb 2023 09:50:35 +0100 Subject: Merging upstream version 11.2.0. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dataframe/sql/dataframe.py | 7 +++---- sqlglot/dialects/bigquery.py | 2 ++ sqlglot/dialects/dialect.py | 15 +++++++++++++++ sqlglot/dialects/drill.py | 3 ++- sqlglot/dialects/duckdb.py | 19 +++++-------------- sqlglot/dialects/snowflake.py | 2 ++ sqlglot/expressions.py | 8 ++++++++ sqlglot/generator.py | 9 +++++++++ sqlglot/lineage.py | 2 +- sqlglot/optimizer/annotate_types.py | 2 +- sqlglot/optimizer/qualify_columns.py | 27 +++++++++++++++++++-------- sqlglot/parser.py | 14 ++++++++++++++ 13 files changed, 82 insertions(+), 30 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index c17a703..7bcaa22 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -40,7 +40,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.1.3" +__version__ = "11.2.0" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 65a37f5..93ca45a 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -18,7 +18,6 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window from sqlglot.helper import ensure_list, object_to_dict from sqlglot.optimizer import optimize as optimize_func -from sqlglot.optimizer.qualify_columns import qualify_columns if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ( @@ -376,9 +375,9 @@ class DataFrame: else: table_identifier = ctes_with_column[0].args["alias"].this ambiguous_col.expression.set("table", table_identifier) - expression = self.expression.select(*[x.expression for x in cols], **kwargs) - qualify_columns(expression, sqlglot.schema) - return self.copy(expression=expression, **kwargs) + return self.copy( + expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs + ) @operation(Operation.NO_OP) def alias(self, name: str, **kwargs) -> DataFrame: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 7fd9e35..a75e802 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, rename_func, timestrtotime_sql, + ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -233,6 +234,7 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.VariancePop: rename_func("VAR_POP"), exp.Values: _derived_table_values_to_unnest, diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index f4e8fd4..af36256 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -423,3 +423,18 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str: from_part = "FROM " if trim_type or remove_chars else "" collation = f" COLLATE {collation}" if collation else "" return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" + + +def str_to_time_sql(self, expression: exp.Expression) -> str: + return self.func("STRPTIME", expression.this, self.format_time(expression)) + + +def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: + def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: + _dialect = Dialect.get_or_raise(dialect) + time_format = self.format_time(expression) + if time_format and time_format not in (_dialect.time_format, _dialect.date_format): + return f"CAST({str_to_time_sql(self, expression)} AS DATE)" + return f"CAST({self.sql(expression, 'this')} AS DATE)" + + return _ts_or_ds_to_date_sql diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index e9c42e1..afcf4d0 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, timestrtotime_sql, + ts_or_ds_to_date_sql, ) @@ -147,7 +148,7 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})", - exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index cfec9a4..6144101 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -14,29 +14,20 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, rename_func, str_position_sql, + str_to_time_sql, timestrtotime_sql, + ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _str_to_time_sql(self, expression): - return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" - - def _ts_or_ds_add(self, expression): this = expression.args.get("this") unit = self.sql(expression, "unit").strip("'") or "DAY" return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" -def _ts_or_ds_to_date_sql(self, expression): - time_format = self.format_time(expression) - if time_format and time_format not in (DuckDB.time_format, DuckDB.date_format): - return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" - return f"CAST({self.sql(expression, 'this')} AS DATE)" - - def _date_add(self, expression): this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" @@ -159,8 +150,8 @@ class DuckDB(Dialect): exp.Split: rename_func("STR_SPLIT"), exp.SortArray: _sort_array_sql, exp.StrPosition: str_position_sql, - exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", - exp.StrToTime: _str_to_time_sql, + exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)", + exp.StrToTime: str_to_time_sql, exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, exp.TableSample: no_tablesample_sql, @@ -171,7 +162,7 @@ class DuckDB(Dialect): exp.TimeToUnix: rename_func("EPOCH"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add, - exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"), exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index bb46135..9342865 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( inline_array_sql, rename_func, timestrtotime_sql, + ts_or_ds_to_date_sql, var_map_sql, ) from sqlglot.expressions import Literal @@ -236,6 +237,7 @@ class Snowflake(Dialect): exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), + exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 42652a6..a29aeb4 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -946,6 +946,10 @@ class CommentColumnConstraint(ColumnConstraintKind): pass +class CompressColumnConstraint(ColumnConstraintKind): + pass + + class DateFormatColumnConstraint(ColumnConstraintKind): arg_types = {"this": True} @@ -970,6 +974,10 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): } +class InlineLengthColumnConstraint(ColumnConstraintKind): + pass + + class NotNullColumnConstraint(ColumnConstraintKind): arg_types = {"allow_null": False} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 1479e28..18ae42a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -84,6 +84,7 @@ class Generator: exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", + exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -445,6 +446,14 @@ class Generator: def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) + def compresscolumnconstraint_sql(self, expression: exp.CompressColumnConstraint) -> str: + if isinstance(expression.this, list): + this = self.wrap(self.expressions(expression, key="this", flat=True)) + else: + this = self.sql(expression, "this") + + return f"COMPRESS {this}" + def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 908f126..2e563ae 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -113,7 +113,7 @@ def lineage( ) else: if table not in tables: - tables[table] = Node(name=table, source=source, expression=source) + tables[table] = Node(name=c.sql(), source=source, expression=source) node.downstream.append(tables[table]) return node diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index be65ab9..ca2131c 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -286,7 +286,7 @@ class TypeAnnotator: source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) - elif source and col.table in selects: + elif source and col.table in selects and col.name in selects[col.table]: col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index a7bd9b5..e793e31 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -146,7 +146,7 @@ def _expand_group_by(scope, resolver): # Source columns get priority over select aliases if table: - node.set("table", exp.to_identifier(table)) + node.set("table", table) return node selects = {s.alias_or_name: s for s in scope.selects} @@ -212,7 +212,7 @@ def _qualify_columns(scope, resolver): # column_table can be a '' because bigquery unnest has no table alias if column_table: - column.set("table", exp.to_identifier(column_table)) + column.set("table", column_table) columns_missing_from_scope = [] # Determine whether each reference in the order by clause is to a column or an alias. @@ -239,7 +239,7 @@ def _qualify_columns(scope, resolver): column_table = resolver.get_table(column.name) if column_table: - column.set("table", exp.to_identifier(column_table)) + column.set("table", column_table) def _expand_stars(scope, resolver): @@ -340,7 +340,7 @@ class Resolver: self._unambiguous_columns = None self._all_columns = None - def get_table(self, column_name: str) -> t.Optional[str]: + def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: """ Get the table for a column name. @@ -354,18 +354,29 @@ class Resolver: self._get_all_source_columns() ) - table = self._unambiguous_columns.get(column_name) + table_name = self._unambiguous_columns.get(column_name) - if not table: + if not table_name: sources_without_schema = tuple( source for source, columns in self._get_all_source_columns().items() if not columns or "*" in columns ) if len(sources_without_schema) == 1: - return sources_without_schema[0] + table_name = sources_without_schema[0] - return table + if table_name not in self.scope.selected_sources: + return exp.to_identifier(table_name) + + node, _ = self.scope.selected_sources.get(table_name) + + if isinstance(node, exp.Subqueryable): + while node and node.alias != table_name: + node = node.parent + node_alias = node.args.get("alias") + if node_alias: + return node_alias.this + return exp.to_identifier(table_name) @property def all_columns(self): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 9bde696..f92f5ac 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -594,6 +594,7 @@ class Parser(metaclass=_Parser): "COMMENT": lambda self: self.expression( exp.CommentColumnConstraint, this=self._parse_string() ), + "COMPRESS": lambda self: self._parse_compress(), "DEFAULT": lambda self: self.expression( exp.DefaultColumnConstraint, this=self._parse_bitwise() ), @@ -604,6 +605,7 @@ class Parser(metaclass=_Parser): ), "GENERATED": lambda self: self._parse_generated_as_identity(), "IDENTITY": lambda self: self._parse_auto_increment(), + "INLINE": lambda self: self._parse_inline(), "LIKE": lambda self: self._parse_create_like(), "NOT": lambda self: self._parse_not_constraint(), "NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True), @@ -2882,6 +2884,14 @@ class Parser(metaclass=_Parser): return exp.AutoIncrementColumnConstraint() + def _parse_compress(self) -> exp.Expression: + if self._match(TokenType.L_PAREN, advance=False): + return self.expression( + exp.CompressColumnConstraint, this=self._parse_wrapped_csv(self._parse_bitwise) + ) + + return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) + def _parse_generated_as_identity(self) -> exp.Expression: if self._match(TokenType.BY_DEFAULT): this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) @@ -2909,6 +2919,10 @@ class Parser(metaclass=_Parser): return this + def _parse_inline(self) -> t.Optional[exp.Expression]: + self._match_text_seq("LENGTH") + return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise()) + def _parse_not_constraint(self) -> t.Optional[exp.Expression]: if self._match_text_seq("NULL"): return self.expression(exp.NotNullColumnConstraint) -- cgit v1.2.3