From 8bec55350caa5c760d8b7e7e2d0ba6c77a32bc71 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 8 Feb 2023 05:14:34 +0100 Subject: Merging upstream version 10.6.3. Signed-off-by: Daniel Baumann --- README.md | 10 +- sqlglot/__init__.py | 66 +++++++- sqlglot/dataframe/sql/column.py | 6 +- sqlglot/dataframe/sql/functions.py | 20 ++- sqlglot/dialects/bigquery.py | 8 - sqlglot/dialects/dialect.py | 10 ++ sqlglot/dialects/drill.py | 5 +- sqlglot/dialects/duckdb.py | 12 +- sqlglot/dialects/hive.py | 16 +- sqlglot/dialects/mysql.py | 17 +- sqlglot/dialects/oracle.py | 4 + sqlglot/dialects/postgres.py | 45 +++--- sqlglot/dialects/presto.py | 32 +++- sqlglot/dialects/redshift.py | 11 +- sqlglot/dialects/snowflake.py | 9 -- sqlglot/dialects/spark.py | 37 +++++ sqlglot/dialects/tableau.py | 1 - sqlglot/dialects/teradata.py | 8 + sqlglot/dialects/tsql.py | 2 + sqlglot/diff.py | 9 +- sqlglot/executor/__init__.py | 61 +++++-- sqlglot/executor/env.py | 1 + sqlglot/executor/table.py | 7 +- sqlglot/expressions.py | 158 ++++++++++++++++-- sqlglot/generator.py | 187 +++++++++++++++------- sqlglot/lineage.py | 7 +- sqlglot/optimizer/eliminate_subqueries.py | 2 +- sqlglot/optimizer/scope.py | 2 +- sqlglot/optimizer/simplify.py | 6 +- sqlglot/parser.py | 122 +++++++++++--- sqlglot/schema.py | 3 +- sqlglot/tokens.py | 1 + tests/dialects/test_bigquery.py | 4 +- tests/dialects/test_dialect.py | 58 +++++-- tests/dialects/test_duckdb.py | 19 ++- tests/dialects/test_hive.py | 6 +- tests/dialects/test_mysql.py | 20 +++ tests/dialects/test_postgres.py | 20 ++- tests/dialects/test_presto.py | 87 +++++++++- tests/dialects/test_spark.py | 11 ++ tests/dialects/test_sqlite.py | 6 + tests/dialects/test_teradata.py | 3 + tests/fixtures/identity.sql | 16 +- tests/fixtures/optimizer/eliminate_subqueries.sql | 4 + tests/fixtures/pretty.sql | 2 +- tests/test_build.py | 10 ++ tests/test_expressions.py | 14 ++ tests/test_transpile.py | 7 +- 48 files changed, 906 insertions(+), 266 deletions(-) diff --git a/README.md b/README.md index a2e2836..08c7f36 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # SQLGlot -SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. +SQLGlot is a no-dependency SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. -It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python. +It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks), while being written purely in Python. You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. -Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that the parser is very lenient when it comes to detecting errors, because it aims to consume as much SQL as possible. On one hand, this makes its implementation simpler, and thus more comprehensible, but on the other hand it means that syntax errors may sometimes go unnoticed. +Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that SQL validation is not SQLGlot’s goal, so some syntax errors may go unnoticed. Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started! @@ -432,6 +432,8 @@ user_id price 2 3.0 ``` +See also: [Writing a Python SQL engine from scratch](https://github.com/tobymao/sqlglot/blob/main/posts/python_sql_engine.md). + ## Used By * [Fugue](https://github.com/fugue-project/fugue) * [ibis](https://github.com/ibis-project/ibis) @@ -442,7 +444,7 @@ user_id price ## Documentation -SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation: +SQLGlot uses [pdoc](https://pdoc.dev/) to serve its API documentation: ``` make docs-serve diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index bfcabb3..714897f 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -33,7 +33,13 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema, Schema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.6.0" +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + + T = t.TypeVar("T", bound=Expression) + + +__version__ = "10.6.3" pretty = False """Whether to format generated SQL by default.""" @@ -42,9 +48,7 @@ schema = MappingSchema() """The default schema used by SQLGlot (e.g. in the optimizer).""" -def parse( - sql: str, read: t.Optional[str | Dialect] = None, **opts -) -> t.List[t.Optional[Expression]]: +def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]: """ Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. @@ -60,9 +64,57 @@ def parse( return dialect.parse(sql, **opts) +@t.overload +def parse_one( + sql: str, + read: None = None, + into: t.Type[T] = ..., + **opts, +) -> T: + ... + + +@t.overload +def parse_one( + sql: str, + read: DialectType, + into: t.Type[T], + **opts, +) -> T: + ... + + +@t.overload +def parse_one( + sql: str, + read: None = None, + into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ..., + **opts, +) -> Expression: + ... + + +@t.overload +def parse_one( + sql: str, + read: DialectType, + into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]], + **opts, +) -> Expression: + ... + + +@t.overload +def parse_one( + sql: str, + **opts, +) -> Expression: + ... + + def parse_one( sql: str, - read: t.Optional[str | Dialect] = None, + read: DialectType = None, into: t.Optional[exp.IntoType] = None, **opts, ) -> Expression: @@ -96,8 +148,8 @@ def parse_one( def transpile( sql: str, - read: t.Optional[str | Dialect] = None, - write: t.Optional[str | Dialect] = None, + read: DialectType = None, + write: DialectType = None, identity: bool = True, error_level: t.Optional[ErrorLevel] = None, **opts, diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 40ffe3e..f5b0974 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -260,11 +260,7 @@ class Column: """ if isinstance(dataType, DataType): dataType = dataType.simpleString() - new_expression = exp.Cast( - this=self.column_expression, - to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore - ) - return Column(new_expression) + return Column(exp.cast(self.column_expression, dataType, dialect="spark")) def startswith(self, value: t.Union[str, Column]) -> Column: value = self._lit(value) if not isinstance(value, Column) else value diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index a141fe4..47d5e7b 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -536,15 +536,15 @@ def month(col: ColumnOrName) -> Column: def dayofweek(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DAYOFWEEK") + return Column.invoke_expression_over_column(col, glotexp.DayOfWeek) def dayofmonth(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DAYOFMONTH") + return Column.invoke_expression_over_column(col, glotexp.DayOfMonth) def dayofyear(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DAYOFYEAR") + return Column.invoke_expression_over_column(col, glotexp.DayOfYear) def hour(col: ColumnOrName) -> Column: @@ -560,7 +560,7 @@ def second(col: ColumnOrName) -> Column: def weekofyear(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "WEEKOFYEAR") + return Column.invoke_expression_over_column(col, glotexp.WeekOfYear) def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: @@ -1144,10 +1144,16 @@ def aggregate( merge_exp = _get_lambda_from_func(merge) if finish is not None: finish_exp = _get_lambda_from_func(finish) - return Column.invoke_anonymous_function( - col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp) + return Column.invoke_expression_over_column( + col, + glotexp.Reduce, + initial=initialValue, + merge=Column(merge_exp), + finish=Column(finish_exp), ) - return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) + return Column.invoke_expression_over_column( + col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp) + ) def transform( diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 27dca48..90ae229 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -222,14 +222,6 @@ class BigQuery(Dialect): exp.DataType.Type.NVARCHAR: "STRING", } - ROOT_PROPERTIES = { - exp.LanguageProperty, - exp.ReturnsProperty, - exp.VolatilityProperty, - } - - WITH_PROPERTIES = {exp.Property} - EXPLICIT_UNION = True def array_sql(self, expression: exp.Array) -> str: diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0c2beba..1b20e0a 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect): def get_or_raise(cls, dialect): if not dialect: return cls + if isinstance(dialect, _Dialect): + return dialect + if isinstance(dialect, Dialect): + return dialect.__class__ + result = cls.get(dialect) if not result: raise ValueError(f"Unknown dialect '{dialect}'") + return result @classmethod @@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect): ) +if t.TYPE_CHECKING: + DialectType = t.Union[str, Dialect, t.Type[Dialect], None] + + def rename_func(name): def _rename(self, expression): args = flatten(expression.args.values()) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 4e3c0e1..d0a0251 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -137,7 +137,10 @@ class Drill(Dialect): exp.DataType.Type.DATETIME: "TIMESTAMP", } - ROOT_PROPERTIES = {exp.PartitionedByProperty} + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + } TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 4646eb4..95ff95c 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -20,10 +20,6 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _unix_to_time(self, expression): - return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))" - - def _str_to_time_sql(self, expression): return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" @@ -113,7 +109,7 @@ class DuckDB(Dialect): "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRUCT_PACK": exp.Struct.from_arg_list, - "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, "UNNEST": exp.Explode.from_arg_list, } @@ -162,9 +158,9 @@ class DuckDB(Dialect): exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, - exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, - exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)", + exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToTime: rename_func("TO_TIMESTAMP"), + exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", } TYPE_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 4bbec70..f2b6eaa 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -322,17 +322,11 @@ class Hive(Dialect): exp.LastDateOfMonth: rename_func("LAST_DAY"), } - WITH_PROPERTIES = {exp.Property} - - ROOT_PROPERTIES = { - exp.PartitionedByProperty, - exp.FileFormatProperty, - exp.SchemaCommentProperty, - exp.LocationProperty, - exp.TableFormatProperty, - exp.RowFormatDelimitedProperty, - exp.RowFormatSerdeProperty, - exp.SerdeProperties, + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, } def with_properties(self, properties): diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index cd8c30c..a5bd86b 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,7 +1,5 @@ from __future__ import annotations -import typing as t - from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, @@ -98,6 +96,8 @@ def _date_add_sql(kind): class MySQL(Dialect): + time_format = "'%Y-%m-%d %T'" + # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions time_mapping = { "%M": "%B", @@ -110,6 +110,7 @@ class MySQL(Dialect): "%u": "%W", "%k": "%-H", "%l": "%-I", + "%T": "%H:%M:%S", } class Tokenizer(tokens.Tokenizer): @@ -428,6 +429,7 @@ class MySQL(Dialect): ) class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False TRANSFORMS = { @@ -449,23 +451,12 @@ class MySQL(Dialect): exp.StrPosition: strposition_to_locate_sql, } - ROOT_PROPERTIES = { - exp.EngineProperty, - exp.AutoIncrementProperty, - exp.CharacterSetProperty, - exp.CollateProperty, - exp.SchemaCommentProperty, - exp.LikeProperty, - } - TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) - WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() - def show_sql(self, expression): this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 67d791d..fde845e 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -44,6 +44,8 @@ class Oracle(Dialect): } class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "NUMBER", @@ -69,6 +71,7 @@ class Oracle(Dialect): exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", + exp.Substring: rename_func("SUBSTR"), } def query_modifiers(self, expression, *sqls): @@ -90,6 +93,7 @@ class Oracle(Dialect): self.sql(expression, "order"), self.sql(expression, "offset"), # offset before limit in oracle self.sql(expression, "limit"), + self.sql(expression, "lock"), sep="", ) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 0d74b3a..6418032 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -148,6 +148,22 @@ def _serial_to_generated(expression): return expression +def _generate_series(args): + # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day + step = seq_get(args, 2) + + if step is None: + # Postgres allows calls with just two arguments -- the "step" argument defaults to 1 + return exp.GenerateSeries.from_arg_list(args) + + if step.is_string: + args[2] = exp.to_interval(step.this) + elif isinstance(step, exp.Interval) and not step.args.get("unit"): + args[2] = exp.to_interval(step.this.this) + + return exp.GenerateSeries.from_arg_list(args) + + def _to_timestamp(args): # TO_TIMESTAMP accepts either a single double argument or (text, text) if len(args) == 1: @@ -195,29 +211,6 @@ class Postgres(Dialect): HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] - CREATABLES = ( - "AGGREGATE", - "CAST", - "CONVERSION", - "COLLATION", - "DEFAULT CONVERSION", - "CONSTRAINT", - "DOMAIN", - "EXTENSION", - "FOREIGN", - "FUNCTION", - "OPERATOR", - "POLICY", - "ROLE", - "RULE", - "SEQUENCE", - "TEXT", - "TRIGGER", - "TYPE", - "UNLOGGED", - "USER", - ) - KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "~~": TokenType.LIKE, @@ -243,8 +236,6 @@ class Postgres(Dialect): "TEMP": TokenType.TEMPORARY, "UUID": TokenType.UUID, "CSTRING": TokenType.PSEUDO_TYPE, - **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, - **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } QUOTES = ["'", "$$"] SINGLE_TOKENS = { @@ -257,8 +248,10 @@ class Postgres(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore + "NOW": exp.CurrentTimestamp.from_arg_list, "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), + "GENERATE_SERIES": _generate_series, } BITWISE = { @@ -272,6 +265,8 @@ class Postgres(Dialect): } class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "SMALLINT", diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 8175d6f..6c1a474 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -105,6 +105,29 @@ def _ts_or_ds_add_sql(self, expression): return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" +def _sequence_sql(self, expression): + start = expression.args["start"] + end = expression.args["end"] + step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series + + target_type = None + + if isinstance(start, exp.Cast): + target_type = start.to + elif isinstance(end, exp.Cast): + target_type = end.to + + if target_type and target_type.this == exp.DataType.Type.TIMESTAMP: + to = target_type.copy() + + if target_type is start.to: + end = exp.Cast(this=end, to=to) + else: + start = exp.Cast(this=start, to=to) + + return f"SEQUENCE({self.format_args(start, end, step)})" + + def _ensure_utf8(charset): if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") @@ -145,7 +168,7 @@ def _from_unixtime(args): class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" - time_format = "'%Y-%m-%d %H:%i:%S'" + time_format = MySQL.time_format # type: ignore time_mapping = MySQL.time_mapping # type: ignore class Tokenizer(tokens.Tokenizer): @@ -197,7 +220,10 @@ class Presto(Dialect): class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") - ROOT_PROPERTIES = {exp.SchemaCommentProperty} + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -223,6 +249,7 @@ class Presto(Dialect): exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DataType: _datatype_sql, exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", @@ -231,6 +258,7 @@ class Presto(Dialect): exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.Encode: _encode_sql, + exp.GenerateSeries: _sequence_sql, exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 7da881f..c3c99eb 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -61,14 +61,9 @@ class Redshift(Postgres): exp.DataType.Type.INT: "INTEGER", } - ROOT_PROPERTIES = { - exp.DistKeyProperty, - exp.SortKeyProperty, - exp.DistStyleProperty, - } - - WITH_PROPERTIES = { - exp.LikeProperty, + PROPERTIES_LOCATION = { + **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH, } TRANSFORMS = { diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index db72a34..3b83b02 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -234,15 +234,6 @@ class Snowflake(Dialect): "replace": "RENAME", } - ROOT_PROPERTIES = { - exp.PartitionedByProperty, - exp.ReturnsProperty, - exp.LanguageProperty, - exp.SchemaCommentProperty, - exp.ExecuteAsProperty, - exp.VolatilityProperty, - } - def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index fc711ab..8ef4a87 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -73,6 +73,19 @@ class Spark(Hive): ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "IIF": exp.If.from_arg_list, + "AGGREGATE": exp.Reduce.from_arg_list, + "DAYOFWEEK": lambda args: exp.DayOfWeek( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "DAYOFMONTH": lambda args: exp.DayOfMonth( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "DAYOFYEAR": lambda args: exp.DayOfYear( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "WEEKOFYEAR": lambda args: exp.WeekOfYear( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), } FUNCTION_PARSERS = { @@ -105,6 +118,14 @@ class Spark(Hive): exp.DataType.Type.BIGINT: "LONG", } + PROPERTIES_LOCATION = { + **Hive.Generator.PROPERTIES_LOCATION, # type: ignore + exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, + exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, + exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, + exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, + } + TRANSFORMS = { **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), @@ -126,11 +147,27 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), exp.LogicalOr: rename_func("BOOL_OR"), + exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", } TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False + def cast_sql(self, expression: exp.Cast) -> str: + if isinstance(expression.this, exp.Cast) and expression.this.is_type( + exp.DataType.Type.JSON + ): + schema = f"'{self.sql(expression, 'to')}'" + return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})" + if expression.to.is_type(exp.DataType.Type.JSON): + return f"TO_JSON({self.sql(expression, 'this')})" + + return super(Spark.Generator, self).cast_sql(expression) + class Tokenizer(Hive.Tokenizer): HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 36c085f..31b1c8d 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -31,6 +31,5 @@ class Tableau(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "IFNULL": exp.Coalesce.from_arg_list, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 4340820..123da04 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -76,6 +76,14 @@ class Teradata(Dialect): ) class Generator(generator.Generator): + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, + } + + def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: + return f"PARTITION BY {self.sql(expression, 'this')}" + # FROM before SET in Teradata UPDATE syntax # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause def update_sql(self, expression: exp.Update) -> str: diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 9f9099e..05ba53a 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -412,6 +412,8 @@ class TSQL(Dialect): return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) class Generator(generator.Generator): + LOCKING_READS_SUPPORTED = True + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "BIT", diff --git a/sqlglot/diff.py b/sqlglot/diff.py index a5373b0..7d5ec21 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -14,10 +14,6 @@ from sqlglot import Dialect from sqlglot import expressions as exp from sqlglot.helper import ensure_collection -if t.TYPE_CHECKING: - T = t.TypeVar("T") - Edit = t.Union[Insert, Remove, Move, Update, Keep] - @dataclass(frozen=True) class Insert: @@ -56,6 +52,11 @@ class Keep: target: exp.Expression +if t.TYPE_CHECKING: + T = t.TypeVar("T") + Edit = t.Union[Insert, Remove, Move, Update, Keep] + + def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: """ Returns the list of changes between the source and the target expressions. diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index 04621b5..67b4b00 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -1,5 +1,13 @@ +""" +.. include:: ../../posts/python_sql_engine.md +---- +""" + +from __future__ import annotations + import logging import time +import typing as t from sqlglot import maybe_parse from sqlglot.errors import ExecuteError @@ -11,42 +19,63 @@ from sqlglot.schema import ensure_schema logger = logging.getLogger("sqlglot") +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + from sqlglot.executor.table import Tables + from sqlglot.expressions import Expression + from sqlglot.schema import Schema -def execute(sql, schema=None, read=None, tables=None): + +def execute( + sql: str | Expression, + schema: t.Optional[t.Dict | Schema] = None, + read: DialectType = None, + tables: t.Optional[t.Dict] = None, +) -> Table: """ Run a sql query against data. Args: - sql (str|sqlglot.Expression): a sql statement - schema (dict|sqlglot.optimizer.Schema): database schema. - This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of - the following forms: - 1. {table: {col: type}} - 2. {db: {table: {col: type}}} - 3. {catalog: {db: {table: {col: type}}}} - read (str): the SQL dialect to apply during parsing - (eg. "spark", "hive", "presto", "mysql"). - tables (dict): additional tables to register. + sql: a sql statement. + schema: database schema. + This can either be an instance of `Schema` or a mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + tables: additional tables to register. + Returns: - sqlglot.executor.Table: Simple columnar data structure. + Simple columnar data structure. """ - tables = ensure_tables(tables) + tables_ = ensure_tables(tables) + if not schema: schema = { name: {column: type(table[0][column]).__name__ for column in table.columns} - for name, table in tables.mapping.items() + for name, table in tables_.mapping.items() } + schema = ensure_schema(schema) - if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args: + + if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args: raise ExecuteError("Tables must support the same table args as schema") + expression = maybe_parse(sql, dialect=read) + now = time.time() expression = optimize(expression, schema, leave_tables_isolated=True) + logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) + plan = Plan(expression) + logger.debug("Logical Plan: %s", plan) + now = time.time() - result = PythonExecutor(tables=tables).execute(plan) + result = PythonExecutor(tables=tables_).execute(plan) + logger.debug("Query finished: %f", time.time() - now) + return result diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 04dc938..ba9cbbd 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -171,5 +171,6 @@ ENV = { "STRPOSITION": str_position, "SUB": null_if_any(lambda e, this: e - this), "SUBSTRING": substring, + "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)), "UPPER": null_if_any(lambda arg: arg.upper()), } diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index f1b5b54..27e3e5e 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot.helper import dict_depth from sqlglot.schema import AbstractMappingSchema @@ -106,11 +108,11 @@ class Tables(AbstractMappingSchema[Table]): pass -def ensure_tables(d: dict | None) -> Tables: +def ensure_tables(d: t.Optional[t.Dict]) -> Tables: return Tables(_ensure_tables(d)) -def _ensure_tables(d: dict | None) -> dict: +def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict: if not d: return {} @@ -127,4 +129,5 @@ def _ensure_tables(d: dict | None) -> dict: columns = tuple(table[0]) if table else () rows = [tuple(row[c] for c in columns) for row in table] result[name] = Table(columns=columns, rows=rows) + return result diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7c1a116..6bb083a 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -32,13 +32,7 @@ from sqlglot.helper import ( from sqlglot.tokens import Token if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import Dialect - - IntoType = t.Union[ - str, - t.Type[Expression], - t.Collection[t.Union[str, t.Type[Expression]]], - ] + from sqlglot.dialects.dialect import DialectType class _Expression(type): @@ -427,7 +421,7 @@ class Expression(metaclass=_Expression): def __repr__(self): return self._to_s() - def sql(self, dialect: Dialect | str | None = None, **opts) -> str: + def sql(self, dialect: DialectType = None, **opts) -> str: """ Returns SQL string representation of this tree. @@ -595,6 +589,14 @@ class Expression(metaclass=_Expression): return load(obj) +if t.TYPE_CHECKING: + IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], + ] + + class Condition(Expression): def and_(self, *expressions, dialect=None, **opts): """ @@ -1285,6 +1287,18 @@ class Property(Expression): arg_types = {"this": True, "value": True} +class AlgorithmProperty(Property): + arg_types = {"this": True} + + +class DefinerProperty(Property): + arg_types = {"this": True} + + +class SqlSecurityProperty(Property): + arg_types = {"definer": True} + + class TableFormatProperty(Property): arg_types = {"this": True} @@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property): class Properties(Expression): - arg_types = {"expressions": True, "before": False} + arg_types = {"expressions": True} NAME_TO_PROPERTY = { + "ALGORITHM": AlgorithmProperty, "AUTO_INCREMENT": AutoIncrementProperty, "CHARACTER SET": CharacterSetProperty, "COLLATE": CollateProperty, "COMMENT": SchemaCommentProperty, + "DEFINER": DefinerProperty, "DISTKEY": DistKeyProperty, "DISTSTYLE": DistStyleProperty, "ENGINE": EngineProperty, @@ -1447,6 +1463,14 @@ class Properties(Expression): PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + class Location(AutoName): + POST_CREATE = auto() + PRE_SCHEMA = auto() + POST_INDEX = auto() + POST_SCHEMA_ROOT = auto() + POST_SCHEMA_WITH = auto() + UNSUPPORTED = auto() + @classmethod def from_dict(cls, properties_dict) -> Properties: expressions = [] @@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = { "order": False, "limit": False, "offset": False, + "lock": False, } @@ -1713,6 +1738,12 @@ class Schema(Expression): arg_types = {"this": False, "expressions": False} +# Used to represent the FOR UPDATE and FOR SHARE locking read types. +# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html +class Lock(Expression): + arg_types = {"update": True} + + class Select(Subqueryable): arg_types = { "with": False, @@ -2243,6 +2274,30 @@ class Select(Subqueryable): properties=properties_expression, ) + def lock(self, update: bool = True, copy: bool = True) -> Select: + """ + Set the locking read mode for this expression. + + Examples: + >>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR UPDATE" + + >>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR SHARE" + + Args: + update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`. + copy: if `False`, modify this expression instance in-place. + + Returns: + The modified expression. + """ + + inst = _maybe_copy(self, copy) + inst.set("lock", Lock(update=update)) + + return inst + @property def named_selects(self) -> t.List[str]: return [e.output_name for e in self.expressions if e.alias_or_name] @@ -2456,24 +2511,28 @@ class DataType(Expression): @classmethod def build( - cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs + cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs ) -> DataType: from sqlglot import parse_one if isinstance(dtype, str): - data_type_exp: t.Optional[Expression] if dtype.upper() in cls.Type.__members__: - data_type_exp = DataType(this=DataType.Type[dtype.upper()]) + data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()]) else: data_type_exp = parse_one(dtype, read=dialect, into=DataType) if data_type_exp is None: raise ValueError(f"Unparsable data type value: {dtype}") elif isinstance(dtype, DataType.Type): data_type_exp = DataType(this=dtype) + elif isinstance(dtype, DataType): + return dtype else: raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") return DataType(**{**data_type_exp.args, **kwargs}) + def is_type(self, dtype: DataType.Type) -> bool: + return self.this == dtype + # https://www.postgresql.org/docs/15/datatype-pseudo.html class PseudoType(Expression): @@ -2840,6 +2899,10 @@ class Array(Func): is_var_len_args = True +class GenerateSeries(Func): + arg_types = {"start": True, "end": True, "step": False} + + class ArrayAgg(AggFunc): pass @@ -2909,6 +2972,9 @@ class Cast(Func): def output_name(self): return self.name + def is_type(self, dtype: DataType.Type) -> bool: + return self.to.is_type(dtype) + class Collate(Binary): pass @@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class DayOfWeek(Func): + _sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"] + + +class DayOfMonth(Func): + _sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"] + + +class DayOfYear(Func): + _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"] + + +class WeekOfYear(Func): + _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"] + + class LastDateOfMonth(Func): pass @@ -3239,7 +3321,7 @@ class ReadCSV(Func): class Reduce(Func): - arg_types = {"this": True, "initial": True, "merge": True, "finish": True} + arg_types = {"this": True, "initial": True, "merge": True, "finish": False} class RegexpLike(Func): @@ -3476,7 +3558,7 @@ def maybe_parse( sql_or_expression: str | Expression, *, into: t.Optional[IntoType] = None, - dialect: t.Optional[str] = None, + dialect: DialectType = None, prefix: t.Optional[str] = None, **opts, ) -> Expression: @@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: return identifier +INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*") + + +def to_interval(interval: str | Literal) -> Interval: + """Builds an interval expression from a string like '1 day' or '5 months'.""" + if isinstance(interval, Literal): + if not interval.is_string: + raise ValueError("Invalid interval string.") + + interval = interval.this + + interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore + + if not interval_parts: + raise ValueError("Invalid interval string.") + + return Interval( + this=Literal.string(interval_parts.group(1)), + unit=Var(this=interval_parts.group(2)), + ) + + @t.overload def to_table(sql_path: str | Table, **kwargs) -> Table: ... @@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): def subquery(expression, alias=None, dialect=None, **opts): """ Build a subquery expression. - Expample: + + Example: >>> subquery('select x from tbl', 'bar').select('x').sql() 'SELECT x FROM (SELECT x FROM tbl) AS bar' @@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts): def column(col, table=None, quoted=None) -> Column: """ Build a Column. + Args: col (str | Expression): column name table (str | Expression): table name @@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column: ) +def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast: + """Cast an expression to a data type. + + Example: + >>> cast('x + 1', 'int').sql() + 'CAST(x + 1 AS INT)' + + Args: + expression: The expression to cast. + to: The datatype to cast to. + + Returns: + A cast node. + """ + expression = maybe_parse(expression, **opts) + return Cast(this=expression, to=DataType.build(to, **opts)) + + def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: """Build a Table. @@ -4137,7 +4261,7 @@ def values( types = list(columns.values()) expressions[0].set( "expressions", - [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)], + [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)], ) return Values( expressions=expressions, @@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True return expression.transform(_expand, copy=copy) -def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func: +def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: """ Returns a Func expression. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 3f3365a..b95e9bc 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -67,6 +67,7 @@ class Generator: exp.VolatilityProperty: lambda self, e: e.name, exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", + exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -75,6 +76,9 @@ class Generator: # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True + # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported + LOCKING_READS_SUPPORTED = False + # Always do union distinct or union all EXPLICIT_UNION = False @@ -99,34 +103,42 @@ class Generator: STRUCT_DELIMITER = ("<", ">") - BEFORE_PROPERTIES = { - exp.FallbackProperty, - exp.WithJournalTableProperty, - exp.LogProperty, - exp.JournalProperty, - exp.AfterJournalProperty, - exp.ChecksumProperty, - exp.FreespaceProperty, - exp.MergeBlockRatioProperty, - exp.DataBlocksizeProperty, - exp.BlockCompressionProperty, - exp.IsolatedLoadingProperty, - } - - ROOT_PROPERTIES = { - exp.ReturnsProperty, - exp.LanguageProperty, - exp.DistStyleProperty, - exp.DistKeyProperty, - exp.SortKeyProperty, - exp.LikeProperty, - } - - WITH_PROPERTIES = { - exp.Property, - exp.FileFormatProperty, - exp.PartitionedByProperty, - exp.TableFormatProperty, + PROPERTIES_LOCATION = { + exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, + exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA, + exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA, + exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA, + exp.DefinerProperty: exp.Properties.Location.POST_CREATE, + exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA, + exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA, + exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA, + exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LogProperty: exp.Properties.Location.PRE_SCHEMA, + exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.Property: exp.Properties.Location.POST_SCHEMA_WITH, + exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, + exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA, } WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) @@ -284,10 +296,10 @@ class Generator: ) return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" - def no_identify(self, func: t.Callable[[], str]) -> str: + def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str: original = self.identify self.identify = False - result = func() + result = func(*args, **kwargs) self.identify = original return result @@ -455,19 +467,33 @@ class Generator: def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() - has_before_properties = expression.args.get("properties") - has_before_properties = ( - has_before_properties.args.get("before") if has_before_properties else None - ) - if kind == "TABLE" and has_before_properties: + properties = expression.args.get("properties") + properties_exp = expression.copy() + properties_locs = self.locate_properties(properties) if properties else {} + if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get( + exp.Properties.Location.POST_SCHEMA_WITH + ): + properties_exp.set( + "properties", + exp.Properties( + expressions=[ + *properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT], + *properties_locs[exp.Properties.Location.POST_SCHEMA_WITH], + ] + ), + ) + if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA): this_name = self.sql(expression.this, "this") - this_properties = self.sql(expression, "properties") + this_properties = self.properties( + exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]), + wrapped=False, + ) this_schema = f"({self.expressions(expression.this)})" this = f"{this_name}, {this_properties} {this_schema}" - properties = "" + properties_sql = "" else: this = self.sql(expression, "this") - properties = self.sql(expression, "properties") + properties_sql = self.sql(properties_exp, "properties") begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" @@ -514,11 +540,31 @@ class Generator: if index.args.get("columns") else "" ) + if index.args.get("primary") and properties_locs.get( + exp.Properties.Location.POST_INDEX + ): + postindex_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_INDEX] + ), + wrapped=False, + ) + ind_columns = f"{ind_columns} {postindex_props_sql}" + indexes_sql.append( f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" ) index_sql = "".join(indexes_sql) + postcreate_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_CREATE): + postcreate_props_sql = self.properties( + exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]), + sep=" ", + prefix=" ", + wrapped=False, + ) + modifiers = "".join( ( replace, @@ -531,6 +577,7 @@ class Generator: multiset, global_temporary, volatile, + postcreate_props_sql, ) ) no_schema_binding = ( @@ -539,7 +586,7 @@ class Generator: post_expression_modifiers = "".join((data, statistics, no_primary_index)) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression: exp.Describe) -> str: @@ -665,24 +712,19 @@ class Generator: return f"PARTITION({self.expressions(expression)})" def properties_sql(self, expression: exp.Properties) -> str: - before_properties = [] root_properties = [] with_properties = [] for p in expression.expressions: - p_class = p.__class__ - if p_class in self.BEFORE_PROPERTIES: - before_properties.append(p) - elif p_class in self.WITH_PROPERTIES: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc == exp.Properties.Location.POST_SCHEMA_WITH: with_properties.append(p) - elif p_class in self.ROOT_PROPERTIES: + elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: root_properties.append(p) - return ( - self.properties(exp.Properties(expressions=before_properties), before=True) - + self.root_properties(exp.Properties(expressions=root_properties)) - + self.with_properties(exp.Properties(expressions=with_properties)) - ) + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) def root_properties(self, properties: exp.Properties) -> str: if properties.expressions: @@ -695,17 +737,41 @@ class Generator: prefix: str = "", sep: str = ", ", suffix: str = "", - before: bool = False, + wrapped: bool = True, ) -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - expressions = expressions if before else self.wrap(expressions) + expressions = self.wrap(expressions) if wrapped else expressions return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" return "" def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("WITH")) + def locate_properties( + self, properties: exp.Properties + ) -> t.Dict[exp.Properties.Location, list[exp.Property]]: + properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = { + key: [] for key in exp.Properties.Location + } + + for p in properties.expressions: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc == exp.Properties.Location.PRE_SCHEMA: + properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p) + elif p_loc == exp.Properties.Location.POST_INDEX: + properties_locs[exp.Properties.Location.POST_INDEX].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: + properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH: + properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p) + elif p_loc == exp.Properties.Location.POST_CREATE: + properties_locs[exp.Properties.Location.POST_CREATE].append(p) + elif p_loc == exp.Properties.Location.UNSUPPORTED: + self.unsupported(f"Unsupported property {p.key}") + + return properties_locs + def property_sql(self, expression: exp.Property) -> str: property_cls = expression.__class__ if property_cls == exp.Property: @@ -713,7 +779,7 @@ class Generator: property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) if not property_name: - self.unsupported(f"Unsupported property {property_name}") + self.unsupported(f"Unsupported property {expression.key}") return f"{property_name}={self.sql(expression, 'this')}" @@ -975,7 +1041,7 @@ class Generator: rollup = self.expressions(expression, key="rollup", indent=False) rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" - return f"{group_by}{grouping_sets}{cube}{rollup}" + return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}" def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) @@ -1015,7 +1081,7 @@ class Generator: def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args - return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") + return f"{args} {arrow_sep} {self.sql(expression, 'this')}" def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") @@ -1043,6 +1109,14 @@ class Generator: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + def lock_sql(self, expression: exp.Lock) -> str: + if self.LOCKING_READS_SUPPORTED: + lock_type = "UPDATE" if expression.args["update"] else "SHARE" + return self.seg(f"FOR {lock_type}") + + self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") + return "" + def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: @@ -1163,6 +1237,7 @@ class Generator: self.sql(expression, "order"), self.sql(expression, "limit"), self.sql(expression, "offset"), + self.sql(expression, "lock"), sep="", ) @@ -1773,7 +1848,7 @@ class Generator: def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") - expressions = self.no_identify(lambda: self.expressions(expression)) + expressions = self.no_identify(self.expressions, expression) expressions = ( self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" ) diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 4e7eab8..a39ad8c 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -9,6 +9,9 @@ from sqlglot.optimizer import Scope, build_scope, optimize from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + @dataclass(frozen=True) class Node: @@ -36,7 +39,7 @@ def lineage( schema: t.Optional[t.Dict | Schema] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns), - dialect: t.Optional[str] = None, + dialect: DialectType = None, ) -> Node: """Build the lineage graph for a column of a SQL query. @@ -126,7 +129,7 @@ class LineageHTML: def __init__( self, node: Node, - dialect: t.Optional[str] = None, + dialect: DialectType = None, imports: bool = True, **opts: t.Any, ): diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 2245cc2..c6bea5a 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -114,7 +114,7 @@ def _eliminate_union(scope, existing_ctes, taken): taken[alias] = scope # Try to maintain the selections - expressions = scope.expression.args.get("expressions") + expressions = scope.selects selects = [ exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) for e in expressions diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 5a3ed5a..badbb87 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -300,7 +300,7 @@ class Scope: list[exp.Expression]: expressions """ if isinstance(self.expression, exp.Union): - return [] + return self.expression.unnest().selects return self.expression.selects @property diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f560760..f80484d 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -456,8 +456,10 @@ def extract_interval(interval): def date_literal(date): - expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE") - return exp.Cast(this=exp.Literal.string(date), to=expr_type) + return exp.cast( + exp.Literal.string(date), + "DATETIME" if isinstance(date, datetime.datetime) else "DATE", + ) def boolean_literal(condition): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 6229105..e2b2c54 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -80,6 +80,7 @@ class Parser(metaclass=_Parser): length=exp.Literal.number(10), ), "VAR_MAP": parse_var_map, + "IFNULL": exp.Coalesce.from_arg_list, } NO_PAREN_FUNCTIONS = { @@ -567,6 +568,8 @@ class Parser(metaclass=_Parser): default=self._prev.text.upper() == "DEFAULT" ), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), + "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), + "DEFINER": lambda self: self._parse_definer(), } CONSTRAINT_PARSERS = { @@ -608,6 +611,7 @@ class Parser(metaclass=_Parser): "order": lambda self: self._parse_order(), "limit": lambda self: self._parse_limit(), "offset": lambda self: self._parse_offset(), + "lock": lambda self: self._parse_lock(), } SHOW_PARSERS: t.Dict[str, t.Callable] = {} @@ -850,7 +854,7 @@ class Parser(metaclass=_Parser): self.raise_error(error_message) def _find_sql(self, start: Token, end: Token) -> str: - return self.sql[self._find_token(start) : self._find_token(end)] + return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)] def _find_token(self, token: Token) -> int: line = 1 @@ -901,6 +905,7 @@ class Parser(metaclass=_Parser): return expression def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: + start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text @@ -908,8 +913,7 @@ class Parser(metaclass=_Parser): if default_kind: kind = default_kind else: - self.raise_error(f"Expected {self.CREATABLES}") - return None + return self._parse_as_command(start) return self.expression( exp.Drop, @@ -929,6 +933,7 @@ class Parser(metaclass=_Parser): ) def _parse_create(self) -> t.Optional[exp.Expression]: + start = self._prev replace = self._match_pair(TokenType.OR, TokenType.REPLACE) set_ = self._match(TokenType.SET) # Teradata multiset = self._match_text_seq("MULTISET") # Teradata @@ -943,16 +948,19 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): self._match(TokenType.TABLE) + properties = None create_token = self._match_set(self.CREATABLES) and self._prev if not create_token: - self.raise_error(f"Expected {self.CREATABLES}") - return None + properties = self._parse_properties() + create_token = self._match_set(self.CREATABLES) and self._prev + + if not properties or not create_token: + return self._parse_as_command(start) exists = self._parse_exists(not_=True) this = None expression = None - properties = None data = None statistics = None no_primary_index = None @@ -1006,6 +1014,14 @@ class Parser(metaclass=_Parser): indexes = [] while True: index = self._parse_create_table_index() + + # post index PARTITION BY property + if self._match(TokenType.PARTITION_BY, advance=False): + if properties: + properties.expressions.append(self._parse_property()) + else: + properties = self._parse_properties() + if not index: break else: @@ -1040,6 +1056,9 @@ class Parser(metaclass=_Parser): ) def _parse_property_before(self) -> t.Optional[exp.Expression]: + self._match(TokenType.COMMA) + + # parsers look to _prev for no/dual/default, so need to consume first self._match_text_seq("NO") self._match_text_seq("DUAL") self._match_text_seq("DEFAULT") @@ -1059,6 +1078,9 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): return self._parse_sortkey(compound=True) + if self._match_text_seq("SQL", "SECURITY"): + return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER")) + assignment = self._match_pair( TokenType.VAR, TokenType.EQ, advance=False ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) @@ -1083,7 +1105,6 @@ class Parser(metaclass=_Parser): while True: if before: - self._match(TokenType.COMMA) identified_property = self._parse_property_before() else: identified_property = self._parse_property() @@ -1094,7 +1115,7 @@ class Parser(metaclass=_Parser): properties.append(p) if properties: - return self.expression(exp.Properties, expressions=properties, before=before) + return self.expression(exp.Properties, expressions=properties) return None @@ -1118,6 +1139,19 @@ class Parser(metaclass=_Parser): return self._parse_withisolatedloading() + # https://dev.mysql.com/doc/refman/8.0/en/create-view.html + def _parse_definer(self) -> t.Optional[exp.Expression]: + self._match(TokenType.EQ) + + user = self._parse_id_var() + self._match(TokenType.PARAMETER) + host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text) + + if not user or not host: + return None + + return exp.DefinerProperty(this=f"{user}@{host}") + def _parse_withjournaltable(self) -> exp.Expression: self._match_text_seq("WITH", "JOURNAL", "TABLE") self._match(TokenType.EQ) @@ -1695,12 +1729,10 @@ class Parser(metaclass=_Parser): paren += 1 if self._curr.token_type == TokenType.R_PAREN: paren -= 1 + end = self._prev self._advance() if paren > 0: self.raise_error("Expecting )", self._curr) - if not self._curr: - self.raise_error("Expecting pattern", self._curr) - end = self._prev pattern = exp.Var(this=self._find_sql(start, end)) else: pattern = None @@ -2044,9 +2076,16 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv(self._parse_conjunction) grouping_sets = self._parse_grouping_sets() + self._match(TokenType.COMMA) with_ = self._match(TokenType.WITH) - cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars()) - rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars()) + cube = self._match(TokenType.CUBE) and ( + with_ or self._parse_wrapped_csv(self._parse_column) + ) + + self._match(TokenType.COMMA) + rollup = self._match(TokenType.ROLLUP) and ( + with_ or self._parse_wrapped_csv(self._parse_column) + ) return self.expression( exp.Group, @@ -2149,6 +2188,14 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) + def _parse_lock(self) -> t.Optional[exp.Expression]: + if self._match_text_seq("FOR", "UPDATE"): + return self.expression(exp.Lock, update=True) + if self._match_text_seq("FOR", "SHARE"): + return self.expression(exp.Lock, update=False) + + return None + def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set(self.SET_OPERATIONS): return this @@ -2330,12 +2377,21 @@ class Parser(metaclass=_Parser): maybe_func = True if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - return exp.DataType( + this = exp.DataType( this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value, expressions=expressions)], nested=True, ) + while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + this = exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[this], + nested=True, + ) + + return this + if self._match(TokenType.L_BRACKET): self._retreat(index) return None @@ -2430,7 +2486,12 @@ class Parser(metaclass=_Parser): self.raise_error("Expected type") elif op: self._advance() - field = exp.Literal.string(self._prev.text) + value = self._prev.text + field = ( + exp.Literal.number(value) + if self._prev.token_type == TokenType.NUMBER + else exp.Literal.string(value) + ) else: field = self._parse_star() or self._parse_function() or self._parse_id_var() @@ -2752,7 +2813,23 @@ class Parser(metaclass=_Parser): if not self._curr: break - if self._match_text_seq("NOT", "ENFORCED"): + if self._match(TokenType.ON): + action = None + on = self._advance_any() and self._prev.text + + if self._match(TokenType.NO_ACTION): + action = "NO ACTION" + elif self._match(TokenType.CASCADE): + action = "CASCADE" + elif self._match_pair(TokenType.SET, TokenType.NULL): + action = "SET NULL" + elif self._match_pair(TokenType.SET, TokenType.DEFAULT): + action = "SET DEFAULT" + else: + self.raise_error("Invalid key constraint") + + options.append(f"ON {on} {action}") + elif self._match_text_seq("NOT", "ENFORCED"): options.append("NOT ENFORCED") elif self._match_text_seq("DEFERRABLE"): options.append("DEFERRABLE") @@ -2762,10 +2839,6 @@ class Parser(metaclass=_Parser): options.append("NORELY") elif self._match_text_seq("MATCH", "FULL"): options.append("MATCH FULL") - elif self._match_text_seq("ON", "UPDATE", "NO ACTION"): - options.append("ON UPDATE NO ACTION") - elif self._match_text_seq("ON", "DELETE", "NO ACTION"): - options.append("ON DELETE NO ACTION") else: break @@ -3158,7 +3231,9 @@ class Parser(metaclass=_Parser): prefix += self._prev.text if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): - return exp.Identifier(this=prefix + self._prev.text, quoted=False) + quoted = self._prev.token_type == TokenType.STRING + return exp.Identifier(this=prefix + self._prev.text, quoted=quoted) + return None def _parse_string(self) -> t.Optional[exp.Expression]: @@ -3486,6 +3561,11 @@ class Parser(metaclass=_Parser): def _parse_set(self) -> exp.Expression: return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + def _parse_as_command(self, start: Token) -> exp.Command: + while self._curr: + self._advance() + return exp.Command(this=self._find_sql(start, self._prev)) + def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict ) -> t.Optional[t.Callable]: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f6f3883..f5d9f2b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -11,6 +11,7 @@ from sqlglot.trie import in_trie, new_trie if t.TYPE_CHECKING: from sqlglot.dataframe.sql.types import StructType + from sqlglot.dialects.dialect import DialectType ColumnMapping = t.Union[t.Dict, str, StructType, t.List] @@ -153,7 +154,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): self, schema: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None, - dialect: t.Optional[str] = None, + dialect: DialectType = None, ) -> None: self.dialect = dialect self.visible = visible or {} diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 8bdd338..e95057a 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -665,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer): "STRING": TokenType.TEXT, "TEXT": TokenType.TEXT, "CLOB": TokenType.TEXT, + "LONGVARCHAR": TokenType.TEXT, "BINARY": TokenType.BINARY, "BLOB": TokenType.VARBINARY, "BYTEA": TokenType.VARBINARY, diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index e5b1c94..241f496 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -170,7 +170,7 @@ class TestBigQuery(Validator): "bigquery": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP()", "postgres": "CURRENT_TIMESTAMP", - "presto": "CURRENT_TIMESTAMP()", + "presto": "CURRENT_TIMESTAMP", "hive": "CURRENT_TIMESTAMP()", "spark": "CURRENT_TIMESTAMP()", }, @@ -181,7 +181,7 @@ class TestBigQuery(Validator): "bigquery": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP()", "postgres": "CURRENT_TIMESTAMP", - "presto": "CURRENT_TIMESTAMP()", + "presto": "CURRENT_TIMESTAMP", "hive": "CURRENT_TIMESTAMP()", "spark": "CURRENT_TIMESTAMP()", }, diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 5a13655..a456415 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1,6 +1,7 @@ import unittest from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one +from sqlglot.dialects import Hive class Validator(unittest.TestCase): @@ -67,6 +68,11 @@ class TestDialect(Validator): self.assertIsNotNone(Dialect.get_or_raise(dialect)) self.assertIsNotNone(Dialect[dialect.value]) + def test_get_or_raise(self): + self.assertEqual(Dialect.get_or_raise(Hive), Hive) + self.assertEqual(Dialect.get_or_raise(Hive()), Hive) + self.assertEqual(Dialect.get_or_raise("hive"), Hive) + def test_cast(self): self.validate_all( "CAST(a AS TEXT)", @@ -280,6 +286,21 @@ class TestDialect(Validator): write={"oracle": "CAST(a AS NUMBER)"}, ) + def test_if_null(self): + self.validate_all( + "SELECT IFNULL(1, NULL) FROM foo", + write={ + "": "SELECT COALESCE(1, NULL) FROM foo", + "redshift": "SELECT COALESCE(1, NULL) FROM foo", + "postgres": "SELECT COALESCE(1, NULL) FROM foo", + "mysql": "SELECT COALESCE(1, NULL) FROM foo", + "duckdb": "SELECT COALESCE(1, NULL) FROM foo", + "spark": "SELECT COALESCE(1, NULL) FROM foo", + "bigquery": "SELECT COALESCE(1, NULL) FROM foo", + "presto": "SELECT COALESCE(1, NULL) FROM foo", + }, + ) + def test_time(self): self.validate_all( "STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')", @@ -287,10 +308,10 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", }, write={ - "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%T')", "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", - "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", + "presto": "DATE_PARSE(x, '%Y-%m-%dT%T')", "drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')", "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", @@ -356,7 +377,7 @@ class TestDialect(Validator): write={ "duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))", "hive": "UNIX_TIMESTAMP('2020-01-01')", - "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%S'))", + "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))", }, ) self.validate_all( @@ -418,7 +439,7 @@ class TestDialect(Validator): self.validate_all( "UNIX_TO_STR(x, y)", write={ - "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)", + "duckdb": "STRFTIME(TO_TIMESTAMP(x), y)", "hive": "FROM_UNIXTIME(x, y)", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", "starrocks": "FROM_UNIXTIME(x, y)", @@ -427,7 +448,7 @@ class TestDialect(Validator): self.validate_all( "UNIX_TO_TIME(x)", write={ - "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", + "duckdb": "TO_TIMESTAMP(x)", "hive": "FROM_UNIXTIME(x)", "oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)", "postgres": "TO_TIMESTAMP(x)", @@ -438,7 +459,7 @@ class TestDialect(Validator): self.validate_all( "UNIX_TO_TIME_STR(x)", write={ - "duckdb": "CAST(TO_TIMESTAMP(CAST(x AS BIGINT)) AS TEXT)", + "duckdb": "CAST(TO_TIMESTAMP(x) AS TEXT)", "hive": "FROM_UNIXTIME(x)", "presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)", }, @@ -575,10 +596,10 @@ class TestDialect(Validator): }, write={ "drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')", - "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", - "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%T')", + "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%T')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", - "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S') AS DATE)", + "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)", "spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')", }, ) @@ -709,6 +730,7 @@ class TestDialect(Validator): "hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", }, ) @@ -1381,3 +1403,21 @@ SELECT "spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name", }, ) + + def test_substring(self): + self.validate_all( + "SUBSTR('123456', 2, 3)", + write={ + "bigquery": "SUBSTR('123456', 2, 3)", + "oracle": "SUBSTR('123456', 2, 3)", + "postgres": "SUBSTR('123456', 2, 3)", + }, + ) + self.validate_all( + "SUBSTRING('123456', 2, 3)", + write={ + "bigquery": "SUBSTRING('123456', 2, 3)", + "oracle": "SUBSTR('123456', 2, 3)", + "postgres": "SUBSTRING('123456' FROM 2 FOR 3)", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index f6446ca..f01a604 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -22,7 +22,7 @@ class TestDuckDB(Validator): "EPOCH_MS(x)", write={ "bigquery": "UNIX_TO_TIME(x / 1000)", - "duckdb": "TO_TIMESTAMP(CAST(x / 1000 AS BIGINT))", + "duckdb": "TO_TIMESTAMP(x / 1000)", "presto": "FROM_UNIXTIME(x / 1000)", "spark": "FROM_UNIXTIME(x / 1000)", }, @@ -41,7 +41,7 @@ class TestDuckDB(Validator): "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", write={ "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", - "presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", }, ) @@ -58,9 +58,10 @@ class TestDuckDB(Validator): self.validate_all( "TO_TIMESTAMP(x)", write={ - "duckdb": "CAST(x AS TIMESTAMP)", - "presto": "CAST(x AS TIMESTAMP)", - "hive": "CAST(x AS TIMESTAMP)", + "bigquery": "UNIX_TO_TIME(x)", + "duckdb": "TO_TIMESTAMP(x)", + "presto": "FROM_UNIXTIME(x)", + "hive": "FROM_UNIXTIME(x)", }, ) self.validate_all( @@ -334,6 +335,14 @@ class TestDuckDB(Validator): }, ) + self.validate_all( + "cast([[1]] as int[][])", + write={ + "duckdb": "CAST(LIST_VALUE(LIST_VALUE(1)) AS INT[][])", + "spark": "CAST(ARRAY(ARRAY(1)) AS ARRAY>)", + }, + ) + def test_bool_or(self): self.validate_all( "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index c41e4f7..1f35d1d 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -259,7 +259,7 @@ class TestHive(Validator): self.validate_all( """from_unixtime(x, "yyyy-MM-dd'T'HH")""", write={ - "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), '%Y-%m-%d''T''%H')", + "duckdb": "STRFTIME(TO_TIMESTAMP(x), '%Y-%m-%d''T''%H')", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')", "hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", "spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", @@ -269,7 +269,7 @@ class TestHive(Validator): "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", write={ "duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')", - "presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%i:%S')", + "presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %T')", "hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", "spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", }, @@ -308,7 +308,7 @@ class TestHive(Validator): "UNIX_TIMESTAMP(x)", write={ "duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))", - "presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %H:%i:%S'))", + "presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %T'))", "hive": "UNIX_TIMESTAMP(x)", "spark": "UNIX_TIMESTAMP(x)", "": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')", diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index ce865e1..3e3b0d3 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -195,6 +195,26 @@ class TestMySQL(Validator): ) def test_mysql(self): + self.validate_all( + "SELECT a FROM tbl FOR UPDATE", + write={ + "": "SELECT a FROM tbl", + "mysql": "SELECT a FROM tbl FOR UPDATE", + "oracle": "SELECT a FROM tbl FOR UPDATE", + "postgres": "SELECT a FROM tbl FOR UPDATE", + "tsql": "SELECT a FROM tbl FOR UPDATE", + }, + ) + self.validate_all( + "SELECT a FROM tbl FOR SHARE", + write={ + "": "SELECT a FROM tbl", + "mysql": "SELECT a FROM tbl FOR SHARE", + "oracle": "SELECT a FROM tbl FOR SHARE", + "postgres": "SELECT a FROM tbl FOR SHARE", + "tsql": "SELECT a FROM tbl FOR SHARE", + }, + ) self.validate_all( "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", write={ diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 8a17b78..5664a2a 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -112,6 +112,22 @@ class TestPostgres(Validator): self.validate_identity("x ~ 'y'") self.validate_identity("x ~* 'y'") + self.validate_all( + "GENERATE_SERIES(a, b, ' 2 days ')", + write={ + "postgres": "GENERATE_SERIES(a, b, INTERVAL '2' days)", + "presto": "SEQUENCE(a, b, INTERVAL '2' days)", + "trino": "SEQUENCE(a, b, INTERVAL '2' days)", + }, + ) + self.validate_all( + "GENERATE_SERIES('2019-01-01'::TIMESTAMP, NOW(), '1day')", + write={ + "postgres": "GENERATE_SERIES(CAST('2019-01-01' AS TIMESTAMP), CURRENT_TIMESTAMP, INTERVAL '1' day)", + "presto": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", + "trino": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", + }, + ) self.validate_all( "END WORK AND NO CHAIN", write={"postgres": "COMMIT AND NO CHAIN"}, @@ -249,7 +265,7 @@ class TestPostgres(Validator): ) self.validate_all( "'[1,2,3]'::json->2", - write={"postgres": "CAST('[1,2,3]' AS JSON) -> '2'"}, + write={"postgres": "CAST('[1,2,3]' AS JSON) -> 2"}, ) self.validate_all( """'{"a":1,"b":2}'::json->'b'""", @@ -265,7 +281,7 @@ class TestPostgres(Validator): ) self.validate_all( """'[1,2,3]'::json->>2""", - write={"postgres": "CAST('[1,2,3]' AS JSON) ->> '2'"}, + write={"postgres": "CAST('[1,2,3]' AS JSON) ->> 2"}, ) self.validate_all( """'{"a":1,"b":2}'::json->>'b'""", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 5ecd69a..9815dcc 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -111,7 +111,7 @@ class TestPresto(Validator): "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", write={ "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", - "presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", }, @@ -120,7 +120,7 @@ class TestPresto(Validator): "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", write={ "duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')", - "presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", + "presto": "DATE_PARSE(x, '%Y-%m-%d %T')", "hive": "CAST(x AS TIMESTAMP)", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')", }, @@ -134,6 +134,12 @@ class TestPresto(Validator): "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')", }, ) + self.validate_all( + "DATE_FORMAT(x, '%T')", + write={ + "hive": "DATE_FORMAT(x, 'HH:mm:ss')", + }, + ) self.validate_all( "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", write={ @@ -146,7 +152,7 @@ class TestPresto(Validator): self.validate_all( "FROM_UNIXTIME(x)", write={ - "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", + "duckdb": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)", "spark": "FROM_UNIXTIME(x)", @@ -177,11 +183,51 @@ class TestPresto(Validator): self.validate_all( "NOW()", write={ - "presto": "CURRENT_TIMESTAMP()", + "presto": "CURRENT_TIMESTAMP", "hive": "CURRENT_TIMESTAMP()", }, ) + self.validate_all( + "DAY_OF_WEEK(timestamp '2012-08-08 01:00')", + write={ + "spark": "DAYOFWEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))", + }, + ) + + self.validate_all( + "DAY_OF_MONTH(timestamp '2012-08-08 01:00')", + write={ + "spark": "DAYOFMONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "presto": "DAY_OF_MONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))", + }, + ) + + self.validate_all( + "DAY_OF_YEAR(timestamp '2012-08-08 01:00')", + write={ + "spark": "DAYOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "presto": "DAY_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", + }, + ) + + self.validate_all( + "WEEK_OF_YEAR(timestamp '2012-08-08 01:00')", + write={ + "spark": "WEEKOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "presto": "WEEK_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", + }, + ) + + self.validate_all( + "SELECT timestamp '2012-10-31 00:00' AT TIME ZONE 'America/Sao_Paulo'", + write={ + "spark": "SELECT FROM_UTC_TIMESTAMP(CAST('2012-10-31 00:00' AS TIMESTAMP), 'America/Sao_Paulo')", + "presto": "SELECT CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'", + }, + ) + def test_ddl(self): self.validate_all( "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", @@ -314,6 +360,11 @@ class TestPresto(Validator): def test_presto(self): self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") + self.validate_identity("SELECT * FROM (VALUES (1))") + self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") + self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") + self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") + self.validate_all( 'SELECT a."b" FROM "foo"', write={ @@ -455,10 +506,6 @@ class TestPresto(Validator): "spark": UnsupportedError, }, ) - self.validate_identity("SELECT * FROM (VALUES (1))") - self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") - self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") - self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") def test_encode_decode(self): self.validate_all( @@ -529,3 +576,27 @@ class TestPresto(Validator): "presto": "FROM_HEX(x)", }, ) + + def test_json(self): + self.validate_all( + "SELECT CAST(JSON '[1,23,456]' AS ARRAY(INTEGER))", + write={ + "spark": "SELECT FROM_JSON('[1,23,456]', 'ARRAY')", + "presto": "SELECT CAST(CAST('[1,23,456]' AS JSON) AS ARRAY(INTEGER))", + }, + ) + self.validate_all( + """SELECT CAST(JSON '{"k1":1,"k2":23,"k3":456}' AS MAP(VARCHAR, INTEGER))""", + write={ + "spark": 'SELECT FROM_JSON(\'{"k1":1,"k2":23,"k3":456}\', \'MAP\')', + "presto": 'SELECT CAST(CAST(\'{"k1":1,"k2":23,"k3":456}\' AS JSON) AS MAP(VARCHAR, INTEGER))', + }, + ) + + self.validate_all( + "SELECT CAST(ARRAY [1, 23, 456] AS JSON)", + write={ + "spark": "SELECT TO_JSON(ARRAY(1, 23, 456))", + "presto": "SELECT CAST(ARRAY[1, 23, 456] AS JSON)", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 02d43aa..be74a27 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -212,6 +212,17 @@ TBLPROPERTIES ( self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") + + self.validate_all( + "AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", + write={ + "trino": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", + "duckdb": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", + "hive": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", + "presto": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", + "spark": "AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", + }, + ) self.validate_all( "TRIM('SL', 'SSparkSQLS')", write={"spark": "TRIM('SL' FROM 'SSparkSQLS')"} ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index e54a4bc..c4f4a6e 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -92,3 +92,9 @@ class TestSQLite(Validator): "sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks" }, ) + + def test_longvarchar_dtype(self): + self.validate_all( + "CREATE TABLE foo (bar LONGVARCHAR)", + write={"sqlite": "CREATE TABLE foo (bar TEXT)"}, + ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index e56de25..9e82961 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -21,3 +21,6 @@ class TestTeradata(Validator): "mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1", }, ) + + def test_create(self): + self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index f2830b1..5a4871d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -161,6 +161,7 @@ SELECT 1 FROM test SELECT * FROM a, b, (SELECT 1) AS c SELECT a FROM test SELECT 1 AS filter +SELECT 1 AS "quoted alias" SELECT SUM(x) AS filter SELECT 1 AS range FROM test SELECT 1 AS count FROM test @@ -264,7 +265,9 @@ SELECT a FROM test GROUP BY GROUPING SETS (x, ()) SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q) SELECT a FROM test GROUP BY CUBE (x) SELECT a FROM test GROUP BY ROLLUP (x) -SELECT a FROM test GROUP BY CUBE (x) ROLLUP (x, y, z) +SELECT t.a FROM test AS t GROUP BY ROLLUP (t.x) +SELECT a FROM test GROUP BY GROUPING SETS ((x, y)), ROLLUP (b) +SELECT a FROM test GROUP BY CUBE (x), ROLLUP (x, y, z) SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a'] @@ -339,7 +342,6 @@ SELECT CAST(a AS ARRAY) FROM test SELECT CAST(a AS VARIANT) FROM test SELECT TRY_CAST(a AS INT) FROM test SELECT COALESCE(a, b, c) FROM test -SELECT IFNULL(a, b) FROM test SELECT ANY_VALUE(a) FROM test SELECT 1 FROM a JOIN b ON a.x = b.x SELECT 1 FROM a JOIN b AS c ON a.x = b.x @@ -510,6 +512,14 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) CREATE TABLE z (a INT REFERENCES parent(b, c)) CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE NO ACTION) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE CASCADE) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET NULL) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET DEFAULT) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE NO ACTION) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE CASCADE) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET NULL) +CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET DEFAULT) CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA CREATE TABLE asd AS SELECT asd FROM asd WITH DATA CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY) @@ -526,6 +536,7 @@ CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DAT CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT) CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT) CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT) +CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b) CREATE TEMPORARY TABLE x AS SELECT a FROM d CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE VIEW x AS SELECT a FROM b @@ -555,6 +566,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b) CREATE SCHEMA x CREATE SCHEMA IF NOT EXISTS y CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END' +CREATE OR REPLACE STAGE DESCRIBE x DROP INDEX a.b.c DROP FUNCTION a.b.c (INT) diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql index c566657..4fa63dd 100644 --- a/tests/fixtures/optimizer/eliminate_subqueries.sql +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -50,6 +50,10 @@ WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.x AS x, cte.y AS y FROM cte AS ct (SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y)); WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4); +-- Three unions +SELECT a FROM x UNION ALL SELECT a FROM y UNION ALL SELECT a FROM z; +WITH cte AS (SELECT a FROM x), cte_2 AS (SELECT a FROM y), cte_3 AS (SELECT a FROM z), cte_4 AS (SELECT cte_2.a AS a FROM cte_2 AS cte_2 UNION ALL SELECT cte_3.a AS a FROM cte_3 AS cte_3) SELECT cte.a AS a FROM cte AS cte UNION ALL SELECT cte_4.a AS a FROM cte_4 AS cte_4; + -- Subquery SELECT a FROM x WHERE b = (SELECT y.c FROM y); SELECT a FROM x WHERE b = (SELECT y.c FROM y); diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 64806eb..a240597 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -99,7 +99,7 @@ WITH cte1 AS ( GROUPING SETS ( a, (b, c) - ) + ), CUBE ( y, z diff --git a/tests/test_build.py b/tests/test_build.py index a1a268d..fbfbb62 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -62,6 +62,16 @@ class TestBuild(unittest.TestCase): lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False), "SELECT x FROM tbl WHERE x < 9", ), + ( + lambda: select("x").from_("tbl").where("x > 0").lock(), + "SELECT x FROM tbl WHERE x > 0 FOR UPDATE", + "mysql", + ), + ( + lambda: select("x").from_("tbl").where("x > 0").lock(update=False), + "SELECT x FROM tbl WHERE x > 0 FOR SHARE", + "postgres", + ), ( lambda: select("x", "y").from_("tbl").group_by("x"), "SELECT x, y FROM tbl GROUP BY x", diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 2d5407e..55e07d1 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -466,6 +466,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction) self.assertIsInstance(parse_one("COMMIT"), exp.Commit) self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback) + self.assertIsInstance(parse_one("GENERATE_SERIES(a, b, c)"), exp.GenerateSeries) def test_column(self): dot = parse_one("a.b.c") @@ -630,6 +631,19 @@ FROM foo""", FROM foo""", ) + def test_to_interval(self): + self.assertEqual(exp.to_interval("1day").sql(), "INTERVAL '1' day") + self.assertEqual(exp.to_interval(" 5 months").sql(), "INTERVAL '5' months") + with self.assertRaises(ValueError): + exp.to_interval("bla") + + self.assertEqual(exp.to_interval(exp.Literal.string("1day")).sql(), "INTERVAL '1' day") + self.assertEqual( + exp.to_interval(exp.Literal.string(" 5 months")).sql(), "INTERVAL '5' months" + ) + with self.assertRaises(ValueError): + exp.to_interval(exp.Literal.string("bla")) + def test_to_table(self): table_only = exp.to_table("table_name") self.assertEqual(table_only.name, "table_name") diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 3e094f5..c0d518d 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -326,12 +326,12 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb") self.validate( "UNIX_TO_STR(123, 'y')", - "STRFTIME(TO_TIMESTAMP(CAST(123 AS BIGINT)), 'y')", + "STRFTIME(TO_TIMESTAMP(123), 'y')", write="duckdb", ) self.validate( "UNIX_TO_TIME(123)", - "TO_TIMESTAMP(CAST(123 AS BIGINT))", + "TO_TIMESTAMP(123)", write="duckdb", ) @@ -426,6 +426,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1) mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1) + def test_identify_lambda(self): + self.validate("x(y -> y)", 'X("y" -> "y")', identify=True) + def test_identity(self): self.assertEqual(transpile("")[0], "") for sql in load_sql_fixtures("identity.sql"): -- cgit v1.2.3