diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-10 08:53:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-10 08:53:10 +0000 |
commit | f7cb7fdb0fb5a8e2d053c1aa18dd98462401a64e (patch) | |
tree | 75bbd792c82b8d1e70b5561de82a5b270b61867c /sqlglot | |
parent | Adding upstream version 18.11.2. (diff) | |
download | sqlglot-f7cb7fdb0fb5a8e2d053c1aa18dd98462401a64e.tar.xz sqlglot-f7cb7fdb0fb5a8e2d053c1aa18dd98462401a64e.zip |
Adding upstream version 18.11.6.upstream/18.11.6
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 19 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 1 | ||||
-rw-r--r-- | sqlglot/expressions.py | 28 | ||||
-rw-r--r-- | sqlglot/generator.py | 25 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize_identifiers.py | 4 | ||||
-rw-r--r-- | sqlglot/parser.py | 39 | ||||
-rw-r--r-- | sqlglot/tokens.py | 1 |
11 files changed, 121 insertions, 13 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index d0ae50c..9ab00d5 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -656,7 +656,7 @@ def unix_timestamp( def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: tz_column = tz if isinstance(tz, Column) else lit(tz) - return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column) + return Column.invoke_expression_over_column(timestamp, expression.AtTimeZone, zone=tz_column) def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 0d741b5..7f69dd9 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -260,15 +260,16 @@ class BigQuery(Dialect): "ANY TYPE": TokenType.VARIANT, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, - "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "BYTES": TokenType.BINARY, + "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "DECLARE": TokenType.COMMAND, "FLOAT64": TokenType.DOUBLE, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, "INT64": TokenType.BIGINT, + "MODEL": TokenType.MODEL, + "NOT DETERMINISTIC": TokenType.VOLATILE, "RECORD": TokenType.STRUCT, "TIMESTAMP": TokenType.TIMESTAMPTZ, - "NOT DETERMINISTIC": TokenType.VOLATILE, - "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } KEYWORDS.pop("DIV") diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 008727c..c435309 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -205,7 +205,7 @@ def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: for when in expression.expressions: when.transform( - lambda node: exp.column(node.name) + lambda node: exp.column(node.this) if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets else node, copy=False, @@ -439,6 +439,8 @@ class Postgres(Dialect): exp.TryCast: no_trycast_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", + exp.VariancePop: rename_func("VAR_POP"), + exp.Variance: rename_func("VAR_SAMP"), exp.Xor: bool_xor_sql, } diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 88e4448..b70a8a1 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -31,6 +31,7 @@ class Redshift(Postgres): RESOLVES_IDENTIFIERS_AS_UPPERCASE = None SUPPORTS_USER_DEFINED_TYPES = False + INDEX_OFFSET = 0 TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" TIME_MAPPING = { @@ -57,6 +58,24 @@ class Redshift(Postgres): "STRTOL": exp.FromBase.from_arg_list, } + def _parse_table( + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, + parse_bracket: bool = False, + ) -> t.Optional[exp.Expression]: + # Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr` + unpivot = self._match(TokenType.UNPIVOT) + table = super()._parse_table( + schema=schema, + joins=joins, + alias_tokens=alias_tokens, + parse_bracket=parse_bracket, + ) + + return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 4130375..2fd4f4e 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -117,6 +117,10 @@ class Spark2(Hive): "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DOUBLE": _parse_as_cast("double"), "FLOAT": _parse_as_cast("float"), + "FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone( + this=exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("timestamp")), + zone=seq_get(args, 1), + ), "IIF": exp.If.from_arg_list, "INT": _parse_as_cast("int"), "MAP_FROM_ARRAYS": exp.Map.from_arg_list, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 6aa49e4..d8bea6d 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -598,6 +598,7 @@ class TSQL(Dialect): exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", + exp.DataType.Type.DOUBLE: "FLOAT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.TIMESTAMP: "DATETIME2", exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1e4aad6..80f1c0f 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2040,8 +2040,12 @@ class FreespaceProperty(Property): arg_types = {"this": True, "percent": False} -class InputOutputFormat(Expression): - arg_types = {"input_format": False, "output_format": False} +class InputModelProperty(Property): + arg_types = {"this": True} + + +class OutputModelProperty(Property): + arg_types = {"this": True} class IsolatedLoadingProperty(Property): @@ -2137,6 +2141,10 @@ class PartitionedByProperty(Property): arg_types = {"this": True} +class RemoteWithConnectionModelProperty(Property): + arg_types = {"this": True} + + class ReturnsProperty(Property): arg_types = {"this": True, "is_table": False, "table": False} @@ -2211,6 +2219,10 @@ class TemporaryProperty(Property): arg_types = {} +class TransformModelProperty(Property): + arg_types = {"expressions": True} + + class TransientProperty(Property): arg_types = {"this": False} @@ -2293,6 +2305,10 @@ class Qualify(Expression): pass +class InputOutputFormat(Expression): + arg_types = {"input_format": False, "output_format": False} + + # https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql class Return(Expression): pass @@ -2465,6 +2481,7 @@ class Table(Expression): "version": False, "format": False, "pattern": False, + "index": False, } @property @@ -3431,7 +3448,7 @@ class Pivot(Expression): arg_types = { "this": False, "alias": False, - "expressions": True, + "expressions": False, "field": False, "unpivot": False, "using": False, @@ -4777,6 +4794,11 @@ class Posexplode(Func): pass +# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function +class Predict(Func): + arg_types = {"this": True, "expression": True, "params_struct": False} + + class Pow(Binary, Func): _sql_names = ["POWER", "POW"] diff --git a/sqlglot/generator.py b/sqlglot/generator.py index edc6939..7a2879c 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -73,6 +73,7 @@ class Generator: exp.ExternalProperty: lambda self, e: "EXTERNAL", exp.HeapProperty: lambda self, e: "HEAP", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", + exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}", exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), @@ -84,7 +85,9 @@ class Generator: exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}", exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", + exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}", exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", + exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}", exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", @@ -94,6 +97,7 @@ class Generator: exp.TemporaryProperty: lambda self, e: f"TEMPORARY", exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", exp.TransientProperty: lambda self, e: "TRANSIENT", + exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions), exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE", exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), @@ -278,6 +282,7 @@ class Generator: exp.FileFormatProperty: exp.Properties.Location.POST_WITH, exp.FreespaceProperty: exp.Properties.Location.POST_NAME, exp.HeapProperty: exp.Properties.Location.POST_WITH, + exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA, exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, exp.JournalProperty: exp.Properties.Location.POST_NAME, exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA, @@ -291,9 +296,11 @@ class Generator: exp.OnProperty: exp.Properties.Location.POST_SCHEMA, exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, exp.Order: exp.Properties.Location.POST_SCHEMA, + exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, exp.Property: exp.Properties.Location.POST_WITH, + exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, @@ -310,6 +317,7 @@ class Generator: exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA, exp.TransientProperty: exp.Properties.Location.POST_CREATE, + exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA, exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, @@ -1350,13 +1358,17 @@ class Generator: pivots = f" {pivots}" if pivots else "" joins = self.expressions(expression, key="joins", sep="", skip_first=True) laterals = self.expressions(expression, key="laterals", sep="") + file_format = self.sql(expression, "format") if file_format: pattern = self.sql(expression, "pattern") pattern = f", PATTERN => {pattern}" if pattern else "" file_format = f" (FILE_FORMAT => {file_format}{pattern})" - return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}" + index = self.sql(expression, "index") + index = f" AT {index}" if index else "" + + return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1401,6 +1413,9 @@ class Generator: if expression.this: this = self.sql(expression, "this") + if not expressions: + return f"UNPIVOT {this}" + on = f"{self.seg('ON')} {expressions}" using = self.expressions(expression, key="using", flat=True) using = f"{self.seg('USING')} {using}" if using else "" @@ -2880,6 +2895,14 @@ class Generator: def opclass_sql(self, expression: exp.Opclass) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def predict_sql(self, expression: exp.Predict) -> str: + model = self.sql(expression, "this") + model = f"MODEL {model}" + table = self.sql(expression, "expression") + table = f"TABLE {table}" if not isinstance(expression.expression, exp.Subquery) else table + parameters = self.sql(expression, "params_struct") + return self.func("PREDICT", model, table, parameters or None) + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 32f3a92..ecea6a0 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp +from sqlglot import exp, parse_one from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType @@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None): The transformed expression. """ if isinstance(expression, str): - expression = exp.to_identifier(expression) + expression = parse_one(expression, dialect=dialect, into=exp.Identifier) dialect = Dialect.get_or_raise(dialect) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5e56961..510abfb 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -236,6 +236,7 @@ class Parser(metaclass=_Parser): TokenType.SCHEMA, TokenType.TABLE, TokenType.VIEW, + TokenType.MODEL, TokenType.DICTIONARY, } @@ -649,6 +650,7 @@ class Parser(metaclass=_Parser): "IMMUTABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), + "INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()), "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), "LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"), @@ -664,11 +666,13 @@ class Parser(metaclass=_Parser): "NO": lambda self: self._parse_no_property(), "ON": lambda self: self._parse_on_property(), "ORDER BY": lambda self: self._parse_order(skip_order_token=True), + "OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()), "PARTITION BY": lambda self: self._parse_partitioned_by(), "PARTITIONED BY": lambda self: self._parse_partitioned_by(), "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), "PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True), "RANGE": lambda self: self._parse_dict_range(this="RANGE"), + "REMOTE": lambda self: self._parse_remote_with_connection(), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), @@ -690,6 +694,9 @@ class Parser(metaclass=_Parser): "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), "TO": lambda self: self._parse_to_table(), "TRANSIENT": lambda self: self.expression(exp.TransientProperty), + "TRANSFORM": lambda self: self.expression( + exp.TransformModelProperty, expressions=self._parse_wrapped_csv(self._parse_expression) + ), "TTL": lambda self: self._parse_ttl(), "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "VOLATILE": lambda self: self._parse_volatile_property(), @@ -789,6 +796,7 @@ class Parser(metaclass=_Parser): "MATCH": lambda self: self._parse_match_against(), "OPENJSON": lambda self: self._parse_open_json(), "POSITION": lambda self: self._parse_position(), + "PREDICT": lambda self: self._parse_predict(), "SAFE_CAST": lambda self: self._parse_cast(False), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), @@ -1787,6 +1795,12 @@ class Parser(metaclass=_Parser): exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) + def _parse_remote_with_connection(self) -> exp.RemoteWithConnectionModelProperty: + self._match_text_seq("WITH", "CONNECTION") + return self.expression( + exp.RemoteWithConnectionModelProperty, this=self._parse_table_parts() + ) + def _parse_returns(self) -> exp.ReturnsProperty: value: t.Optional[exp.Expression] is_table = self._match(TokenType.TABLE) @@ -2622,7 +2636,9 @@ class Parser(metaclass=_Parser): bracket = parse_bracket and self._parse_bracket(None) bracket = self.expression(exp.Table, this=bracket) if bracket else None - this: exp.Expression = bracket or self._parse_table_parts(schema=schema) + this = t.cast( + exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema)) + ) if schema: return self._parse_schema(this=this) @@ -2639,6 +2655,9 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) + if self._match_text_seq("AT"): + this.set("index", self._parse_id_var()) + this.set("hints", self._parse_table_hints()) if not this.args.get("pivots"): @@ -3886,7 +3905,9 @@ class Parser(metaclass=_Parser): def _parse_unnamed_constraint( self, constraints: t.Optional[t.Collection[str]] = None ) -> t.Optional[exp.Expression]: - if not self._match_texts(constraints or self.CONSTRAINT_PARSERS): + if self._match(TokenType.IDENTIFIER, advance=False) or not self._match_texts( + constraints or self.CONSTRAINT_PARSERS + ): return None constraint = self._prev.text.upper() @@ -4402,6 +4423,20 @@ class Parser(metaclass=_Parser): exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2) ) + def _parse_predict(self) -> exp.Predict: + self._match_text_seq("MODEL") + this = self._parse_table() + + self._match(TokenType.COMMA) + self._match_text_seq("TABLE") + + return self.expression( + exp.Predict, + this=this, + expression=self._parse_table(), + params_struct=self._match(TokenType.COMMA) and self._parse_bitwise(), + ) + def _parse_join_hint(self, func_name: str) -> exp.JoinHint: args = self._parse_csv(self._parse_table) return exp.JoinHint(this=func_name.upper(), expressions=args) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 080a86b..4ab01dd 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -263,6 +263,7 @@ class TokenType(AutoName): MEMBER_OF = auto() MERGE = auto() MOD = auto() + MODEL = auto() NATURAL = auto() NEXT = auto() NOTNULL = auto() |