From 8bec55350caa5c760d8b7e7e2d0ba6c77a32bc71 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 8 Feb 2023 05:14:34 +0100 Subject: Merging upstream version 10.6.3. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 66 +++++++++-- sqlglot/dataframe/sql/column.py | 6 +- sqlglot/dataframe/sql/functions.py | 20 ++-- sqlglot/dialects/bigquery.py | 8 -- sqlglot/dialects/dialect.py | 10 ++ sqlglot/dialects/drill.py | 5 +- sqlglot/dialects/duckdb.py | 12 +- sqlglot/dialects/hive.py | 16 +-- sqlglot/dialects/mysql.py | 17 +-- sqlglot/dialects/oracle.py | 4 + sqlglot/dialects/postgres.py | 45 ++++--- sqlglot/dialects/presto.py | 32 ++++- sqlglot/dialects/redshift.py | 11 +- sqlglot/dialects/snowflake.py | 9 -- sqlglot/dialects/spark.py | 37 ++++++ sqlglot/dialects/tableau.py | 1 - sqlglot/dialects/teradata.py | 8 ++ sqlglot/dialects/tsql.py | 2 + sqlglot/diff.py | 9 +- sqlglot/executor/__init__.py | 61 +++++++--- sqlglot/executor/env.py | 1 + sqlglot/executor/table.py | 7 +- sqlglot/expressions.py | 158 ++++++++++++++++++++++--- sqlglot/generator.py | 187 +++++++++++++++++++++--------- sqlglot/lineage.py | 7 +- sqlglot/optimizer/eliminate_subqueries.py | 2 +- sqlglot/optimizer/scope.py | 2 +- sqlglot/optimizer/simplify.py | 6 +- sqlglot/parser.py | 122 +++++++++++++++---- sqlglot/schema.py | 3 +- sqlglot/tokens.py | 1 + 31 files changed, 647 insertions(+), 228 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index bfcabb3..714897f 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -33,7 +33,13 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema, Schema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.6.0" +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + + T = t.TypeVar("T", bound=Expression) + + +__version__ = "10.6.3" pretty = False """Whether to format generated SQL by default.""" @@ -42,9 +48,7 @@ schema = MappingSchema() """The default schema used by SQLGlot (e.g. in the optimizer).""" -def parse( - sql: str, read: t.Optional[str | Dialect] = None, **opts -) -> t.List[t.Optional[Expression]]: +def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]: """ Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. @@ -60,9 +64,57 @@ def parse( return dialect.parse(sql, **opts) +@t.overload +def parse_one( + sql: str, + read: None = None, + into: t.Type[T] = ..., + **opts, +) -> T: + ... + + +@t.overload +def parse_one( + sql: str, + read: DialectType, + into: t.Type[T], + **opts, +) -> T: + ... + + +@t.overload +def parse_one( + sql: str, + read: None = None, + into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ..., + **opts, +) -> Expression: + ... + + +@t.overload +def parse_one( + sql: str, + read: DialectType, + into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]], + **opts, +) -> Expression: + ... + + +@t.overload +def parse_one( + sql: str, + **opts, +) -> Expression: + ... + + def parse_one( sql: str, - read: t.Optional[str | Dialect] = None, + read: DialectType = None, into: t.Optional[exp.IntoType] = None, **opts, ) -> Expression: @@ -96,8 +148,8 @@ def parse_one( def transpile( sql: str, - read: t.Optional[str | Dialect] = None, - write: t.Optional[str | Dialect] = None, + read: DialectType = None, + write: DialectType = None, identity: bool = True, error_level: t.Optional[ErrorLevel] = None, **opts, diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 40ffe3e..f5b0974 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -260,11 +260,7 @@ class Column: """ if isinstance(dataType, DataType): dataType = dataType.simpleString() - new_expression = exp.Cast( - this=self.column_expression, - to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore - ) - return Column(new_expression) + return Column(exp.cast(self.column_expression, dataType, dialect="spark")) def startswith(self, value: t.Union[str, Column]) -> Column: value = self._lit(value) if not isinstance(value, Column) else value diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index a141fe4..47d5e7b 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -536,15 +536,15 @@ def month(col: ColumnOrName) -> Column: def dayofweek(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DAYOFWEEK") + return Column.invoke_expression_over_column(col, glotexp.DayOfWeek) def dayofmonth(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DAYOFMONTH") + return Column.invoke_expression_over_column(col, glotexp.DayOfMonth) def dayofyear(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DAYOFYEAR") + return Column.invoke_expression_over_column(col, glotexp.DayOfYear) def hour(col: ColumnOrName) -> Column: @@ -560,7 +560,7 @@ def second(col: ColumnOrName) -> Column: def weekofyear(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "WEEKOFYEAR") + return Column.invoke_expression_over_column(col, glotexp.WeekOfYear) def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: @@ -1144,10 +1144,16 @@ def aggregate( merge_exp = _get_lambda_from_func(merge) if finish is not None: finish_exp = _get_lambda_from_func(finish) - return Column.invoke_anonymous_function( - col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp) + return Column.invoke_expression_over_column( + col, + glotexp.Reduce, + initial=initialValue, + merge=Column(merge_exp), + finish=Column(finish_exp), ) - return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) + return Column.invoke_expression_over_column( + col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp) + ) def transform( diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 27dca48..90ae229 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -222,14 +222,6 @@ class BigQuery(Dialect): exp.DataType.Type.NVARCHAR: "STRING", } - ROOT_PROPERTIES = { - exp.LanguageProperty, - exp.ReturnsProperty, - exp.VolatilityProperty, - } - - WITH_PROPERTIES = {exp.Property} - EXPLICIT_UNION = True def array_sql(self, expression: exp.Array) -> str: diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0c2beba..1b20e0a 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect): def get_or_raise(cls, dialect): if not dialect: return cls + if isinstance(dialect, _Dialect): + return dialect + if isinstance(dialect, Dialect): + return dialect.__class__ + result = cls.get(dialect) if not result: raise ValueError(f"Unknown dialect '{dialect}'") + return result @classmethod @@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect): ) +if t.TYPE_CHECKING: + DialectType = t.Union[str, Dialect, t.Type[Dialect], None] + + def rename_func(name): def _rename(self, expression): args = flatten(expression.args.values()) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 4e3c0e1..d0a0251 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -137,7 +137,10 @@ class Drill(Dialect): exp.DataType.Type.DATETIME: "TIMESTAMP", } - ROOT_PROPERTIES = {exp.PartitionedByProperty} + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + } TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 4646eb4..95ff95c 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -20,10 +20,6 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _unix_to_time(self, expression): - return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))" - - def _str_to_time_sql(self, expression): return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" @@ -113,7 +109,7 @@ class DuckDB(Dialect): "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRUCT_PACK": exp.Struct.from_arg_list, - "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, "UNNEST": exp.Explode.from_arg_list, } @@ -162,9 +158,9 @@ class DuckDB(Dialect): exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, - exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, - exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)", + exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToTime: rename_func("TO_TIMESTAMP"), + exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", } TYPE_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 4bbec70..f2b6eaa 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -322,17 +322,11 @@ class Hive(Dialect): exp.LastDateOfMonth: rename_func("LAST_DAY"), } - WITH_PROPERTIES = {exp.Property} - - ROOT_PROPERTIES = { - exp.PartitionedByProperty, - exp.FileFormatProperty, - exp.SchemaCommentProperty, - exp.LocationProperty, - exp.TableFormatProperty, - exp.RowFormatDelimitedProperty, - exp.RowFormatSerdeProperty, - exp.SerdeProperties, + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, } def with_properties(self, properties): diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index cd8c30c..a5bd86b 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,7 +1,5 @@ from __future__ import annotations -import typing as t - from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, @@ -98,6 +96,8 @@ def _date_add_sql(kind): class MySQL(Dialect): + time_format = "'%Y-%m-%d %T'" + # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions time_mapping = { "%M": "%B", @@ -110,6 +110,7 @@ class MySQL(Dialect): "%u": "%W", "%k": "%-H", "%l": "%-I", + "%T": "%H:%M:%S", } class Tokenizer(tokens.Tokenizer): @@ -428,6 +429,7 @@ class MySQL(Dialect): ) class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False TRANSFORMS = { @@ -449,23 +451,12 @@ class MySQL(Dialect): exp.StrPosition: strposition_to_locate_sql, } - ROOT_PROPERTIES = { - exp.EngineProperty, - exp.AutoIncrementProperty, - exp.CharacterSetProperty, - exp.CollateProperty, - exp.SchemaCommentProperty, - exp.LikeProperty, - } - TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) - WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() - def show_sql(self, expression): this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 67d791d..fde845e 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -44,6 +44,8 @@ class Oracle(Dialect): } class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "NUMBER", @@ -69,6 +71,7 @@ class Oracle(Dialect): exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", + exp.Substring: rename_func("SUBSTR"), } def query_modifiers(self, expression, *sqls): @@ -90,6 +93,7 @@ class Oracle(Dialect): self.sql(expression, "order"), self.sql(expression, "offset"), # offset before limit in oracle self.sql(expression, "limit"), + self.sql(expression, "lock"), sep="", ) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 0d74b3a..6418032 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -148,6 +148,22 @@ def _serial_to_generated(expression): return expression +def _generate_series(args): + # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day + step = seq_get(args, 2) + + if step is None: + # Postgres allows calls with just two arguments -- the "step" argument defaults to 1 + return exp.GenerateSeries.from_arg_list(args) + + if step.is_string: + args[2] = exp.to_interval(step.this) + elif isinstance(step, exp.Interval) and not step.args.get("unit"): + args[2] = exp.to_interval(step.this.this) + + return exp.GenerateSeries.from_arg_list(args) + + def _to_timestamp(args): # TO_TIMESTAMP accepts either a single double argument or (text, text) if len(args) == 1: @@ -195,29 +211,6 @@ class Postgres(Dialect): HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] - CREATABLES = ( - "AGGREGATE", - "CAST", - "CONVERSION", - "COLLATION", - "DEFAULT CONVERSION", - "CONSTRAINT", - "DOMAIN", - "EXTENSION", - "FOREIGN", - "FUNCTION", - "OPERATOR", - "POLICY", - "ROLE", - "RULE", - "SEQUENCE", - "TEXT", - "TRIGGER", - "TYPE", - "UNLOGGED", - "USER", - ) - KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "~~": TokenType.LIKE, @@ -243,8 +236,6 @@ class Postgres(Dialect): "TEMP": TokenType.TEMPORARY, "UUID": TokenType.UUID, "CSTRING": TokenType.PSEUDO_TYPE, - **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, - **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } QUOTES = ["'", "$$"] SINGLE_TOKENS = { @@ -257,8 +248,10 @@ class Postgres(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore + "NOW": exp.CurrentTimestamp.from_arg_list, "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), + "GENERATE_SERIES": _generate_series, } BITWISE = { @@ -272,6 +265,8 @@ class Postgres(Dialect): } class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "SMALLINT", diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 8175d6f..6c1a474 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -105,6 +105,29 @@ def _ts_or_ds_add_sql(self, expression): return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" +def _sequence_sql(self, expression): + start = expression.args["start"] + end = expression.args["end"] + step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series + + target_type = None + + if isinstance(start, exp.Cast): + target_type = start.to + elif isinstance(end, exp.Cast): + target_type = end.to + + if target_type and target_type.this == exp.DataType.Type.TIMESTAMP: + to = target_type.copy() + + if target_type is start.to: + end = exp.Cast(this=end, to=to) + else: + start = exp.Cast(this=start, to=to) + + return f"SEQUENCE({self.format_args(start, end, step)})" + + def _ensure_utf8(charset): if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") @@ -145,7 +168,7 @@ def _from_unixtime(args): class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" - time_format = "'%Y-%m-%d %H:%i:%S'" + time_format = MySQL.time_format # type: ignore time_mapping = MySQL.time_mapping # type: ignore class Tokenizer(tokens.Tokenizer): @@ -197,7 +220,10 @@ class Presto(Dialect): class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") - ROOT_PROPERTIES = {exp.SchemaCommentProperty} + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -223,6 +249,7 @@ class Presto(Dialect): exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DataType: _datatype_sql, exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", @@ -231,6 +258,7 @@ class Presto(Dialect): exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.Encode: _encode_sql, + exp.GenerateSeries: _sequence_sql, exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 7da881f..c3c99eb 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -61,14 +61,9 @@ class Redshift(Postgres): exp.DataType.Type.INT: "INTEGER", } - ROOT_PROPERTIES = { - exp.DistKeyProperty, - exp.SortKeyProperty, - exp.DistStyleProperty, - } - - WITH_PROPERTIES = { - exp.LikeProperty, + PROPERTIES_LOCATION = { + **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH, } TRANSFORMS = { diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index db72a34..3b83b02 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -234,15 +234,6 @@ class Snowflake(Dialect): "replace": "RENAME", } - ROOT_PROPERTIES = { - exp.PartitionedByProperty, - exp.ReturnsProperty, - exp.LanguageProperty, - exp.SchemaCommentProperty, - exp.ExecuteAsProperty, - exp.VolatilityProperty, - } - def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index fc711ab..8ef4a87 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -73,6 +73,19 @@ class Spark(Hive): ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "IIF": exp.If.from_arg_list, + "AGGREGATE": exp.Reduce.from_arg_list, + "DAYOFWEEK": lambda args: exp.DayOfWeek( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "DAYOFMONTH": lambda args: exp.DayOfMonth( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "DAYOFYEAR": lambda args: exp.DayOfYear( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "WEEKOFYEAR": lambda args: exp.WeekOfYear( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), } FUNCTION_PARSERS = { @@ -105,6 +118,14 @@ class Spark(Hive): exp.DataType.Type.BIGINT: "LONG", } + PROPERTIES_LOCATION = { + **Hive.Generator.PROPERTIES_LOCATION, # type: ignore + exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, + exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, + exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, + exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, + } + TRANSFORMS = { **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), @@ -126,11 +147,27 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), exp.LogicalOr: rename_func("BOOL_OR"), + exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", } TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False + def cast_sql(self, expression: exp.Cast) -> str: + if isinstance(expression.this, exp.Cast) and expression.this.is_type( + exp.DataType.Type.JSON + ): + schema = f"'{self.sql(expression, 'to')}'" + return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})" + if expression.to.is_type(exp.DataType.Type.JSON): + return f"TO_JSON({self.sql(expression, 'this')})" + + return super(Spark.Generator, self).cast_sql(expression) + class Tokenizer(Hive.Tokenizer): HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 36c085f..31b1c8d 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -31,6 +31,5 @@ class Tableau(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "IFNULL": exp.Coalesce.from_arg_list, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 4340820..123da04 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -76,6 +76,14 @@ class Teradata(Dialect): ) class Generator(generator.Generator): + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, + } + + def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: + return f"PARTITION BY {self.sql(expression, 'this')}" + # FROM before SET in Teradata UPDATE syntax # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause def update_sql(self, expression: exp.Update) -> str: diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 9f9099e..05ba53a 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -412,6 +412,8 @@ class TSQL(Dialect): return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "BIT", diff --git a/sqlglot/diff.py b/sqlglot/diff.py index a5373b0..7d5ec21 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -14,10 +14,6 @@ from sqlglot import Dialect from sqlglot import expressions as exp from sqlglot.helper import ensure_collection -if t.TYPE_CHECKING: - T = t.TypeVar("T") - Edit = t.Union[Insert, Remove, Move, Update, Keep] - @dataclass(frozen=True) class Insert: @@ -56,6 +52,11 @@ class Keep: target: exp.Expression +if t.TYPE_CHECKING: + T = t.TypeVar("T") + Edit = t.Union[Insert, Remove, Move, Update, Keep] + + def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: """ Returns the list of changes between the source and the target expressions. diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index 04621b5..67b4b00 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -1,5 +1,13 @@ +""" +.. include:: ../../posts/python_sql_engine.md +---- +""" + +from __future__ import annotations + import logging import time +import typing as t from sqlglot import maybe_parse from sqlglot.errors import ExecuteError @@ -11,42 +19,63 @@ from sqlglot.schema import ensure_schema logger = logging.getLogger("sqlglot") +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + from sqlglot.executor.table import Tables + from sqlglot.expressions import Expression + from sqlglot.schema import Schema -def execute(sql, schema=None, read=None, tables=None): + +def execute( + sql: str | Expression, + schema: t.Optional[t.Dict | Schema] = None, + read: DialectType = None, + tables: t.Optional[t.Dict] = None, +) -> Table: """ Run a sql query against data. Args: - sql (str|sqlglot.Expression): a sql statement - schema (dict|sqlglot.optimizer.Schema): database schema. - This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of - the following forms: - 1. {table: {col: type}} - 2. {db: {table: {col: type}}} - 3. {catalog: {db: {table: {col: type}}}} - read (str): the SQL dialect to apply during parsing - (eg. "spark", "hive", "presto", "mysql"). - tables (dict): additional tables to register. + sql: a sql statement. + schema: database schema. + This can either be an instance of `Schema` or a mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + tables: additional tables to register. + Returns: - sqlglot.executor.Table: Simple columnar data structure. + Simple columnar data structure. """ - tables = ensure_tables(tables) + tables_ = ensure_tables(tables) + if not schema: schema = { name: {column: type(table[0][column]).__name__ for column in table.columns} - for name, table in tables.mapping.items() + for name, table in tables_.mapping.items() } + schema = ensure_schema(schema) - if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args: + + if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args: raise ExecuteError("Tables must support the same table args as schema") + expression = maybe_parse(sql, dialect=read) + now = time.time() expression = optimize(expression, schema, leave_tables_isolated=True) + logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) + plan = Plan(expression) + logger.debug("Logical Plan: %s", plan) + now = time.time() - result = PythonExecutor(tables=tables).execute(plan) + result = PythonExecutor(tables=tables_).execute(plan) + logger.debug("Query finished: %f", time.time() - now) + return result diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 04dc938..ba9cbbd 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -171,5 +171,6 @@ ENV = { "STRPOSITION": str_position, "SUB": null_if_any(lambda e, this: e - this), "SUBSTRING": substring, + "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)), "UPPER": null_if_any(lambda arg: arg.upper()), } diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index f1b5b54..27e3e5e 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot.helper import dict_depth from sqlglot.schema import AbstractMappingSchema @@ -106,11 +108,11 @@ class Tables(AbstractMappingSchema[Table]): pass -def ensure_tables(d: dict | None) -> Tables: +def ensure_tables(d: t.Optional[t.Dict]) -> Tables: return Tables(_ensure_tables(d)) -def _ensure_tables(d: dict | None) -> dict: +def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict: if not d: return {} @@ -127,4 +129,5 @@ def _ensure_tables(d: dict | None) -> dict: columns = tuple(table[0]) if table else () rows = [tuple(row[c] for c in columns) for row in table] result[name] = Table(columns=columns, rows=rows) + return result diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7c1a116..6bb083a 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -32,13 +32,7 @@ from sqlglot.helper import ( from sqlglot.tokens import Token if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import Dialect - - IntoType = t.Union[ - str, - t.Type[Expression], - t.Collection[t.Union[str, t.Type[Expression]]], - ] + from sqlglot.dialects.dialect import DialectType class _Expression(type): @@ -427,7 +421,7 @@ class Expression(metaclass=_Expression): def __repr__(self): return self._to_s() - def sql(self, dialect: Dialect | str | None = None, **opts) -> str: + def sql(self, dialect: DialectType = None, **opts) -> str: """ Returns SQL string representation of this tree. @@ -595,6 +589,14 @@ class Expression(metaclass=_Expression): return load(obj) +if t.TYPE_CHECKING: + IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], + ] + + class Condition(Expression): def and_(self, *expressions, dialect=None, **opts): """ @@ -1285,6 +1287,18 @@ class Property(Expression): arg_types = {"this": True, "value": True} +class AlgorithmProperty(Property): + arg_types = {"this": True} + + +class DefinerProperty(Property): + arg_types = {"this": True} + + +class SqlSecurityProperty(Property): + arg_types = {"definer": True} + + class TableFormatProperty(Property): arg_types = {"this": True} @@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property): class Properties(Expression): - arg_types = {"expressions": True, "before": False} + arg_types = {"expressions": True} NAME_TO_PROPERTY = { + "ALGORITHM": AlgorithmProperty, "AUTO_INCREMENT": AutoIncrementProperty, "CHARACTER SET": CharacterSetProperty, "COLLATE": CollateProperty, "COMMENT": SchemaCommentProperty, + "DEFINER": DefinerProperty, "DISTKEY": DistKeyProperty, "DISTSTYLE": DistStyleProperty, "ENGINE": EngineProperty, @@ -1447,6 +1463,14 @@ class Properties(Expression): PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + class Location(AutoName): + POST_CREATE = auto() + PRE_SCHEMA = auto() + POST_INDEX = auto() + POST_SCHEMA_ROOT = auto() + POST_SCHEMA_WITH = auto() + UNSUPPORTED = auto() + @classmethod def from_dict(cls, properties_dict) -> Properties: expressions = [] @@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = { "order": False, "limit": False, "offset": False, + "lock": False, } @@ -1713,6 +1738,12 @@ class Schema(Expression): arg_types = {"this": False, "expressions": False} +# Used to represent the FOR UPDATE and FOR SHARE locking read types. +# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html +class Lock(Expression): + arg_types = {"update": True} + + class Select(Subqueryable): arg_types = { "with": False, @@ -2243,6 +2274,30 @@ class Select(Subqueryable): properties=properties_expression, ) + def lock(self, update: bool = True, copy: bool = True) -> Select: + """ + Set the locking read mode for this expression. + + Examples: + >>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR UPDATE" + + >>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR SHARE" + + Args: + update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`. + copy: if `False`, modify this expression instance in-place. + + Returns: + The modified expression. + """ + + inst = _maybe_copy(self, copy) + inst.set("lock", Lock(update=update)) + + return inst + @property def named_selects(self) -> t.List[str]: return [e.output_name for e in self.expressions if e.alias_or_name] @@ -2456,24 +2511,28 @@ class DataType(Expression): @classmethod def build( - cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs + cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs ) -> DataType: from sqlglot import parse_one if isinstance(dtype, str): - data_type_exp: t.Optional[Expression] if dtype.upper() in cls.Type.__members__: - data_type_exp = DataType(this=DataType.Type[dtype.upper()]) + data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()]) else: data_type_exp = parse_one(dtype, read=dialect, into=DataType) if data_type_exp is None: raise ValueError(f"Unparsable data type value: {dtype}") elif isinstance(dtype, DataType.Type): data_type_exp = DataType(this=dtype) + elif isinstance(dtype, DataType): + return dtype else: raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") return DataType(**{**data_type_exp.args, **kwargs}) + def is_type(self, dtype: DataType.Type) -> bool: + return self.this == dtype + # https://www.postgresql.org/docs/15/datatype-pseudo.html class PseudoType(Expression): @@ -2840,6 +2899,10 @@ class Array(Func): is_var_len_args = True +class GenerateSeries(Func): + arg_types = {"start": True, "end": True, "step": False} + + class ArrayAgg(AggFunc): pass @@ -2909,6 +2972,9 @@ class Cast(Func): def output_name(self): return self.name + def is_type(self, dtype: DataType.Type) -> bool: + return self.to.is_type(dtype) + class Collate(Binary): pass @@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class DayOfWeek(Func): + _sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"] + + +class DayOfMonth(Func): + _sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"] + + +class DayOfYear(Func): + _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"] + + +class WeekOfYear(Func): + _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"] + + class LastDateOfMonth(Func): pass @@ -3239,7 +3321,7 @@ class ReadCSV(Func): class Reduce(Func): - arg_types = {"this": True, "initial": True, "merge": True, "finish": True} + arg_types = {"this": True, "initial": True, "merge": True, "finish": False} class RegexpLike(Func): @@ -3476,7 +3558,7 @@ def maybe_parse( sql_or_expression: str | Expression, *, into: t.Optional[IntoType] = None, - dialect: t.Optional[str] = None, + dialect: DialectType = None, prefix: t.Optional[str] = None, **opts, ) -> Expression: @@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: return identifier +INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*") + + +def to_interval(interval: str | Literal) -> Interval: + """Builds an interval expression from a string like '1 day' or '5 months'.""" + if isinstance(interval, Literal): + if not interval.is_string: + raise ValueError("Invalid interval string.") + + interval = interval.this + + interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore + + if not interval_parts: + raise ValueError("Invalid interval string.") + + return Interval( + this=Literal.string(interval_parts.group(1)), + unit=Var(this=interval_parts.group(2)), + ) + + @t.overload def to_table(sql_path: str | Table, **kwargs) -> Table: ... @@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): def subquery(expression, alias=None, dialect=None, **opts): """ Build a subquery expression. - Expample: + + Example: >>> subquery('select x from tbl', 'bar').select('x').sql() 'SELECT x FROM (SELECT x FROM tbl) AS bar' @@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts): def column(col, table=None, quoted=None) -> Column: """ Build a Column. + Args: col (str | Expression): column name table (str | Expression): table name @@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column: ) +def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast: + """Cast an expression to a data type. + + Example: + >>> cast('x + 1', 'int').sql() + 'CAST(x + 1 AS INT)' + + Args: + expression: The expression to cast. + to: The datatype to cast to. + + Returns: + A cast node. + """ + expression = maybe_parse(expression, **opts) + return Cast(this=expression, to=DataType.build(to, **opts)) + + def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: """Build a Table. @@ -4137,7 +4261,7 @@ def values( types = list(columns.values()) expressions[0].set( "expressions", - [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)], + [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)], ) return Values( expressions=expressions, @@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True return expression.transform(_expand, copy=copy) -def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func: +def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: """ Returns a Func expression. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 3f3365a..b95e9bc 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -67,6 +67,7 @@ class Generator: exp.VolatilityProperty: lambda self, e: e.name, exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", + exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -75,6 +76,9 @@ class Generator: # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True + # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported + LOCKING_READS_SUPPORTED = False + # Always do union distinct or union all EXPLICIT_UNION = False @@ -99,34 +103,42 @@ class Generator: STRUCT_DELIMITER = ("<", ">") - BEFORE_PROPERTIES = { - exp.FallbackProperty, - exp.WithJournalTableProperty, - exp.LogProperty, - exp.JournalProperty, - exp.AfterJournalProperty, - exp.ChecksumProperty, - exp.FreespaceProperty, - exp.MergeBlockRatioProperty, - exp.DataBlocksizeProperty, - exp.BlockCompressionProperty, - exp.IsolatedLoadingProperty, - } - - ROOT_PROPERTIES = { - exp.ReturnsProperty, - exp.LanguageProperty, - exp.DistStyleProperty, - exp.DistKeyProperty, - exp.SortKeyProperty, - exp.LikeProperty, - } - - WITH_PROPERTIES = { - exp.Property, - exp.FileFormatProperty, - exp.PartitionedByProperty, - exp.TableFormatProperty, + PROPERTIES_LOCATION = { + exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, + exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA, + exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA, + exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA, + exp.DefinerProperty: exp.Properties.Location.POST_CREATE, + exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA, + exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA, + exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA, + exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LogProperty: exp.Properties.Location.PRE_SCHEMA, + exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.Property: exp.Properties.Location.POST_SCHEMA_WITH, + exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, + exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA, } WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) @@ -284,10 +296,10 @@ class Generator: ) return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" - def no_identify(self, func: t.Callable[[], str]) -> str: + def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str: original = self.identify self.identify = False - result = func() + result = func(*args, **kwargs) self.identify = original return result @@ -455,19 +467,33 @@ class Generator: def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() - has_before_properties = expression.args.get("properties") - has_before_properties = ( - has_before_properties.args.get("before") if has_before_properties else None - ) - if kind == "TABLE" and has_before_properties: + properties = expression.args.get("properties") + properties_exp = expression.copy() + properties_locs = self.locate_properties(properties) if properties else {} + if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get( + exp.Properties.Location.POST_SCHEMA_WITH + ): + properties_exp.set( + "properties", + exp.Properties( + expressions=[ + *properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT], + *properties_locs[exp.Properties.Location.POST_SCHEMA_WITH], + ] + ), + ) + if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA): this_name = self.sql(expression.this, "this") - this_properties = self.sql(expression, "properties") + this_properties = self.properties( + exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]), + wrapped=False, + ) this_schema = f"({self.expressions(expression.this)})" this = f"{this_name}, {this_properties} {this_schema}" - properties = "" + properties_sql = "" else: this = self.sql(expression, "this") - properties = self.sql(expression, "properties") + properties_sql = self.sql(properties_exp, "properties") begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" @@ -514,11 +540,31 @@ class Generator: if index.args.get("columns") else "" ) + if index.args.get("primary") and properties_locs.get( + exp.Properties.Location.POST_INDEX + ): + postindex_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_INDEX] + ), + wrapped=False, + ) + ind_columns = f"{ind_columns} {postindex_props_sql}" + indexes_sql.append( f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" ) index_sql = "".join(indexes_sql) + postcreate_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_CREATE): + postcreate_props_sql = self.properties( + exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]), + sep=" ", + prefix=" ", + wrapped=False, + ) + modifiers = "".join( ( replace, @@ -531,6 +577,7 @@ class Generator: multiset, global_temporary, volatile, + postcreate_props_sql, ) ) no_schema_binding = ( @@ -539,7 +586,7 @@ class Generator: post_expression_modifiers = "".join((data, statistics, no_primary_index)) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression: exp.Describe) -> str: @@ -665,24 +712,19 @@ class Generator: return f"PARTITION({self.expressions(expression)})" def properties_sql(self, expression: exp.Properties) -> str: - before_properties = [] root_properties = [] with_properties = [] for p in expression.expressions: - p_class = p.__class__ - if p_class in self.BEFORE_PROPERTIES: - before_properties.append(p) - elif p_class in self.WITH_PROPERTIES: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc == exp.Properties.Location.POST_SCHEMA_WITH: with_properties.append(p) - elif p_class in self.ROOT_PROPERTIES: + elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: root_properties.append(p) - return ( - self.properties(exp.Properties(expressions=before_properties), before=True) - + self.root_properties(exp.Properties(expressions=root_properties)) - + self.with_properties(exp.Properties(expressions=with_properties)) - ) + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) def root_properties(self, properties: exp.Properties) -> str: if properties.expressions: @@ -695,17 +737,41 @@ class Generator: prefix: str = "", sep: str = ", ", suffix: str = "", - before: bool = False, + wrapped: bool = True, ) -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - expressions = expressions if before else self.wrap(expressions) + expressions = self.wrap(expressions) if wrapped else expressions return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" return "" def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("WITH")) + def locate_properties( + self, properties: exp.Properties + ) -> t.Dict[exp.Properties.Location, list[exp.Property]]: + properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = { + key: [] for key in exp.Properties.Location + } + + for p in properties.expressions: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc == exp.Properties.Location.PRE_SCHEMA: + properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p) + elif p_loc == exp.Properties.Location.POST_INDEX: + properties_locs[exp.Properties.Location.POST_INDEX].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: + properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH: + properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p) + elif p_loc == exp.Properties.Location.POST_CREATE: + properties_locs[exp.Properties.Location.POST_CREATE].append(p) + elif p_loc == exp.Properties.Location.UNSUPPORTED: + self.unsupported(f"Unsupported property {p.key}") + + return properties_locs + def property_sql(self, expression: exp.Property) -> str: property_cls = expression.__class__ if property_cls == exp.Property: @@ -713,7 +779,7 @@ class Generator: property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) if not property_name: - self.unsupported(f"Unsupported property {property_name}") + self.unsupported(f"Unsupported property {expression.key}") return f"{property_name}={self.sql(expression, 'this')}" @@ -975,7 +1041,7 @@ class Generator: rollup = self.expressions(expression, key="rollup", indent=False) rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" - return f"{group_by}{grouping_sets}{cube}{rollup}" + return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}" def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) @@ -1015,7 +1081,7 @@ class Generator: def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args - return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") + return f"{args} {arrow_sep} {self.sql(expression, 'this')}" def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") @@ -1043,6 +1109,14 @@ class Generator: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + def lock_sql(self, expression: exp.Lock) -> str: + if self.LOCKING_READS_SUPPORTED: + lock_type = "UPDATE" if expression.args["update"] else "SHARE" + return self.seg(f"FOR {lock_type}") + + self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") + return "" + def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: @@ -1163,6 +1237,7 @@ class Generator: self.sql(expression, "order"), self.sql(expression, "limit"), self.sql(expression, "offset"), + self.sql(expression, "lock"), sep="", ) @@ -1773,7 +1848,7 @@ class Generator: def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") - expressions = self.no_identify(lambda: self.expressions(expression)) + expressions = self.no_identify(self.expressions, expression) expressions = ( self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" ) diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 4e7eab8..a39ad8c 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -9,6 +9,9 @@ from sqlglot.optimizer import Scope, build_scope, optimize from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + @dataclass(frozen=True) class Node: @@ -36,7 +39,7 @@ def lineage( schema: t.Optional[t.Dict | Schema] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns), - dialect: t.Optional[str] = None, + dialect: DialectType = None, ) -> Node: """Build the lineage graph for a column of a SQL query. @@ -126,7 +129,7 @@ class LineageHTML: def __init__( self, node: Node, - dialect: t.Optional[str] = None, + dialect: DialectType = None, imports: bool = True, **opts: t.Any, ): diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 2245cc2..c6bea5a 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -114,7 +114,7 @@ def _eliminate_union(scope, existing_ctes, taken): taken[alias] = scope # Try to maintain the selections - expressions = scope.expression.args.get("expressions") + expressions = scope.selects selects = [ exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) for e in expressions diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 5a3ed5a..badbb87 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -300,7 +300,7 @@ class Scope: list[exp.Expression]: expressions """ if isinstance(self.expression, exp.Union): - return [] + return self.expression.unnest().selects return self.expression.selects @property diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f560760..f80484d 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -456,8 +456,10 @@ def extract_interval(interval): def date_literal(date): - expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE") - return exp.Cast(this=exp.Literal.string(date), to=expr_type) + return exp.cast( + exp.Literal.string(date), + "DATETIME" if isinstance(date, datetime.datetime) else "DATE", + ) def boolean_literal(condition): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 6229105..e2b2c54 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -80,6 +80,7 @@ class Parser(metaclass=_Parser): length=exp.Literal.number(10), ), "VAR_MAP": parse_var_map, + "IFNULL": exp.Coalesce.from_arg_list, } NO_PAREN_FUNCTIONS = { @@ -567,6 +568,8 @@ class Parser(metaclass=_Parser): default=self._prev.text.upper() == "DEFAULT" ), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), + "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), + "DEFINER": lambda self: self._parse_definer(), } CONSTRAINT_PARSERS = { @@ -608,6 +611,7 @@ class Parser(metaclass=_Parser): "order": lambda self: self._parse_order(), "limit": lambda self: self._parse_limit(), "offset": lambda self: self._parse_offset(), + "lock": lambda self: self._parse_lock(), } SHOW_PARSERS: t.Dict[str, t.Callable] = {} @@ -850,7 +854,7 @@ class Parser(metaclass=_Parser): self.raise_error(error_message) def _find_sql(self, start: Token, end: Token) -> str: - return self.sql[self._find_token(start) : self._find_token(end)] + return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)] def _find_token(self, token: Token) -> int: line = 1 @@ -901,6 +905,7 @@ class Parser(metaclass=_Parser): return expression def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: + start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text @@ -908,8 +913,7 @@ class Parser(metaclass=_Parser): if default_kind: kind = default_kind else: - self.raise_error(f"Expected {self.CREATABLES}") - return None + return self._parse_as_command(start) return self.expression( exp.Drop, @@ -929,6 +933,7 @@ class Parser(metaclass=_Parser): ) def _parse_create(self) -> t.Optional[exp.Expression]: + start = self._prev replace = self._match_pair(TokenType.OR, TokenType.REPLACE) set_ = self._match(TokenType.SET) # Teradata multiset = self._match_text_seq("MULTISET") # Teradata @@ -943,16 +948,19 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): self._match(TokenType.TABLE) + properties = None create_token = self._match_set(self.CREATABLES) and self._prev if not create_token: - self.raise_error(f"Expected {self.CREATABLES}") - return None + properties = self._parse_properties() + create_token = self._match_set(self.CREATABLES) and self._prev + + if not properties or not create_token: + return self._parse_as_command(start) exists = self._parse_exists(not_=True) this = None expression = None - properties = None data = None statistics = None no_primary_index = None @@ -1006,6 +1014,14 @@ class Parser(metaclass=_Parser): indexes = [] while True: index = self._parse_create_table_index() + + # post index PARTITION BY property + if self._match(TokenType.PARTITION_BY, advance=False): + if properties: + properties.expressions.append(self._parse_property()) + else: + properties = self._parse_properties() + if not index: break else: @@ -1040,6 +1056,9 @@ class Parser(metaclass=_Parser): ) def _parse_property_before(self) -> t.Optional[exp.Expression]: + self._match(TokenType.COMMA) + + # parsers look to _prev for no/dual/default, so need to consume first self._match_text_seq("NO") self._match_text_seq("DUAL") self._match_text_seq("DEFAULT") @@ -1059,6 +1078,9 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): return self._parse_sortkey(compound=True) + if self._match_text_seq("SQL", "SECURITY"): + return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER")) + assignment = self._match_pair( TokenType.VAR, TokenType.EQ, advance=False ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) @@ -1083,7 +1105,6 @@ class Parser(metaclass=_Parser): while True: if before: - self._match(TokenType.COMMA) identified_property = self._parse_property_before() else: identified_property = self._parse_property() @@ -1094,7 +1115,7 @@ class Parser(metaclass=_Parser): properties.append(p) if properties: - return self.expression(exp.Properties, expressions=properties, before=before) + return self.expression(exp.Properties, expressions=properties) return None @@ -1118,6 +1139,19 @@ class Parser(metaclass=_Parser): return self._parse_withisolatedloading() + # https://dev.mysql.com/doc/refman/8.0/en/create-view.html + def _parse_definer(self) -> t.Optional[exp.Expression]: + self._match(TokenType.EQ) + + user = self._parse_id_var() + self._match(TokenType.PARAMETER) + host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text) + + if not user or not host: + return None + + return exp.DefinerProperty(this=f"{user}@{host}") + def _parse_withjournaltable(self) -> exp.Expression: self._match_text_seq("WITH", "JOURNAL", "TABLE") self._match(TokenType.EQ) @@ -1695,12 +1729,10 @@ class Parser(metaclass=_Parser): paren += 1 if self._curr.token_type == TokenType.R_PAREN: paren -= 1 + end = self._prev self._advance() if paren > 0: self.raise_error("Expecting )", self._curr) - if not self._curr: - self.raise_error("Expecting pattern", self._curr) - end = self._prev pattern = exp.Var(this=self._find_sql(start, end)) else: pattern = None @@ -2044,9 +2076,16 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv(self._parse_conjunction) grouping_sets = self._parse_grouping_sets() + self._match(TokenType.COMMA) with_ = self._match(TokenType.WITH) - cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars()) - rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars()) + cube = self._match(TokenType.CUBE) and ( + with_ or self._parse_wrapped_csv(self._parse_column) + ) + + self._match(TokenType.COMMA) + rollup = self._match(TokenType.ROLLUP) and ( + with_ or self._parse_wrapped_csv(self._parse_column) + ) return self.expression( exp.Group, @@ -2149,6 +2188,14 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) + def _parse_lock(self) -> t.Optional[exp.Expression]: + if self._match_text_seq("FOR", "UPDATE"): + return self.expression(exp.Lock, update=True) + if self._match_text_seq("FOR", "SHARE"): + return self.expression(exp.Lock, update=False) + + return None + def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set(self.SET_OPERATIONS): return this @@ -2330,12 +2377,21 @@ class Parser(metaclass=_Parser): maybe_func = True if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - return exp.DataType( + this = exp.DataType( this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value, expressions=expressions)], nested=True, ) + while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + this = exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[this], + nested=True, + ) + + return this + if self._match(TokenType.L_BRACKET): self._retreat(index) return None @@ -2430,7 +2486,12 @@ class Parser(metaclass=_Parser): self.raise_error("Expected type") elif op: self._advance() - field = exp.Literal.string(self._prev.text) + value = self._prev.text + field = ( + exp.Literal.number(value) + if self._prev.token_type == TokenType.NUMBER + else exp.Literal.string(value) + ) else: field = self._parse_star() or self._parse_function() or self._parse_id_var() @@ -2752,7 +2813,23 @@ class Parser(metaclass=_Parser): if not self._curr: break - if self._match_text_seq("NOT", "ENFORCED"): + if self._match(TokenType.ON): + action = None + on = self._advance_any() and self._prev.text + + if self._match(TokenType.NO_ACTION): + action = "NO ACTION" + elif self._match(TokenType.CASCADE): + action = "CASCADE" + elif self._match_pair(TokenType.SET, TokenType.NULL): + action = "SET NULL" + elif self._match_pair(TokenType.SET, TokenType.DEFAULT): + action = "SET DEFAULT" + else: + self.raise_error("Invalid key constraint") + + options.append(f"ON {on} {action}") + elif self._match_text_seq("NOT", "ENFORCED"): options.append("NOT ENFORCED") elif self._match_text_seq("DEFERRABLE"): options.append("DEFERRABLE") @@ -2762,10 +2839,6 @@ class Parser(metaclass=_Parser): options.append("NORELY") elif self._match_text_seq("MATCH", "FULL"): options.append("MATCH FULL") - elif self._match_text_seq("ON", "UPDATE", "NO ACTION"): - options.append("ON UPDATE NO ACTION") - elif self._match_text_seq("ON", "DELETE", "NO ACTION"): - options.append("ON DELETE NO ACTION") else: break @@ -3158,7 +3231,9 @@ class Parser(metaclass=_Parser): prefix += self._prev.text if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): - return exp.Identifier(this=prefix + self._prev.text, quoted=False) + quoted = self._prev.token_type == TokenType.STRING + return exp.Identifier(this=prefix + self._prev.text, quoted=quoted) + return None def _parse_string(self) -> t.Optional[exp.Expression]: @@ -3486,6 +3561,11 @@ class Parser(metaclass=_Parser): def _parse_set(self) -> exp.Expression: return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + def _parse_as_command(self, start: Token) -> exp.Command: + while self._curr: + self._advance() + return exp.Command(this=self._find_sql(start, self._prev)) + def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict ) -> t.Optional[t.Callable]: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f6f3883..f5d9f2b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -11,6 +11,7 @@ from sqlglot.trie import in_trie, new_trie if t.TYPE_CHECKING: from sqlglot.dataframe.sql.types import StructType + from sqlglot.dialects.dialect import DialectType ColumnMapping = t.Union[t.Dict, str, StructType, t.List] @@ -153,7 +154,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): self, schema: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None, - dialect: t.Optional[str] = None, + dialect: DialectType = None, ) -> None: self.dialect = dialect self.visible = visible or {} diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 8bdd338..e95057a 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -665,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer): "STRING": TokenType.TEXT, "TEXT": TokenType.TEXT, "CLOB": TokenType.TEXT, + "LONGVARCHAR": TokenType.TEXT, "BINARY": TokenType.BINARY, "BLOB": TokenType.VARBINARY, "BYTEA": TokenType.VARBINARY, -- cgit v1.2.3