diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-31 05:44:41 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-31 05:44:41 +0000 |
commit | 376de8b6892deca7dc5d83035c047f1e13eb67ea (patch) | |
tree | 334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot | |
parent | Releasing debian version 20.9.0-1. (diff) | |
download | sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip |
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
37 files changed, 822 insertions, 280 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 6cf9949..d71c06d 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -87,13 +87,11 @@ def parse( @t.overload -def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: - ... +def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ... @t.overload -def parse_one(sql: str, **opts) -> Expression: - ... +def parse_one(sql: str, **opts) -> Expression: ... def parse_one( diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py index 86d965a..65f307e 100644 --- a/sqlglot/_typing.py +++ b/sqlglot/_typing.py @@ -4,10 +4,13 @@ import typing as t import sqlglot +if t.TYPE_CHECKING: + from typing_extensions import Literal as Lit # noqa + # A little hack for backwards compatibility with Python 3.7. # For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed. # But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency. A = t.TypeVar("A", bound=t.Any) - +B = t.TypeVar("B", bound="sqlglot.exp.Binary") E = t.TypeVar("E", bound="sqlglot.exp.Expression") T = t.TypeVar("T") diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index ca85376..724c5bf 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -144,9 +144,11 @@ class Column: ) -> Column: ensured_column = None if column is None else cls.ensure_col(column) ensure_expression_values = { - k: [Column.ensure_col(x).expression for x in v] - if is_iterable(v) - else Column.ensure_col(v).expression + k: ( + [Column.ensure_col(x).expression for x in v] + if is_iterable(v) + else Column.ensure_col(v).expression + ) for k, v in kwargs.items() if v is not None } diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 68d36fe..0bacbf9 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -140,12 +140,10 @@ class DataFrame: return cte, name @t.overload - def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: - ... + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... @t.overload - def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: - ... + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... def _ensure_list_of_columns(self, cols): return Column.ensure_cols(ensure_list(cols)) @@ -496,9 +494,11 @@ class DataFrame: join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list select_column_names = [ - column.alias_or_name - if not isinstance(column.expression.this, exp.Star) - else column.sql() + ( + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql() + ) for column in self_columns + other_columns ] select_column_names = [ @@ -552,9 +552,11 @@ class DataFrame: ), "The length of items in ascending must equal the number of columns provided" col_and_ascending = list(zip(columns, ascending)) order_by_columns = [ - exp.Ordered(this=col.expression, desc=not asc) - if i not in pre_ordered_col_indexes - else columns[i].column_expression + ( + exp.Ordered(this=col.expression, desc=not asc) + if i not in pre_ordered_col_indexes + else columns[i].column_expression + ) for i, (col, asc) in enumerate(col_and_ascending) ] return self.copy(expression=self.expression.order_by(*order_by_columns)) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 141a302..a388cb4 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -661,7 +661,7 @@ def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: tz_column = tz if isinstance(tz, Column) else lit(tz) - return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column) + return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column) def timestamp_seconds(col: ColumnOrName) -> Column: diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index f68bacb..b246641 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -7,11 +7,11 @@ from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.helper import ensure_list -NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) - if t.TYPE_CHECKING: from sqlglot.dataframe.sql.session import SparkSession + NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) + def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]): expr = ensure_list(expr) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 4a33ef9..f518ac2 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -82,9 +82,11 @@ class SparkSession: ] sel_columns = [ - F.col(name).cast(data_type).alias(name).expression - if data_type is not None - else F.col(name).expression + ( + F.col(name).cast(data_type).alias(name).expression + if data_type is not None + else F.col(name).expression + ) for name, data_type in column_mapping.items() ] diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py index c1d913f..9e2fabd 100644 --- a/sqlglot/dataframe/sql/window.py +++ b/sqlglot/dataframe/sql/window.py @@ -90,9 +90,11 @@ class WindowSpec: **kwargs, **{ "start_side": "PRECEDING", - "start": "UNBOUNDED" - if start <= Window.unboundedPreceding - else F.lit(start).expression, + "start": ( + "UNBOUNDED" + if start <= Window.unboundedPreceding + else F.lit(start).expression + ), }, } if end == Window.currentRow: @@ -102,9 +104,9 @@ class WindowSpec: **kwargs, **{ "end_side": "FOLLOWING", - "end": "UNBOUNDED" - if end >= Window.unboundedFollowing - else F.lit(end).expression, + "end": ( + "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression + ), }, } return kwargs diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 0151e6c..771ae1a 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -5,7 +5,6 @@ import re import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, @@ -30,7 +29,7 @@ from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType if t.TYPE_CHECKING: - from typing_extensions import Literal + from sqlglot._typing import E, Lit logger = logging.getLogger("sqlglot") @@ -47,9 +46,11 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va exp.alias_(value, column_name) for value, column_name in zip( t.expressions, - alias.columns - if alias and alias.columns - else (f"_c{i}" for i in range(len(t.expressions))), + ( + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(t.expressions))) + ), ) ] ) @@ -473,12 +474,10 @@ class BigQuery(Dialect): return table @t.overload - def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: - ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... @t.overload - def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: - ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... def _parse_json_object(self, agg=False): json_object = super()._parse_json_object() @@ -546,9 +545,11 @@ class BigQuery(Dialect): exp.ArrayContains: _array_contains_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), - exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}" - if e.args.get("default") - else f"COLLATE {self.sql(e, 'this')}", + exp.CollateProperty: lambda self, e: ( + f"DEFAULT COLLATE {self.sql(e, 'this')}" + if e.args.get("default") + else f"COLLATE {self.sql(e, 'this')}" + ), exp.CountIf: rename_func("COUNTIF"), exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), @@ -560,6 +561,9 @@ class BigQuery(Dialect): exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"), exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"), exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), + exp.FromTimeZone: lambda self, e: self.func( + "DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'" + ), exp.GenerateSeries: rename_func("GENERATE_ARRAY"), exp.GetPath: path_to_jsonpath(), exp.GroupConcat: rename_func("STRING_AGG"), @@ -595,9 +599,9 @@ class BigQuery(Dialect): exp.SHA2: lambda self, e: self.func( f"SHA256" if e.text("length") == "256" else "SHA512", e.this ), - exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" - if e.name == "IMMUTABLE" - else "NOT DETERMINISTIC", + exp.StabilityProperty: lambda self, e: ( + f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" + ), exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: self.func( "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index f2e4fe1..1248edc 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -88,6 +88,8 @@ class ClickHouse(Dialect): "UINT8": TokenType.UTINYINT, "IPV4": TokenType.IPV4, "IPV6": TokenType.IPV6, + "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION, + "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION, } SINGLE_TOKENS = { @@ -548,6 +550,8 @@ class ClickHouse(Dialect): exp.DataType.Type.UTINYINT: "UInt8", exp.DataType.Type.IPV4: "IPv4", exp.DataType.Type.IPV6: "IPv6", + exp.DataType.Type.AGGREGATEFUNCTION: "AggregateFunction", + exp.DataType.Type.SIMPLEAGGREGATEFUNCTION: "SimpleAggregateFunction", } TRANSFORMS = { @@ -651,12 +655,16 @@ class ClickHouse(Dialect): def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: return super().after_limit_modifiers(expression) + [ - self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) - if expression.args.get("settings") - else "", - self.seg("FORMAT ") + self.sql(expression, "format") - if expression.args.get("format") - else "", + ( + self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) + if expression.args.get("settings") + else "" + ), + ( + self.seg("FORMAT ") + self.sql(expression, "format") + if expression.args.get("format") + else "" + ), ] def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str: diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 7664c40..6be991b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -5,7 +5,6 @@ from enum import Enum, auto from functools import reduce from sqlglot import exp -from sqlglot._typing import E from sqlglot.errors import ParseError from sqlglot.generator import Generator from sqlglot.helper import AutoName, flatten, seq_get @@ -14,11 +13,12 @@ from sqlglot.time import TIMEZONES, format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie -B = t.TypeVar("B", bound=exp.Binary) - DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] +if t.TYPE_CHECKING: + from sqlglot._typing import B, E + class Dialects(str, Enum): """Dialects supported by SQLGLot.""" @@ -381,9 +381,11 @@ class Dialect(metaclass=_Dialect): ): expression.set( "this", - expression.this.upper() - if self.normalization_strategy is NormalizationStrategy.UPPERCASE - else expression.this.lower(), + ( + expression.this.upper() + if self.normalization_strategy is NormalizationStrategy.UPPERCASE + else expression.this.lower() + ), ) return expression @@ -877,9 +879,11 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). """ agg_all_unquoted = agg.transform( - lambda node: exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node + lambda node: ( + exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) ) names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) @@ -999,10 +1003,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: """Remove table refs from columns in when statements.""" alias = expression.this.args.get("alias") - normalize = ( - lambda identifier: self.dialect.normalize_identifier(identifier).name - if identifier - else None + normalize = lambda identifier: ( + self.dialect.normalize_identifier(identifier).name if identifier else None ) targets = {normalize(expression.this.this)} @@ -1012,9 +1014,11 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: for when in expression.expressions: when.transform( - lambda node: exp.column(node.this) - if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets - else node, + lambda node: ( + exp.column(node.this) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node + ), copy=False, ) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 2343b35..f55ad70 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -148,8 +148,8 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str def _rename_unless_within_group( a: str, b: str ) -> t.Callable[[DuckDB.Generator, exp.Expression], str]: - return ( - lambda self, expression: self.func(a, *flatten(expression.args.values())) + return lambda self, expression: ( + self.func(a, *flatten(expression.args.values())) if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup) else self.func(b, *flatten(expression.args.values())) ) @@ -273,9 +273,11 @@ class DuckDB(Dialect): PLACEHOLDER_PARSERS = { **parser.Parser.PLACEHOLDER_PARSERS, - TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text) - if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) - else None, + TokenType.PARAMETER: lambda self: ( + self.expression(exp.Placeholder, this=self._prev.text) + if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) + else None + ), } def _parse_types( @@ -321,9 +323,11 @@ class DuckDB(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) - if e.expressions and e.expressions[0].find(exp.Select) - else inline_array_sql(self, e), + exp.Array: lambda self, e: ( + self.func("ARRAY", e.expressions[0]) + if e.expressions and e.expressions[0].find(exp.Select) + else inline_array_sql(self, e) + ), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"), exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index dffa41e..060f9bd 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -397,9 +397,11 @@ class Hive(Dialect): if this and not schema: return this.transform( - lambda node: node.replace(exp.DataType.build("text")) - if isinstance(node, exp.DataType) and node.is_type("char", "varchar") - else node, + lambda node: ( + node.replace(exp.DataType.build("text")) + if isinstance(node, exp.DataType) and node.is_type("char", "varchar") + else node + ), copy=False, ) @@ -409,9 +411,11 @@ class Hive(Dialect): self, ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: return ( - self._parse_csv(self._parse_conjunction) - if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY}) - else [], + ( + self._parse_csv(self._parse_conjunction) + if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY}) + else [] + ), super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)), ) @@ -483,9 +487,9 @@ class Hive(Dialect): exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.Min: min_or_least, exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression), - exp.NotNullColumnConstraint: lambda self, e: "" - if e.args.get("allow_null") - else "NOT NULL", + exp.NotNullColumnConstraint: lambda self, e: ( + "" if e.args.get("allow_null") else "NOT NULL" + ), exp.VarMap: var_map_sql, exp.Create: _create_sql, exp.Quantile: rename_func("PERCENTILE"), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 6ad3718..4591d59 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -166,6 +166,7 @@ class Oracle(Dialect): TABLESAMPLE_KEYWORDS = "SAMPLE" LAST_DAY_SUPPORTS_DATE_PART = False SUPPORTS_SELECT_INTO = True + TZ_TO_WITH_TIME_ZONE = True TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -179,6 +180,8 @@ class Oracle(Dialect): exp.DataType.Type.NVARCHAR: "NVARCHAR2", exp.DataType.Type.NCHAR: "NCHAR", exp.DataType.Type.TEXT: "CLOB", + exp.DataType.Type.TIMETZ: "TIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.VARBINARY: "BLOB", } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 1ca0a78..87f6b02 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -282,6 +282,12 @@ class Postgres(Dialect): VAR_SINGLE_TOKENS = {"$"} class Parser(parser.Parser): + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + "SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()), + } + PROPERTY_PARSERS.pop("INPUT", None) + FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_TRUNC": parse_timestamp_trunc, @@ -385,9 +391,11 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, - exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" - if isinstance(seq_get(e.expressions, 0), exp.Select) - else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", + exp.Array: lambda self, e: ( + f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]" + ), exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), @@ -396,6 +404,7 @@ class Postgres(Dialect): exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.CurrentUser: lambda *_: "CURRENT_USER", exp.DateAdd: _date_add_sql("+"), exp.DateDiff: _date_diff_sql, exp.DateStrToDate: datestrtodate_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9b421e7..6cc6030 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -356,6 +356,7 @@ class Presto(Dialect): exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, + exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'", exp.GetPath: path_to_jsonpath(), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.GroupConcat: lambda self, e: self.func( diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a8e4a42..281167d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -3,7 +3,6 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, @@ -25,6 +24,9 @@ from sqlglot.expressions import Literal from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + def _check_int(s: str) -> bool: if s[0] in ("-", "+"): @@ -297,10 +299,7 @@ def _parse_colon_get_path( if not self._match(TokenType.COLON): break - if self._match_set(self.RANGE_PARSERS): - this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this - - return this + return self._parse_range(this) def _parse_timestamp_from_parts(args: t.List) -> exp.Func: @@ -376,7 +375,7 @@ class Snowflake(Dialect): and isinstance(expression.parent, exp.Table) and expression.name.lower() == "dual" ): - return t.cast(E, expression) + return expression # type: ignore return super().quote_identifier(expression, identify=identify) @@ -471,6 +470,10 @@ class Snowflake(Dialect): } SHOW_PARSERS = { + "SCHEMAS": _show_parser("SCHEMAS"), + "TERSE SCHEMAS": _show_parser("SCHEMAS"), + "OBJECTS": _show_parser("OBJECTS"), + "TERSE OBJECTS": _show_parser("OBJECTS"), "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "COLUMNS": _show_parser("COLUMNS"), @@ -580,21 +583,35 @@ class Snowflake(Dialect): scope = None scope_kind = None + # will identity SHOW TERSE SCHEMAS but not SHOW TERSE PRIMARY KEYS + # which is syntactically valid but has no effect on the output + terse = self._tokens[self._index - 2].text.upper() == "TERSE" + like = self._parse_string() if self._match(TokenType.LIKE) else None if self._match(TokenType.IN): if self._match_text_seq("ACCOUNT"): scope_kind = "ACCOUNT" elif self._match_set(self.DB_CREATABLES): - scope_kind = self._prev.text + scope_kind = self._prev.text.upper() if self._curr: - scope = self._parse_table() + scope = self._parse_table_parts() elif self._curr: - scope_kind = "TABLE" - scope = self._parse_table() + scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE" + scope = self._parse_table_parts() return self.expression( - exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind + exp.Show, + **{ + "terse": terse, + "this": this, + "like": like, + "scope": scope, + "scope_kind": scope_kind, + "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(), + "limit": self._parse_limit(), + "from": self._parse_string() if self._match(TokenType.FROM) else None, + }, ) def _parse_alter_table_swap(self) -> exp.SwapTable: @@ -690,6 +707,9 @@ class Snowflake(Dialect): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.Explode: rename_func("FLATTEN"), exp.Extract: rename_func("DATE_PART"), + exp.FromTimeZone: lambda self, e: self.func( + "CONVERT_TIMEZONE", e.args.get("zone"), "'UTC'", e.this + ), exp.GenerateSeries: lambda self, e: self.func( "ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step") ), @@ -820,6 +840,7 @@ class Snowflake(Dialect): return f"{explode}{alias}" def show_sql(self, expression: exp.Show) -> str: + terse = "TERSE " if expression.args.get("terse") else "" like = self.sql(expression, "like") like = f" LIKE {like}" if like else "" @@ -830,7 +851,19 @@ class Snowflake(Dialect): if scope_kind: scope_kind = f" IN {scope_kind}" - return f"SHOW {expression.name}{like}{scope_kind}{scope}" + starts_with = self.sql(expression, "starts_with") + if starts_with: + starts_with = f" STARTS WITH {starts_with}" + + limit = self.sql(expression, "limit") + + from_ = self.sql(expression, "from") + if from_: + from_ = f" FROM {from_}" + + return ( + f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}" + ) def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to @@ -884,3 +917,6 @@ class Snowflake(Dialect): def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ") + + def cluster_sql(self, expression: exp.Cluster) -> str: + return f"CLUSTER BY ({self.expressions(expression, flat=True)})" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index ba73ac0..624f76e 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -80,9 +80,9 @@ class Spark(Spark2): exp.TimestampAdd: lambda self, e: self.func( "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this ), - exp.TryCast: lambda self, e: self.trycast_sql(e) - if e.args.get("safe") - else self.cast_sql(e), + exp.TryCast: lambda self, e: ( + self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) + ), } TRANSFORMS.pop(exp.AnyValue) TRANSFORMS.pop(exp.DateDiff) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index e27ba18..e4bb30e 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -129,10 +129,20 @@ class Spark2(Hive): "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift), "STRING": _parse_as_cast("string"), "TIMESTAMP": _parse_as_cast("timestamp"), - "TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args) - if len(args) == 1 - else format_time_lambda(exp.StrToTime, "spark")(args), + "TO_TIMESTAMP": lambda args: ( + _parse_as_cast("timestamp")(args) + if len(args) == 1 + else format_time_lambda(exp.StrToTime, "spark")(args) + ), "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, + "TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone( + this=exp.cast_unless( + seq_get(args, 0) or exp.Var(this=""), + exp.DataType.build("timestamp"), + exp.DataType.build("timestamp"), + ), + zone=seq_get(args, 1), + ), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), } @@ -188,6 +198,7 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), + exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, @@ -255,10 +266,12 @@ class Spark2(Hive): def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: return super().columndef_sql( expression, - sep=": " - if isinstance(expression.parent, exp.DataType) - and expression.parent.is_type("struct") - else sep, + sep=( + ": " + if isinstance(expression.parent, exp.DataType) + and expression.parent.is_type("struct") + else sep + ), ) class Tokenizer(Hive.Tokenizer): diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 33ec7e1..3795045 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -38,3 +38,4 @@ class Tableau(Dialect): **parser.Parser.FUNCTIONS, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } + NO_PAREN_IF_COMMANDS = False diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b9c347c..a5e04da 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -76,9 +76,11 @@ def _format_time_lambda( format=exp.Literal.string( format_time( args[0].name.lower(), - {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} - if full_format_mapping - else TSQL.TIME_MAPPING, + ( + {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} + if full_format_mapping + else TSQL.TIME_MAPPING + ), ) ), ) @@ -264,6 +266,15 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: ) +def _parse_len(args: t.List) -> exp.Length: + this = seq_get(args, 0) + + if this and not this.is_string: + this = exp.cast(this, exp.DataType.Type.TEXT) + + return exp.Length(this=this) + + class TSQL(Dialect): NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" @@ -431,7 +442,7 @@ class TSQL(Dialect): "IIF": exp.If.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, - "LEN": exp.Length.from_arg_list, + "LEN": _parse_len, "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, @@ -469,6 +480,7 @@ class TSQL(Dialect): ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False STRING_ALIASES = True + NO_PAREN_IF_COMMANDS = False def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -478,9 +490,11 @@ class TSQL(Dialect): See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax """ return [ - exp.alias_(projection.expression, projection.this.this, copy=False) - if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column) - else projection + ( + exp.alias_(projection.expression, projection.this.this, copy=False) + if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column) + else projection + ) for projection in super()._parse_projections() ] @@ -702,7 +716,6 @@ class TSQL(Dialect): exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), exp.LastDay: lambda self, e: self.func("EOMONTH", e.this), - exp.Length: rename_func("LEN"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, @@ -922,3 +935,11 @@ class TSQL(Dialect): this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True, sep=" ") return f"CONSTRAINT {this} {expressions}" + + def length_sql(self, expression: exp.Length) -> str: + this = expression.this + if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT): + this_sql = self.sql(this, "this") + else: + this_sql = self.sql(this) + return self.func("LEN", this_sql) diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 3277e65..7ff9608 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -392,9 +392,9 @@ def _lambda_sql(self, e: exp.Lambda) -> str: names = {e.name.lower() for e in e.expressions} e = e.transform( - lambda n: exp.var(n.name) - if isinstance(n, exp.Identifier) and n.name.lower() in names - else n + lambda n: ( + exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n + ) ) return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" @@ -438,9 +438,9 @@ class Python(Dialect): exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}", exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", - exp.Is: lambda self, e: self.binary(e, "==") - if isinstance(e.this, exp.Literal) - else self.binary(e, "is"), + exp.Is: lambda self, e: ( + self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is") + ), exp.Lambda: _lambda_sql, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index ddad8f8..a95a73e 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -23,7 +23,6 @@ from copy import deepcopy from enum import auto from functools import reduce -from sqlglot._typing import E from sqlglot.errors import ErrorLevel, ParseError from sqlglot.helper import ( AutoName, @@ -36,8 +35,7 @@ from sqlglot.helper import ( from sqlglot.tokens import Token if t.TYPE_CHECKING: - from typing_extensions import Literal as Lit - + from sqlglot._typing import E, Lit from sqlglot.dialects.dialect import DialectType @@ -389,7 +387,7 @@ class Expression(metaclass=_Expression): ancestor = self.parent while ancestor and not isinstance(ancestor, expression_types): ancestor = ancestor.parent - return t.cast(E, ancestor) + return ancestor # type: ignore @property def parent_select(self) -> t.Optional[Select]: @@ -555,12 +553,10 @@ class Expression(metaclass=_Expression): return new_node @t.overload - def replace(self, expression: E) -> E: - ... + def replace(self, expression: E) -> E: ... @t.overload - def replace(self, expression: None) -> None: - ... + def replace(self, expression: None) -> None: ... def replace(self, expression): """ @@ -781,13 +777,16 @@ class Expression(metaclass=_Expression): this=maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], query=maybe_parse(query, copy=copy, **opts) if query else None, - unnest=Unnest( - expressions=[ - maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest) - ] - ) - if unnest - else None, + unnest=( + Unnest( + expressions=[ + maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) + for e in ensure_list(unnest) + ] + ) + if unnest + else None + ), ) def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between: @@ -926,7 +925,7 @@ class DerivedTable(Expression): class Unionable(Expression): def union( self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts - ) -> Unionable: + ) -> Union: """ Builds a UNION expression. @@ -1134,9 +1133,12 @@ class SetItem(Expression): class Show(Expression): arg_types = { "this": True, + "terse": False, "target": False, "offset": False, + "starts_with": False, "limit": False, + "from": False, "like": False, "where": False, "db": False, @@ -1274,9 +1276,14 @@ class AlterColumn(Expression): "using": False, "default": False, "drop": False, + "comment": False, } +class RenameColumn(Expression): + arg_types = {"this": True, "to": True, "exists": False} + + class RenameTable(Expression): pass @@ -1402,7 +1409,7 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): class GeneratedAsRowColumnConstraint(ColumnConstraintKind): - arg_types = {"start": True, "hidden": False} + arg_types = {"start": False, "hidden": False} # https://dev.mysql.com/doc/refman/8.0/en/create-table.html @@ -1667,6 +1674,7 @@ class Index(Expression): "unique": False, "primary": False, "amp": False, # teradata + "include": False, "partition_by": False, # teradata "where": False, # postgres partial indexes } @@ -2016,7 +2024,13 @@ class AutoRefreshProperty(Property): class BlockCompressionProperty(Property): - arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True} + arg_types = { + "autotemp": False, + "always": False, + "default": False, + "manual": False, + "never": False, + } class CharacterSetProperty(Property): @@ -2089,6 +2103,10 @@ class FreespaceProperty(Property): arg_types = {"this": True, "percent": False} +class InheritsProperty(Property): + arg_types = {"expressions": True} + + class InputModelProperty(Property): arg_types = {"this": True} @@ -2099,11 +2117,11 @@ class OutputModelProperty(Property): class IsolatedLoadingProperty(Property): arg_types = { - "no": True, - "concurrent": True, - "for_all": True, - "for_insert": True, - "for_none": True, + "no": False, + "concurrent": False, + "for_all": False, + "for_insert": False, + "for_none": False, } @@ -2264,6 +2282,10 @@ class SetProperty(Property): arg_types = {"multi": True} +class SetConfigProperty(Property): + arg_types = {"this": True} + + class SettingsProperty(Property): arg_types = {"expressions": True} @@ -2407,13 +2429,16 @@ class Tuple(Expression): this=maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], query=maybe_parse(query, copy=copy, **opts) if query else None, - unnest=Unnest( - expressions=[ - maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest) - ] - ) - if unnest - else None, + unnest=( + Unnest( + expressions=[ + maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) + for e in ensure_list(unnest) + ] + ) + if unnest + else None + ), ) @@ -3631,6 +3656,8 @@ class DataType(Expression): class Type(AutoName): ARRAY = auto() + AGGREGATEFUNCTION = auto() + SIMPLEAGGREGATEFUNCTION = auto() BIGDECIMAL = auto() BIGINT = auto() BIGSERIAL = auto() @@ -4162,6 +4189,10 @@ class AtTimeZone(Expression): arg_types = {"this": True, "zone": True} +class FromTimeZone(Expression): + arg_types = {"this": True, "zone": True} + + class Between(Predicate): arg_types = {"this": True, "low": True, "high": True} @@ -5456,8 +5487,7 @@ def maybe_parse( prefix: t.Optional[str] = None, copy: bool = False, **opts, -) -> E: - ... +) -> E: ... @t.overload @@ -5469,8 +5499,7 @@ def maybe_parse( prefix: t.Optional[str] = None, copy: bool = False, **opts, -) -> E: - ... +) -> E: ... def maybe_parse( @@ -5522,13 +5551,11 @@ def maybe_parse( @t.overload -def maybe_copy(instance: None, copy: bool = True) -> None: - ... +def maybe_copy(instance: None, copy: bool = True) -> None: ... @t.overload -def maybe_copy(instance: E, copy: bool = True) -> E: - ... +def maybe_copy(instance: E, copy: bool = True) -> E: ... def maybe_copy(instance, copy=True): @@ -6151,15 +6178,13 @@ SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") @t.overload -def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: - ... +def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ... @t.overload def to_identifier( name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True -) -> Identifier: - ... +) -> Identifier: ... def to_identifier(name, quoted=None, copy=True): @@ -6231,13 +6256,11 @@ def to_interval(interval: str | Literal) -> Interval: @t.overload -def to_table(sql_path: str | Table, **kwargs) -> Table: - ... +def to_table(sql_path: str | Table, **kwargs) -> Table: ... @t.overload -def to_table(sql_path: None, **kwargs) -> None: - ... +def to_table(sql_path: None, **kwargs) -> None: ... def to_table( @@ -6562,6 +6585,34 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable: ) +def rename_column( + table_name: str | Table, + old_column_name: str | Column, + new_column_name: str | Column, + exists: t.Optional[bool] = None, +) -> AlterTable: + """Build ALTER TABLE... RENAME COLUMN... expression + + Args: + table_name: Name of the table + old_column: The old name of the column + new_column: The new name of the column + exists: Whether or not to add the `IF EXISTS` clause + + Returns: + Alter table expression + """ + table = to_table(table_name) + old_column = to_column(old_column_name) + new_column = to_column(new_column_name) + return AlterTable( + this=table, + actions=[ + RenameColumn(this=old_column, to=new_column, exists=exists), + ], + ) + + def convert(value: t.Any, copy: bool = False) -> Expression: """Convert a python value into an expression object. @@ -6581,7 +6632,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression: if isinstance(value, bool): return Boolean(this=value) if value is None or (isinstance(value, float) and math.isnan(value)): - return NULL + return null() if isinstance(value, numbers.Number): return Literal.number(value) if isinstance(value, datetime.datetime): @@ -6674,9 +6725,11 @@ def table_name(table: Table | str, dialect: DialectType = None, identify: bool = raise ValueError(f"Cannot parse {table}") return ".".join( - part.sql(dialect=dialect, identify=True, copy=False) - if identify or not SAFE_IDENTIFIER_RE.match(part.name) - else part.name + ( + part.sql(dialect=dialect, identify=True, copy=False) + if identify or not SAFE_IDENTIFIER_RE.match(part.name) + else part.name + ) for part in table.parts ) @@ -6942,9 +6995,3 @@ def null() -> Null: Returns a Null expression. """ return Null() - - -# TODO: deprecate this -TRUE = Boolean(this=True) -FALSE = Boolean(this=False) -NULL = Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 977185f..8e3ff9b 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -77,6 +77,7 @@ class Generator: exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), exp.ExternalProperty: lambda self, e: "EXTERNAL", exp.HeapProperty: lambda self, e: "HEAP", + exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}", exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}", @@ -96,6 +97,7 @@ class Generator: exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", + exp.SetConfigProperty: lambda self, e: self.sql(e, "this"), exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", exp.SqlReadWriteProperty: lambda self, e: e.name, exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", @@ -323,6 +325,7 @@ class Generator: exp.FileFormatProperty: exp.Properties.Location.POST_WITH, exp.FreespaceProperty: exp.Properties.Location.POST_NAME, exp.HeapProperty: exp.Properties.Location.POST_WITH, + exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA, exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA, exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, exp.JournalProperty: exp.Properties.Location.POST_NAME, @@ -353,6 +356,7 @@ class Generator: exp.Set: exp.Properties.Location.POST_SCHEMA, exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA, exp.SetProperty: exp.Properties.Location.POST_CREATE, + exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, @@ -568,9 +572,11 @@ class Generator: def wrap(self, expression: exp.Expression | str) -> str: this_sql = self.indent( - self.sql(expression) - if isinstance(expression, (exp.Select, exp.Union)) - else self.sql(expression, "this"), + ( + self.sql(expression) + if isinstance(expression, (exp.Select, exp.Union)) + else self.sql(expression, "this") + ), level=1, pad=0, ) @@ -605,9 +611,11 @@ class Generator: lines = sql.split("\n") return "\n".join( - line - if (skip_first and i == 0) or (skip_last and i == len(lines) - 1) - else f"{' ' * (level * self._indent + pad)}{line}" + ( + line + if (skip_first and i == 0) or (skip_last and i == len(lines) - 1) + else f"{' ' * (level * self._indent + pad)}{line}" + ) for i, line in enumerate(lines) ) @@ -775,7 +783,7 @@ class Generator: def generatedasrowcolumnconstraint_sql( self, expression: exp.GeneratedAsRowColumnConstraint ) -> str: - start = "START" if expression.args["start"] else "END" + start = "START" if expression.args.get("start") else "END" hidden = " HIDDEN" if expression.args.get("hidden") else "" return f"GENERATED ALWAYS AS ROW {start}{hidden}" @@ -1111,7 +1119,10 @@ class Generator: partition_by = self.expressions(expression, key="partition_by", flat=True) partition_by = f" PARTITION BY {partition_by}" if partition_by else "" where = self.sql(expression, "where") - return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}{where}" + include = self.expressions(expression, key="include", flat=True) + if include: + include = f" INCLUDE ({include})" + return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}" def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name @@ -2017,9 +2028,11 @@ class Generator: def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: return [ self.sql(expression, "qualify"), - self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) - if expression.args.get("windows") - else "", + ( + self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) + if expression.args.get("windows") + else "" + ), self.sql(expression, "distribute"), self.sql(expression, "sort"), self.sql(expression, "cluster"), @@ -2552,6 +2565,11 @@ class Generator: zone = self.sql(expression, "zone") return f"{this} AT TIME ZONE {zone}" + def fromtimezone_sql(self, expression: exp.FromTimeZone) -> str: + this = self.sql(expression, "this") + zone = self.sql(expression, "zone") + return f"{this} AT TIME ZONE {zone} AT TIME ZONE 'UTC'" + def add_sql(self, expression: exp.Add) -> str: return self.binary(expression, "+") @@ -2669,6 +2687,10 @@ class Generator: if default: return f"ALTER COLUMN {this} SET DEFAULT {default}" + comment = self.sql(expression, "comment") + if comment: + return f"ALTER COLUMN {this} COMMENT {comment}" + if not expression.args.get("drop"): self.unsupported("Unsupported ALTER COLUMN syntax") @@ -2683,6 +2705,12 @@ class Generator: this = self.sql(expression, "this") return f"RENAME TO {this}" + def renamecolumn_sql(self, expression: exp.RenameColumn) -> str: + exists = " IF EXISTS" if expression.args.get("exists") else "" + old_column = self.sql(expression, "this") + new_column = self.sql(expression, "to") + return f"RENAME COLUMN{exists} {old_column} TO {new_column}" + def altertable_sql(self, expression: exp.AlterTable) -> str: actions = expression.args["actions"] diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 349c8c8..de737be 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -53,13 +53,11 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: @t.overload -def ensure_list(value: t.Collection[T]) -> t.List[T]: - ... +def ensure_list(value: t.Collection[T]) -> t.List[T]: ... @t.overload -def ensure_list(value: T) -> t.List[T]: - ... +def ensure_list(value: T) -> t.List[T]: ... def ensure_list(value): @@ -81,13 +79,11 @@ def ensure_list(value): @t.overload -def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: - ... +def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... @t.overload -def ensure_collection(value: T) -> t.Collection[T]: - ... +def ensure_collection(value: T) -> t.Collection[T]: ... def ensure_collection(value): diff --git a/sqlglot/jsonpath.py b/sqlglot/jsonpath.py new file mode 100644 index 0000000..c410d11 --- /dev/null +++ b/sqlglot/jsonpath.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import typing as t + +from sqlglot.errors import ParseError +from sqlglot.expressions import SAFE_IDENTIFIER_RE +from sqlglot.tokens import Token, Tokenizer, TokenType + +if t.TYPE_CHECKING: + from sqlglot._typing import Lit + + +class JSONPathTokenizer(Tokenizer): + SINGLE_TOKENS = { + "(": TokenType.L_PAREN, + ")": TokenType.R_PAREN, + "[": TokenType.L_BRACKET, + "]": TokenType.R_BRACKET, + ":": TokenType.COLON, + ",": TokenType.COMMA, + "-": TokenType.DASH, + ".": TokenType.DOT, + "?": TokenType.PLACEHOLDER, + "@": TokenType.PARAMETER, + "'": TokenType.QUOTE, + '"': TokenType.QUOTE, + "$": TokenType.DOLLAR, + "*": TokenType.STAR, + } + + KEYWORDS = { + "..": TokenType.DOT, + } + + IDENTIFIER_ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] + + +JSONPathNode = t.Dict[str, t.Any] + + +def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode: + node = {"kind": kind, **kwargs} + + if value is not None: + node["value"] = value + + return node + + +def parse(path: str) -> t.List[JSONPathNode]: + """Takes in a JSONPath string and converts into a list of nodes.""" + tokens = JSONPathTokenizer().tokenize(path) + size = len(tokens) + + i = 0 + + def _curr() -> t.Optional[TokenType]: + return tokens[i].token_type if i < size else None + + def _prev() -> Token: + return tokens[i - 1] + + def _advance() -> Token: + nonlocal i + i += 1 + return _prev() + + def _error(msg: str) -> str: + return f"{msg} at index {i}: {path}" + + @t.overload + def _match(token_type: TokenType, raise_unmatched: Lit[True] = True) -> Token: + pass + + @t.overload + def _match(token_type: TokenType, raise_unmatched: Lit[False] = False) -> t.Optional[Token]: + pass + + def _match(token_type, raise_unmatched=False): + if _curr() == token_type: + return _advance() + if raise_unmatched: + raise ParseError(_error(f"Expected {token_type}")) + return None + + def _parse_literal() -> t.Any: + token = _match(TokenType.STRING) or _match(TokenType.IDENTIFIER) + if token: + return token.text + if _match(TokenType.STAR): + return _node("wildcard") + if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN): + script = _prev().text == "(" + start = i + + while True: + if _match(TokenType.L_BRACKET): + _parse_bracket() # nested call which we can throw away + if _curr() in (TokenType.R_BRACKET, None): + break + _advance() + return _node( + "script" if script else "filter", path[tokens[start].start : tokens[i].end] + ) + + number = "-" if _match(TokenType.DASH) else "" + + token = _match(TokenType.NUMBER) + if token: + number += token.text + + if number: + return int(number) + return False + + def _parse_slice() -> t.Any: + start = _parse_literal() + end = _parse_literal() if _match(TokenType.COLON) else None + step = _parse_literal() if _match(TokenType.COLON) else None + + if end is None and step is None: + return start + return _node("slice", start=start, end=end, step=step) + + def _parse_bracket() -> JSONPathNode: + literal = _parse_slice() + + if isinstance(literal, str) or literal is not False: + indexes = [literal] + while _match(TokenType.COMMA): + literal = _parse_slice() + + if literal: + indexes.append(literal) + + if len(indexes) == 1: + if isinstance(literal, str): + node = _node("key", indexes[0]) + elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"): + node = _node("selector", indexes[0]) + else: + node = _node("subscript", indexes[0]) + else: + node = _node("union", indexes) + else: + raise ParseError(_error("Cannot have empty segment")) + + _match(TokenType.R_BRACKET, raise_unmatched=True) + + return node + + nodes = [] + + while _curr(): + if _match(TokenType.DOLLAR): + nodes.append(_node("root")) + elif _match(TokenType.DOT): + recursive = _prev().text == ".." + value = _match(TokenType.VAR) or _match(TokenType.STAR) + nodes.append( + _node("recursive" if recursive else "child", value=value.text if value else None) + ) + elif _match(TokenType.L_BRACKET): + nodes.append(_parse_bracket()) + elif _match(TokenType.VAR): + nodes.append(_node("key", _prev().text)) + elif _match(TokenType.STAR): + nodes.append(_node("wildcard")) + elif _match(TokenType.PARAMETER): + nodes.append(_node("current")) + else: + raise ParseError(_error(f"Unexpected {tokens[i].token_type}")) + + return nodes + + +MAPPING = { + "child": lambda n: f".{n['value']}" if n.get("value") is not None else "", + "filter": lambda n: f"?{n['value']}", + "key": lambda n: ( + f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]' + ), + "recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..", + "root": lambda _: "$", + "script": lambda n: f"({n['value']}", + "slice": lambda n: ":".join( + "" if p is False else generate([p]) + for p in [n["start"], n["end"], n["step"]] + if p is not None + ), + "selector": lambda n: f"[{generate([n['value']])}]", + "subscript": lambda n: f"[{generate([n['value']])}]", + "union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]", + "wildcard": lambda _: "*", +} + + +def generate( + nodes: t.List[JSONPathNode], + mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None, +) -> str: + mapping = MAPPING if mapping is None else mapping + path = [] + + for node in nodes: + if isinstance(node, dict): + path.append(mapping[node["kind"]](node)) + elif isinstance(node, str): + escaped = node.replace('"', '\\"') + path.append(f'"{escaped}"') + else: + path.append(str(node)) + + return "".join(path) diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 09bf201..bdd1d14 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -41,9 +41,9 @@ class Node: else: label = node.expression.sql(pretty=True, dialect=dialect) source = node.source.transform( - lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") - if n is node.expression - else n, + lambda n: ( + exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n + ), copy=False, ).sql(pretty=True, dialect=dialect) title = f"<pre>{source}</pre>" diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index d0168d5..a2a86cd 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -4,7 +4,6 @@ import functools import typing as t from sqlglot import exp -from sqlglot._typing import E from sqlglot.helper import ( ensure_list, is_date_unit, @@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema if t.TYPE_CHECKING: - B = t.TypeVar("B", bound=exp.Binary) + from sqlglot._typing import B, E BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] BinaryCoercions = t.Dict[ @@ -480,6 +479,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return self._annotate_args(expression) @t.no_type_check + def _annotate_struct_value( + self, expression: exp.Expression + ) -> t.Optional[exp.DataType] | exp.ColumnDef: + alias = expression.args.get("alias") + if alias: + return exp.ColumnDef(this=alias.copy(), kind=expression.type) + + # Case: key = value or key := value + if expression.expression: + return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type) + + return expression.type + + @t.no_type_check def _annotate_by_args( self, expression: E, @@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ) if struct: - expressions = [ - expr.type - if not expr.args.get("alias") - else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type) - for expr in expressions - ] - self._set_type( expression, - exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True), + exp.DataType( + this=exp.DataType.Type.STRUCT, + expressions=[self._annotate_struct_value(expr) for expr in expressions], + nested=True, + ), ) return expression diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 3361a33..f2a0990 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -3,18 +3,18 @@ from __future__ import annotations import typing as t from sqlglot import exp -from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType +if t.TYPE_CHECKING: + from sqlglot._typing import E + @t.overload -def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: - ... +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ... @t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: - ... +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... def normalize_identifiers(expression, dialect=None): diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index a6397ae..1656727 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -4,7 +4,6 @@ import itertools import typing as t from sqlglot import alias, exp -from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get @@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_ from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema +if t.TYPE_CHECKING: + from sqlglot._typing import E + def qualify_columns( expression: exp.Expression, @@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if not node: return - for column, *_ in walk_in_scope(node): + for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star): if not isinstance(column, exp.Column): continue @@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: selection = alias( selection, alias=selection.output_name or f"_col_{i}", + copy=False, ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index e0fe641..d460e81 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -4,12 +4,14 @@ import itertools import typing as t from sqlglot import alias, exp -from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema +if t.TYPE_CHECKING: + from sqlglot._typing import E + def qualify_tables( expression: E, @@ -46,6 +48,18 @@ def qualify_tables( db = exp.parse_identifier(db, dialect=dialect) if db else None catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None + def _qualify(table: exp.Table) -> None: + if isinstance(table.this, exp.Identifier): + if not table.args.get("db"): + table.set("db", db) + if not table.args.get("catalog") and table.args.get("db"): + table.set("catalog", catalog) + + if not isinstance(expression, exp.Subqueryable): + for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)): + if isinstance(node, exp.Table): + _qualify(node) + for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): if isinstance(derived_table, exp.Subquery): @@ -66,11 +80,7 @@ def qualify_tables( for name, source in scope.sources.items(): if isinstance(source, exp.Table): - if isinstance(source.this, exp.Identifier): - if not source.args.get("db"): - source.set("db", db) - if not source.args.get("catalog") and source.args.get("db"): - source.set("catalog", catalog) + _qualify(source) pivots = pivots = source.args.get("pivots") if not source.alias: @@ -107,5 +117,14 @@ def qualify_tables( if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) + else: + for node, parent, _ in scope.walk(): + if ( + isinstance(node, exp.Table) + and not node.alias + and isinstance(parent, (exp.From, exp.Join)) + ): + # Mutates the table by attaching an alias to it + alias(node, node.name, copy=False, table=True) return expression diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index a3f08d5..16cd548 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -323,9 +323,14 @@ class Scope: sources in the current scope. """ if self._external_columns is None: - self._external_columns = [ - c for c in self.columns if c.table not in self.selected_sources - ] + if isinstance(self.expression, exp.Union): + left, right = self.union_scopes + self._external_columns = left.external_columns + right.external_columns + else: + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] + return self._external_columns @property @@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Args: expression (exp.Expression): expression to traverse + Returns: list[Scope]: scope instances """ if isinstance(expression, exp.Unionable) or ( - isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) + isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable) ): return list(_traverse_scope(Scope(expression))) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 25d4e75..d5b9119 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1068,9 +1068,11 @@ def extract_interval(expression): def date_literal(date): return exp.cast( exp.Literal.string(date), - exp.DataType.Type.DATETIME - if isinstance(date, datetime.datetime) - else exp.DataType.Type.DATE, + ( + exp.DataType.Type.DATETIME + if isinstance(date, datetime.datetime) + else exp.DataType.Type.DATE + ), ) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 4d35175..26f4159 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name): ): return + clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) + # This subquery returns a scalar and can just be converted to a cross join if not isinstance(predicate, (exp.In, exp.Any)): column = exp.column(select.selects[0].alias_or_name, alias) - clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) clause_parent_select = clause.parent_select if clause else None if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( @@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name): column = _other_operand(predicate) value = select.selects[0] - on = exp.condition(f'{column} = "{alias}"."{value.alias}"') - _replace(predicate, f"NOT {on.right} IS NULL") + join_key = exp.column(value.alias, alias) + join_key_not_null = join_key.is_(exp.null()).not_() + + if isinstance(clause, exp.Join): + _replace(predicate, exp.true()) + parent_select.where(join_key_not_null, copy=False) + else: + _replace(predicate, join_key_not_null) parent_select.join( select.group_by(value.this, copy=False), - on=on, + on=column.eq(join_key), join_type="LEFT", join_alias=alias, copy=False, diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 790ee0d..c091605 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -12,9 +12,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: - from typing_extensions import Literal - - from sqlglot._typing import E + from sqlglot._typing import E, Lit from sqlglot.dialects.dialect import Dialect, DialectType logger = logging.getLogger("sqlglot") @@ -148,6 +146,11 @@ class Parser(metaclass=_Parser): TokenType.ENUM16, } + AGGREGATE_TYPE_TOKENS = { + TokenType.AGGREGATEFUNCTION, + TokenType.SIMPLEAGGREGATEFUNCTION, + } + TYPE_TOKENS = { TokenType.BIT, TokenType.BOOLEAN, @@ -241,6 +244,7 @@ class Parser(metaclass=_Parser): TokenType.NULL, *ENUM_TYPE_TOKENS, *NESTED_TYPE_TOKENS, + *AGGREGATE_TYPE_TOKENS, } SIGNED_TO_UNSIGNED_TYPE_TOKEN = { @@ -653,9 +657,11 @@ class Parser(metaclass=_Parser): PLACEHOLDER_PARSERS = { TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), TokenType.PARAMETER: lambda self: self._parse_parameter(), - TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text) - if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) - else None, + TokenType.COLON: lambda self: ( + self.expression(exp.Placeholder, this=self._prev.text) + if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) + else None + ), } RANGE_PARSERS = { @@ -705,6 +711,9 @@ class Parser(metaclass=_Parser): "IMMUTABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), + "INHERITS": lambda self: self.expression( + exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table) + ), "INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()), "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), @@ -822,6 +831,7 @@ class Parser(metaclass=_Parser): ALTER_PARSERS = { "ADD": lambda self: self._parse_alter_table_add(), "ALTER": lambda self: self._parse_alter_table_alter(), + "CLUSTER BY": lambda self: self._parse_cluster(wrapped=True), "DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()), "DROP": lambda self: self._parse_alter_table_drop(), "RENAME": lambda self: self._parse_alter_table_rename(), @@ -973,6 +983,9 @@ class Parser(metaclass=_Parser): MODIFIERS_ATTACHED_TO_UNION = True UNION_MODIFIERS = {"order", "limit", "offset"} + # parses no parenthesis if statements as commands + NO_PAREN_IF_COMMANDS = True + __slots__ = ( "error_level", "error_message_context", @@ -1207,7 +1220,20 @@ class Parser(metaclass=_Parser): if index != self._index: self._advance(index - self._index) + def _warn_unsupported(self) -> None: + if len(self._tokens) <= 1: + return + + # We use _find_sql because self.sql may comprise multiple chunks, and we're only + # interested in emitting a warning for the one being currently processed. + sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context] + + logger.warning( + f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'." + ) + def _parse_command(self) -> exp.Command: + self._warn_unsupported() return self.expression( exp.Command, this=self._prev.text.upper(), expression=self._parse_string() ) @@ -1329,8 +1355,10 @@ class Parser(metaclass=_Parser): start = self._prev comments = self._prev_comments - replace = start.text.upper() == "REPLACE" or self._match_pair( - TokenType.OR, TokenType.REPLACE + replace = ( + start.token_type == TokenType.REPLACE + or self._match_pair(TokenType.OR, TokenType.REPLACE) + or self._match_pair(TokenType.OR, TokenType.ALTER) ) unique = self._match(TokenType.UNIQUE) @@ -1440,6 +1468,9 @@ class Parser(metaclass=_Parser): exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy ) + if self._curr: + return self._parse_as_command(start) + return self.expression( exp.Create, comments=comments, @@ -1516,11 +1547,13 @@ class Parser(metaclass=_Parser): return self.expression( exp.FileFormatProperty, - this=self.expression( - exp.InputOutputFormat, input_format=input_format, output_format=output_format - ) - if input_format or output_format - else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), + this=( + self.expression( + exp.InputOutputFormat, input_format=input_format, output_format=output_format + ) + if input_format or output_format + else self._parse_var_or_string() or self._parse_number() or self._parse_id_var() + ), ) def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E: @@ -1632,8 +1665,15 @@ class Parser(metaclass=_Parser): return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT)) - def _parse_cluster(self) -> exp.Cluster: - return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered)) + def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster: + return self.expression( + exp.Cluster, + expressions=( + self._parse_wrapped_csv(self._parse_ordered) + if wrapped + else self._parse_csv(self._parse_ordered) + ), + ) def _parse_clustered_by(self) -> exp.ClusteredByProperty: self._match_text_seq("BY") @@ -2681,6 +2721,8 @@ class Parser(metaclass=_Parser): else: columns = None + include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None + return self.expression( exp.Index, this=index, @@ -2690,6 +2732,7 @@ class Parser(metaclass=_Parser): unique=unique, primary=primary, amp=amp, + include=include, partition_by=self._parse_partition_by(), where=self._parse_where(), ) @@ -3380,8 +3423,8 @@ class Parser(metaclass=_Parser): def _parse_comparison(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_range, self.COMPARISON) - def _parse_range(self) -> t.Optional[exp.Expression]: - this = self._parse_bitwise() + def _parse_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: + this = this or self._parse_bitwise() negate = self._match(TokenType.NOT) if self._match_set(self.RANGE_PARSERS): @@ -3535,14 +3578,21 @@ class Parser(metaclass=_Parser): return self._parse_tokens(self._parse_factor, self.TERM) def _parse_factor(self) -> t.Optional[exp.Expression]: - if self.EXPONENT: - factor = self._parse_tokens(self._parse_exponent, self.FACTOR) - else: - factor = self._parse_tokens(self._parse_unary, self.FACTOR) - if isinstance(factor, exp.Div): - factor.args["typed"] = self.dialect.TYPED_DIVISION - factor.args["safe"] = self.dialect.SAFE_DIVISION - return factor + parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary + this = parse_method() + + while self._match_set(self.FACTOR): + this = self.expression( + self.FACTOR[self._prev.token_type], + this=this, + comments=self._prev_comments, + expression=parse_method(), + ) + if isinstance(this, exp.Div): + this.args["typed"] = self.dialect.TYPED_DIVISION + this.args["safe"] = self.dialect.SAFE_DIVISION + + return this def _parse_exponent(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_unary, self.EXPONENT) @@ -3617,6 +3667,7 @@ class Parser(metaclass=_Parser): return exp.DataType.build(type_name, udt=True) else: + self._retreat(self._index - 1) return None else: return None @@ -3631,6 +3682,7 @@ class Parser(metaclass=_Parser): nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token in self.STRUCT_TYPE_TOKENS + is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS expressions = None maybe_func = False @@ -3645,6 +3697,18 @@ class Parser(metaclass=_Parser): ) elif type_token in self.ENUM_TYPE_TOKENS: expressions = self._parse_csv(self._parse_equality) + elif is_aggregate: + func_or_ident = self._parse_function(anonymous=True) or self._parse_id_var( + any_token=False, tokens=(TokenType.VAR,) + ) + if not func_or_ident or not self._match(TokenType.COMMA): + return None + expressions = self._parse_csv( + lambda: self._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) + ) + expressions.insert(0, func_or_ident) else: expressions = self._parse_csv(self._parse_type_size) @@ -4413,6 +4477,10 @@ class Parser(metaclass=_Parser): self._match_r_paren() else: index = self._index - 1 + + if self.NO_PAREN_IF_COMMANDS and index == 0: + return self._parse_as_command(self._prev) + condition = self._parse_conjunction() if not condition: @@ -4624,12 +4692,10 @@ class Parser(metaclass=_Parser): return None @t.overload - def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: - ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... @t.overload - def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: - ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... def _parse_json_object(self, agg=False): star = self._parse_star() @@ -4974,11 +5040,12 @@ class Parser(metaclass=_Parser): if alias: this = self.expression(exp.Alias, comments=comments, this=this, alias=alias) + column = this.this # Moves the comment next to the alias in `expr /* comment */ AS alias` - if not this.comments and this.this.comments: - this.comments = this.this.comments - this.this.comments = None + if not this.comments and column and column.comments: + this.comments = column.comments + column.comments = None return this @@ -5244,7 +5311,7 @@ class Parser(metaclass=_Parser): if self._match_text_seq("CHECK"): expression = self._parse_wrapped(self._parse_conjunction) - enforced = self._match_text_seq("ENFORCED") + enforced = self._match_text_seq("ENFORCED") or False return self.expression( exp.AddConstraint, this=this, expression=expression, enforced=enforced @@ -5278,6 +5345,8 @@ class Parser(metaclass=_Parser): return self.expression(exp.AlterColumn, this=column, drop=True) if self._match_pair(TokenType.SET, TokenType.DEFAULT): return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction()) + if self._match(TokenType.COMMENT): + return self.expression(exp.AlterColumn, this=column, comment=self._parse_string()) self._match_text_seq("SET", "DATA") return self.expression( @@ -5298,7 +5367,18 @@ class Parser(metaclass=_Parser): self._retreat(index) return self._parse_csv(self._parse_drop_column) - def _parse_alter_table_rename(self) -> exp.RenameTable: + def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]: + if self._match(TokenType.COLUMN): + exists = self._parse_exists() + old_column = self._parse_column() + to = self._match_text_seq("TO") + new_column = self._parse_column() + + if old_column is None or to is None or new_column is None: + return None + + return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists) + self._match_text_seq("TO") return self.expression(exp.RenameTable, this=self._parse_table(schema=True)) @@ -5319,7 +5399,7 @@ class Parser(metaclass=_Parser): if parser: actions = ensure_list(parser(self)) - if not self._curr: + if not self._curr and actions: return self.expression( exp.AlterTable, this=this, @@ -5467,6 +5547,7 @@ class Parser(metaclass=_Parser): self._advance() text = self._find_sql(start, self._prev) size = len(start.text) + self._warn_unsupported() return exp.Command(this=text[:size], expression=text[size:]) def _parse_dict_property(self, this: str) -> exp.DictProperty: @@ -5634,7 +5715,7 @@ class Parser(metaclass=_Parser): if advance: self._advance() return True - return False + return None def _match_text_seq(self, *texts, advance=True): index = self._index @@ -5643,7 +5724,7 @@ class Parser(metaclass=_Parser): self._advance() else: self._retreat(index) - return False + return None if not advance: self._retreat(index) @@ -5651,14 +5732,12 @@ class Parser(metaclass=_Parser): return True @t.overload - def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: - ... + def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ... @t.overload def _replace_columns_with_dots( self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - ... + ) -> t.Optional[exp.Expression]: ... def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 8acd89f..13f72d6 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -106,6 +106,19 @@ class Schema(abc.ABC): name = column if isinstance(column, str) else column.name return name in self.column_names(table, dialect=dialect, normalize=normalize) + @abc.abstractmethod + def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: + """ + Returns the schema of a given table. + + Args: + table: the target table. + raise_on_missing: whether or not to raise in case the schema is not found. + + Returns: + The schema of the target table. + """ + @property @abc.abstractmethod def supported_table_args(self) -> t.Tuple[str, ...]: @@ -156,11 +169,9 @@ class AbstractMappingSchema: return [table.this.name] return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] - def find( - self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True - ) -> t.Optional[t.Any]: + def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: parts = self.table_parts(table)[0 : len(self.supported_table_args)] - value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) + value, trie = in_trie(self.mapping_trie, parts) if value == TrieResult.FAILED: return None diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index d8fb98b..8a363d2 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -191,6 +191,8 @@ class TokenType(AutoName): FIXEDSTRING = auto() LOWCARDINALITY = auto() NESTED = auto() + AGGREGATEFUNCTION = auto() + SIMPLEAGGREGATEFUNCTION = auto() UNKNOWN = auto() # keywords |