summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-10 08:53:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-10 08:53:14 +0000
commitcd37a3bcaced9283c20baa52837c96b524baec54 (patch)
tree101b1c1487aa832a982dd635cd3b00d4d2ea3ae9 /sqlglot
parentReleasing progress-linux version 18.11.2-1. (diff)
downloadsqlglot-cd37a3bcaced9283c20baa52837c96b524baec54.tar.xz
sqlglot-cd37a3bcaced9283c20baa52837c96b524baec54.zip
Merging upstream version 18.11.6.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dialects/bigquery.py7
-rw-r--r--sqlglot/dialects/postgres.py4
-rw-r--r--sqlglot/dialects/redshift.py19
-rw-r--r--sqlglot/dialects/spark2.py4
-rw-r--r--sqlglot/dialects/tsql.py1
-rw-r--r--sqlglot/expressions.py28
-rw-r--r--sqlglot/generator.py25
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py4
-rw-r--r--sqlglot/parser.py39
-rw-r--r--sqlglot/tokens.py1
11 files changed, 121 insertions, 13 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index d0ae50c..9ab00d5 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -656,7 +656,7 @@ def unix_timestamp(
def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
tz_column = tz if isinstance(tz, Column) else lit(tz)
- return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column)
+ return Column.invoke_expression_over_column(timestamp, expression.AtTimeZone, zone=tz_column)
def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 0d741b5..7f69dd9 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -260,15 +260,16 @@ class BigQuery(Dialect):
"ANY TYPE": TokenType.VARIANT,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
- "CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"BYTES": TokenType.BINARY,
+ "CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"DECLARE": TokenType.COMMAND,
"FLOAT64": TokenType.DOUBLE,
+ "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"INT64": TokenType.BIGINT,
+ "MODEL": TokenType.MODEL,
+ "NOT DETERMINISTIC": TokenType.VOLATILE,
"RECORD": TokenType.STRUCT,
"TIMESTAMP": TokenType.TIMESTAMPTZ,
- "NOT DETERMINISTIC": TokenType.VOLATILE,
- "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
KEYWORDS.pop("DIV")
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 008727c..c435309 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -205,7 +205,7 @@ def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
for when in expression.expressions:
when.transform(
- lambda node: exp.column(node.name)
+ lambda node: exp.column(node.this)
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
else node,
copy=False,
@@ -439,6 +439,8 @@ class Postgres(Dialect):
exp.TryCast: no_trycast_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
+ exp.VariancePop: rename_func("VAR_POP"),
+ exp.Variance: rename_func("VAR_SAMP"),
exp.Xor: bool_xor_sql,
}
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 88e4448..b70a8a1 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -31,6 +31,7 @@ class Redshift(Postgres):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
SUPPORTS_USER_DEFINED_TYPES = False
+ INDEX_OFFSET = 0
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
@@ -57,6 +58,24 @@ class Redshift(Postgres):
"STRTOL": exp.FromBase.from_arg_list,
}
+ def _parse_table(
+ self,
+ schema: bool = False,
+ joins: bool = False,
+ alias_tokens: t.Optional[t.Collection[TokenType]] = None,
+ parse_bracket: bool = False,
+ ) -> t.Optional[exp.Expression]:
+ # Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
+ unpivot = self._match(TokenType.UNPIVOT)
+ table = super()._parse_table(
+ schema=schema,
+ joins=joins,
+ alias_tokens=alias_tokens,
+ parse_bracket=parse_bracket,
+ )
+
+ return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
+
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 4130375..2fd4f4e 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -117,6 +117,10 @@ class Spark2(Hive):
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
+ "FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
+ this=exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("timestamp")),
+ zone=seq_get(args, 1),
+ ),
"IIF": exp.If.from_arg_list,
"INT": _parse_as_cast("int"),
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 6aa49e4..d8bea6d 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -598,6 +598,7 @@ class TSQL(Dialect):
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
+ exp.DataType.Type.DOUBLE: "FLOAT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.TIMESTAMP: "DATETIME2",
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 1e4aad6..80f1c0f 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -2040,8 +2040,12 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
-class InputOutputFormat(Expression):
- arg_types = {"input_format": False, "output_format": False}
+class InputModelProperty(Property):
+ arg_types = {"this": True}
+
+
+class OutputModelProperty(Property):
+ arg_types = {"this": True}
class IsolatedLoadingProperty(Property):
@@ -2137,6 +2141,10 @@ class PartitionedByProperty(Property):
arg_types = {"this": True}
+class RemoteWithConnectionModelProperty(Property):
+ arg_types = {"this": True}
+
+
class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
@@ -2211,6 +2219,10 @@ class TemporaryProperty(Property):
arg_types = {}
+class TransformModelProperty(Property):
+ arg_types = {"expressions": True}
+
+
class TransientProperty(Property):
arg_types = {"this": False}
@@ -2293,6 +2305,10 @@ class Qualify(Expression):
pass
+class InputOutputFormat(Expression):
+ arg_types = {"input_format": False, "output_format": False}
+
+
# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql
class Return(Expression):
pass
@@ -2465,6 +2481,7 @@ class Table(Expression):
"version": False,
"format": False,
"pattern": False,
+ "index": False,
}
@property
@@ -3431,7 +3448,7 @@ class Pivot(Expression):
arg_types = {
"this": False,
"alias": False,
- "expressions": True,
+ "expressions": False,
"field": False,
"unpivot": False,
"using": False,
@@ -4777,6 +4794,11 @@ class Posexplode(Func):
pass
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function
+class Predict(Func):
+ arg_types = {"this": True, "expression": True, "params_struct": False}
+
+
class Pow(Binary, Func):
_sql_names = ["POWER", "POW"]
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index edc6939..7a2879c 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -73,6 +73,7 @@ class Generator:
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
+ exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
@@ -84,7 +85,9 @@ class Generator:
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
+ exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
+ exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
@@ -94,6 +97,7 @@ class Generator:
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
+ exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
@@ -278,6 +282,7 @@ class Generator:
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
+ exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
@@ -291,9 +296,11 @@ class Generator:
exp.OnProperty: exp.Properties.Location.POST_SCHEMA,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.Order: exp.Properties.Location.POST_SCHEMA,
+ exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
+ exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
@@ -310,6 +317,7 @@ class Generator:
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
+ exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
@@ -1350,13 +1358,17 @@ class Generator:
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")
+
file_format = self.sql(expression, "format")
if file_format:
pattern = self.sql(expression, "pattern")
pattern = f", PATTERN => {pattern}" if pattern else ""
file_format = f" (FILE_FORMAT => {file_format}{pattern})"
- return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}"
+ index = self.sql(expression, "index")
+ index = f" AT {index}" if index else ""
+
+ return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@@ -1401,6 +1413,9 @@ class Generator:
if expression.this:
this = self.sql(expression, "this")
+ if not expressions:
+ return f"UNPIVOT {this}"
+
on = f"{self.seg('ON')} {expressions}"
using = self.expressions(expression, key="using", flat=True)
using = f"{self.seg('USING')} {using}" if using else ""
@@ -2880,6 +2895,14 @@ class Generator:
def opclass_sql(self, expression: exp.Opclass) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
+ def predict_sql(self, expression: exp.Predict) -> str:
+ model = self.sql(expression, "this")
+ model = f"MODEL {model}"
+ table = self.sql(expression, "expression")
+ table = f"TABLE {table}" if not isinstance(expression.expression, exp.Subquery) else table
+ parameters = self.sql(expression, "params_struct")
+ return self.func("PREDICT", model, table, parameters or None)
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 32f3a92..ecea6a0 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp
+from sqlglot import exp, parse_one
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None):
The transformed expression.
"""
if isinstance(expression, str):
- expression = exp.to_identifier(expression)
+ expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
dialect = Dialect.get_or_raise(dialect)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 5e56961..510abfb 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -236,6 +236,7 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
+ TokenType.MODEL,
TokenType.DICTIONARY,
}
@@ -649,6 +650,7 @@ class Parser(metaclass=_Parser):
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
+ "INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()),
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"),
@@ -664,11 +666,13 @@ class Parser(metaclass=_Parser):
"NO": lambda self: self._parse_no_property(),
"ON": lambda self: self._parse_on_property(),
"ORDER BY": lambda self: self._parse_order(skip_order_token=True),
+ "OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
"RANGE": lambda self: self._parse_dict_range(this="RANGE"),
+ "REMOTE": lambda self: self._parse_remote_with_connection(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
@@ -690,6 +694,9 @@ class Parser(metaclass=_Parser):
"TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
"TO": lambda self: self._parse_to_table(),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
+ "TRANSFORM": lambda self: self.expression(
+ exp.TransformModelProperty, expressions=self._parse_wrapped_csv(self._parse_expression)
+ ),
"TTL": lambda self: self._parse_ttl(),
"USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"VOLATILE": lambda self: self._parse_volatile_property(),
@@ -789,6 +796,7 @@ class Parser(metaclass=_Parser):
"MATCH": lambda self: self._parse_match_against(),
"OPENJSON": lambda self: self._parse_open_json(),
"POSITION": lambda self: self._parse_position(),
+ "PREDICT": lambda self: self._parse_predict(),
"SAFE_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
@@ -1787,6 +1795,12 @@ class Parser(metaclass=_Parser):
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
+ def _parse_remote_with_connection(self) -> exp.RemoteWithConnectionModelProperty:
+ self._match_text_seq("WITH", "CONNECTION")
+ return self.expression(
+ exp.RemoteWithConnectionModelProperty, this=self._parse_table_parts()
+ )
+
def _parse_returns(self) -> exp.ReturnsProperty:
value: t.Optional[exp.Expression]
is_table = self._match(TokenType.TABLE)
@@ -2622,7 +2636,9 @@ class Parser(metaclass=_Parser):
bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
- this: exp.Expression = bracket or self._parse_table_parts(schema=schema)
+ this = t.cast(
+ exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema))
+ )
if schema:
return self._parse_schema(this=this)
@@ -2639,6 +2655,9 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
+ if self._match_text_seq("AT"):
+ this.set("index", self._parse_id_var())
+
this.set("hints", self._parse_table_hints())
if not this.args.get("pivots"):
@@ -3886,7 +3905,9 @@ class Parser(metaclass=_Parser):
def _parse_unnamed_constraint(
self, constraints: t.Optional[t.Collection[str]] = None
) -> t.Optional[exp.Expression]:
- if not self._match_texts(constraints or self.CONSTRAINT_PARSERS):
+ if self._match(TokenType.IDENTIFIER, advance=False) or not self._match_texts(
+ constraints or self.CONSTRAINT_PARSERS
+ ):
return None
constraint = self._prev.text.upper()
@@ -4402,6 +4423,20 @@ class Parser(metaclass=_Parser):
exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2)
)
+ def _parse_predict(self) -> exp.Predict:
+ self._match_text_seq("MODEL")
+ this = self._parse_table()
+
+ self._match(TokenType.COMMA)
+ self._match_text_seq("TABLE")
+
+ return self.expression(
+ exp.Predict,
+ this=this,
+ expression=self._parse_table(),
+ params_struct=self._match(TokenType.COMMA) and self._parse_bitwise(),
+ )
+
def _parse_join_hint(self, func_name: str) -> exp.JoinHint:
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 080a86b..4ab01dd 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -263,6 +263,7 @@ class TokenType(AutoName):
MEMBER_OF = auto()
MERGE = auto()
MOD = auto()
+ MODEL = auto()
NATURAL = auto()
NEXT = auto()
NOTNULL = auto()