diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:39 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:39 +0000 |
commit | f2981e8e4d28233864f1ca06ecec45ab80bf9eae (patch) | |
tree | b70cb633916830138ce3424aa361f0bbaff02be2 | |
parent | Releasing debian version 10.0.1-1. (diff) | |
download | sqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.tar.xz sqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.zip |
Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
67 files changed, 2463 insertions, 842 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 87dd21d..70f2b55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,21 +7,37 @@ v10.0.0 Changes: - Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions. + - Breaking: renamed list_get to seq_get. + - Breaking: activated mypy type checking for SQLGlot. + - New: Azure Databricks support. + - New: placeholders can now be replaced in an expression. + - New: null safe equal operator (<=>). + - New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL. + - New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL. + - New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL. + - New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL. -- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16) + +- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16). + - New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16). + - Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities. + - Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible. + - Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor. + - Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636). + - Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause. v9.0.0 @@ -37,6 +53,7 @@ v8.0.0 Changes: - Breaking : New add\_table method in Schema ABC. + - New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental. v7.1.0 @@ -45,8 +62,11 @@ v7.1.0 Changes: - Improvement: Pretty generator now takes max\_text\_width which breaks segments into new lines + - New: exp.to\_table helper to turn table names into table expression objects + - New: int[] type parsers + - New: annotations are now generated in sql v7.0.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1d3b822..97c795d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,7 +21,7 @@ Pull requests are the best way to propose changes to the codebase. We actively w 5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor! ## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues) -We use GitHub issues to track public bugs. Report a bug by [opening a new issue](). +We use GitHub issues to track public bugs. Report a bug by opening a new issue. **Great Bug Reports** tend to have: @@ -90,7 +90,7 @@ sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive" "SELECT DATE_FORMAT(x, 'yy-M-ss')" ``` -As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`: +As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks for identifiers and `FLOAT` instead of `REAL`: ```python import sqlglot @@ -376,12 +376,12 @@ print(Dialect["custom"]) [Benchmarks](benchmarks) run on Python 3.10.5 in seconds. -| Query | sqlglot | sqltree | sqlparse | moz_sql_parser | sqloxide | -| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | -| tpch | 0.01178 (1.0) | 0.01173 (0.995) | 0.04676 (3.966) | 0.06800 (5.768) | 0.00094 (0.080) | -| short | 0.00084 (1.0) | 0.00079 (0.948) | 0.00296 (3.524) | 0.00443 (5.266) | 0.00006 (0.072) | -| long | 0.01102 (1.0) | 0.01044 (0.947) | 0.04349 (3.945) | 0.05998 (5.440) | 0.00084 (0.077) | -| crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) | +| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide | +| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | +| tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) | +| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) | +| long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) | +| crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) | ## Optional Dependencies diff --git a/benchmarks/bench.py b/benchmarks/bench.py index cef62a8..2475608 100644 --- a/benchmarks/bench.py +++ b/benchmarks/bench.py @@ -5,8 +5,10 @@ collections.Iterable = collections.abc.Iterable import gc import timeit -import moz_sql_parser import numpy as np + +import sqlfluff +import moz_sql_parser import sqloxide import sqlparse import sqltree @@ -177,6 +179,10 @@ def sqloxide_parse(sql): sqloxide.parse_sql(sql, dialect="ansi") +def sqlfluff_parse(sql): + sqlfluff.parse(sql) + + def border(columns): columns = " | ".join(columns) return f"| {columns} |" @@ -193,6 +199,7 @@ def diff(row, column): libs = [ "sqlglot", + "sqlfluff", "sqltree", "sqlparse", "moz_sql_parser", @@ -206,7 +213,8 @@ for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.it for lib in libs: try: row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3)) - except: + except Exception as e: + print(e) row[lib] = "error" columns = ["Query"] + libs diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 6e67b19..50e2d9c 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -30,7 +30,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.0.1" +__version__ = "10.0.8" pretty = False diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index f9e1c5b..22075e9 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -260,7 +260,10 @@ class Column: """ if isinstance(dataType, DataType): dataType = dataType.simpleString() - new_expression = exp.Cast(this=self.column_expression, to=dataType) + new_expression = exp.Cast( + this=self.column_expression, + to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore + ) return Column(new_expression) def startswith(self, value: t.Union[str, Column]) -> Column: diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 40cd6c9..548c322 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -314,7 +314,13 @@ class DataFrame: replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore cache_table_name ) - sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) + sqlglot.schema.add_table( + cache_table_name, + { + expression.alias_or_name: expression.type.name + for expression in select_expression.expressions + }, + ) cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index dbfb06f..1ee361a 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column: def decode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_anonymous_function(col, "DECODE", lit(charset)) + return Column.invoke_expression_over_column( + col, glotexp.Decode, charset=glotexp.Literal.string(charset) + ) def encode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_anonymous_function(col, "ENCODE", lit(charset)) + return Column.invoke_expression_over_column( + col, glotexp.Encode, charset=glotexp.Literal.string(charset) + ) def format_number(col: ColumnOrName, d: int) -> Column: @@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column: def hex(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "HEX") + return Column.invoke_expression_over_column(col, glotexp.Hex) def unhex(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "UNHEX") + return Column.invoke_expression_over_column(col, glotexp.Unhex) def length(col: ColumnOrName) -> Column: @@ -939,11 +943,7 @@ def array_join( def concat(*cols: ColumnOrName) -> Column: - if len(cols) == 1: - return Column.invoke_anonymous_function(cols[0], "CONCAT") - return Column.invoke_anonymous_function( - cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]] - ) + return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 8cb16ef..c4a22c6 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -88,14 +88,14 @@ class SparkSession: "expressions": sel_columns, "from": exp.From( expressions=[ - exp.Subquery( - this=exp.Values(expressions=data_expressions), + exp.Values( + expressions=data_expressions, alias=exp.TableAlias( this=exp.to_identifier(self._auto_incrementing_name), columns=[exp.to_identifier(col_name) for col_name in column_mapping], ), - ) - ] + ), + ], ), } diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 0816831..2e42e7d 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.dialects.drill import Drill from sqlglot.dialects.duckdb import DuckDB from sqlglot.dialects.hive import Hive from sqlglot.dialects.mysql import MySQL diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5bbff9d..4550d65 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -119,6 +119,8 @@ class BigQuery(Dialect): "UNKNOWN": TokenType.NULL, "WINDOW": TokenType.WINDOW, "NOT DETERMINISTIC": TokenType.VOLATILE, + "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, } KEYWORDS.pop("DIV") @@ -204,6 +206,15 @@ class BigQuery(Dialect): EXPLICIT_UNION = True + def transaction_sql(self, *_): + return "BEGIN TRANSACTION" + + def commit_sql(self, *_): + return "COMMIT TRANSACTION" + + def rollback_sql(self, *_): + return "ROLLBACK TRANSACTION" + def in_unnest_op(self, unnest): return self.sql(unnest) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 3af08bb..8c497ab 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -32,6 +32,7 @@ class Dialects(str, Enum): TRINO = "trino" TSQL = "tsql" DATABRICKS = "databricks" + DRILL = "drill" class _Dialect(type): @@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None): return exp_class(this=this, expression=expression, unit=unit) return inner_func + + +def locate_to_strposition(args): + return exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ) + + +def strposition_to_local_sql(self, expression): + args = self.format_args( + expression.args.get("substr"), expression.this, expression.args.get("position") + ) + return f"LOCATE({args})" diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py new file mode 100644 index 0000000..eb420aa --- /dev/null +++ b/sqlglot/dialects/drill.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import re + +from sqlglot import exp, generator, parser, tokens +from sqlglot.dialects.dialect import ( + Dialect, + create_with_partitions_sql, + format_time_lambda, + no_pivot_sql, + no_trycast_sql, + rename_func, + str_position_sql, +) +from sqlglot.dialects.postgres import _lateral_sql + + +def _to_timestamp(args): + # TO_TIMESTAMP accepts either a single double argument or (text, text) + if len(args) == 1 and args[0].is_number: + return exp.UnixToTime.from_arg_list(args) + return format_time_lambda(exp.StrToTime, "drill")(args) + + +def _str_to_time_sql(self, expression): + return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (Drill.time_format, Drill.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def _date_add_sql(kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() or "DAY" + expression = self.sql(expression, "expression") + return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})" + + return func + + +def if_sql(self, expression): + """ + Drill requires backticks around certain SQL reserved words, IF being one of them, This function + adds the backticks around the keyword IF. + Args: + self: The Drill dialect + expression: The input IF expression + + Returns: The expression with IF in backticks. + + """ + expressions = self.format_args( + expression.this, expression.args.get("true"), expression.args.get("false") + ) + return f"`IF`({expressions})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format == Drill.date_format: + return f"CAST({this} AS DATE)" + return f"TO_DATE({this}, {time_format})" + + +class Drill(Dialect): + normalize_functions = None + null_ordering = "nulls_are_last" + date_format = "'yyyy-MM-dd'" + dateint_format = "'yyyyMMdd'" + time_format = "'yyyy-MM-dd HH:mm:ss'" + + time_mapping = { + "y": "%Y", + "Y": "%Y", + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "hh": "%I", + "h": "%-I", + "mm": "%M", + "m": "%-M", + "ss": "%S", + "s": "%-S", + "SSSSSS": "%f", + "a": "%p", + "DD": "%j", + "D": "%-j", + "E": "%a", + "EE": "%a", + "EEE": "%a", + "EEEE": "%A", + "''T''": "T", + } + + class Tokenizer(tokens.Tokenizer): + QUOTES = ["'"] + IDENTIFIERS = ["`"] + ESCAPES = ["\\"] + ENCODE = "utf-8" + + class Parser(parser.Parser): + STRICT_CAST = False + + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), + } + + class Generator(generator.Generator): + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.SMALLINT: "INTEGER", + exp.DataType.Type.TINYINT: "INTEGER", + exp.DataType.Type.BINARY: "VARBINARY", + exp.DataType.Type.TEXT: "VARCHAR", + exp.DataType.Type.NCHAR: "VARCHAR", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.DATETIME: "TIMESTAMP", + } + + ROOT_PROPERTIES = {exp.PartitionedByProperty} + + TRANSFORMS = { + **generator.Generator.TRANSFORMS, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.Lateral: _lateral_sql, + exp.ArrayContains: rename_func("REPEATED_CONTAINS"), + exp.ArraySize: rename_func("REPEATED_COUNT"), + exp.Create: create_with_partitions_sql, + exp.DateAdd: _date_add_sql("ADD"), + exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateSub: _date_add_sql("SUB"), + exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", + exp.If: if_sql, + exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", + exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.Pivot: no_pivot_sql, + exp.RegexpLike: rename_func("REGEXP_MATCHES"), + exp.StrPosition: str_position_sql, + exp.StrToDate: _str_to_date, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TryCast: no_trycast_sql, + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)", + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + } + + def normalize_func(self, name): + return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 781edff..f1da72b 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -55,13 +55,13 @@ def _array_sort_sql(self, expression): def _sort_array_sql(self, expression): this = self.sql(expression, "this") - if expression.args.get("asc") == exp.FALSE: + if expression.args.get("asc") == exp.false(): return f"ARRAY_REVERSE_SORT({this})" return f"ARRAY_SORT({this})" def _sort_array_reverse(args): - return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE) + return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) def _struct_pack_sql(self, expression): diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ed7357c..cff7139 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import ( create_with_partitions_sql, format_time_lambda, if_sql, + locate_to_strposition, no_ilike_sql, no_recursive_cte_sql, no_safe_divide_sql, no_trycast_sql, rename_func, + strposition_to_local_sql, struct_extract_sql, var_map_sql, ) from sqlglot.helper import seq_get from sqlglot.parser import parse_var_map +from sqlglot.tokens import TokenType # (FuncType, Multiplier) DATE_DELTA_INTERVAL = { @@ -181,6 +184,15 @@ class Hive(Dialect): "F": "FLOAT", "BD": "DECIMAL", } + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + "ADD ARCHIVE": TokenType.COMMAND, + "ADD ARCHIVES": TokenType.COMMAND, + "ADD FILE": TokenType.COMMAND, + "ADD FILES": TokenType.COMMAND, + "ADD JAR": TokenType.COMMAND, + "ADD JARS": TokenType.COMMAND, + } class Parser(parser.Parser): STRICT_CAST = False @@ -210,11 +222,7 @@ class Hive(Dialect): "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, - "LOCATE": lambda args: exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), - ), + "LOCATE": locate_to_strposition, "LOG": ( lambda args: exp.Log.from_arg_list(args) if len(args) > 1 @@ -272,7 +280,7 @@ class Hive(Dialect): exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", - exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", + exp.StrPosition: strposition_to_local_sql, exp.StrToDate: _str_to_date, exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index e742640..93a60f4 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -5,10 +5,12 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + locate_to_strposition, no_ilike_sql, no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + strposition_to_local_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -120,6 +122,7 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -172,13 +175,18 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} class Parser(parser.Parser): - STRICT_CAST = False + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, + "LOCATE": locate_to_strposition, + "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), + "LEFT": lambda args: exp.Substring( + this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1) + ), } FUNCTION_PARSERS = { @@ -264,6 +272,7 @@ class MySQL(Dialect): "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "NAMES": lambda self: self._parse_set_item_names(), + "TRANSACTION": lambda self: self._parse_set_transaction(), } PROFILE_TYPES = { @@ -278,39 +287,48 @@ class MySQL(Dialect): "SWAPS", } + TRANSACTION_CHARACTERISTICS = { + "ISOLATION LEVEL REPEATABLE READ", + "ISOLATION LEVEL READ COMMITTED", + "ISOLATION LEVEL READ UNCOMMITTED", + "ISOLATION LEVEL SERIALIZABLE", + "READ WRITE", + "READ ONLY", + } + def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): - self._match_text(target) + self._match_text_seq(target) target_id = self._parse_id_var() else: target_id = None - log = self._parse_string() if self._match_text("IN") else None + log = self._parse_string() if self._match_text_seq("IN") else None if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: - position = self._parse_number() if self._match_text("FROM") else None + position = self._parse_number() if self._match_text_seq("FROM") else None db = None else: position = None - db = self._parse_id_var() if self._match_text("FROM") else None + db = self._parse_id_var() if self._match_text_seq("FROM") else None - channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None - like = self._parse_string() if self._match_text("LIKE") else None + like = self._parse_string() if self._match_text_seq("LIKE") else None where = self._parse_where() if this == "PROFILE": - types = self._parse_csv(self._parse_show_profile_type) - query = self._parse_number() if self._match_text("FOR", "QUERY") else None - offset = self._parse_number() if self._match_text("OFFSET") else None - limit = self._parse_number() if self._match_text("LIMIT") else None + types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES)) + query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text_seq("OFFSET") else None + limit = self._parse_number() if self._match_text_seq("LIMIT") else None else: types, query = None, None offset, limit = self._parse_oldstyle_limit() - mutex = True if self._match_text("MUTEX") else None - mutex = False if self._match_text("STATUS") else mutex + mutex = True if self._match_text_seq("MUTEX") else None + mutex = False if self._match_text_seq("STATUS") else mutex return self.expression( exp.Show, @@ -331,16 +349,16 @@ class MySQL(Dialect): **{"global": global_}, ) - def _parse_show_profile_type(self): - for type_ in self.PROFILE_TYPES: - if self._match_text(*type_.split(" ")): - return exp.Var(this=type_) + def _parse_var_from_options(self, options): + for option in options: + if self._match_text_seq(*option.split(" ")): + return exp.Var(this=option) return None def _parse_oldstyle_limit(self): limit = None offset = None - if self._match_text("LIMIT"): + if self._match_text_seq("LIMIT"): parts = self._parse_csv(self._parse_number) if len(parts) == 1: limit = parts[0] @@ -353,6 +371,9 @@ class MySQL(Dialect): return self._parse_set_item_assignment(kind=None) def _parse_set_item_assignment(self, kind): + if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): + return self._parse_set_transaction(global_=kind == "GLOBAL") + left = self._parse_primary() or self._parse_id_var() if not self._match(TokenType.EQ): self.raise_error("Expected =") @@ -381,7 +402,7 @@ class MySQL(Dialect): def _parse_set_item_names(self): charset = self._parse_string() or self._parse_id_var() - if self._match_text("COLLATE"): + if self._match_text_seq("COLLATE"): collate = self._parse_string() or self._parse_id_var() else: collate = None @@ -392,6 +413,18 @@ class MySQL(Dialect): kind="NAMES", ) + def _parse_set_transaction(self, global_=False): + self._match_text_seq("TRANSACTION") + characteristics = self._parse_csv( + lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) + ) + return self.expression( + exp.SetItem, + expressions=characteristics, + kind="TRANSACTION", + **{"global": global_}, + ) + class Generator(generator.Generator): NULL_ORDERING_SUPPORTED = False @@ -411,6 +444,7 @@ class MySQL(Dialect): exp.Trim: _trim_sql, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.StrPosition: strposition_to_local_sql, } ROOT_PROPERTIES = { @@ -481,9 +515,11 @@ class MySQL(Dialect): kind = self.sql(expression, "kind") kind = f"{kind} " if kind else "" this = self.sql(expression, "this") + expressions = self.expressions(expression) collate = self.sql(expression, "collate") collate = f" COLLATE {collate}" if collate else "" - return f"{kind}{this}{collate}" + global_ = "GLOBAL " if expression.args.get("global") else "" + return f"{global_}{kind}{this}{expressions}{collate}" def set_sql(self, expression): return f"SET {self.expressions(expression)}" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 3bc1109..870d2b9 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -91,6 +91,7 @@ class Oracle(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, "NVARCHAR2": TokenType.NVARCHAR, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 553a73b..4353164 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -164,11 +164,34 @@ class Postgres(Dialect): BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] + + CREATABLES = ( + "AGGREGATE", + "CAST", + "CONVERSION", + "COLLATION", + "DEFAULT CONVERSION", + "CONSTRAINT", + "DOMAIN", + "EXTENSION", + "FOREIGN", + "FUNCTION", + "OPERATOR", + "POLICY", + "ROLE", + "RULE", + "SEQUENCE", + "TEXT", + "TRIGGER", + "TYPE", + "UNLOGGED", + "USER", + ) + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, - "COMMENT ON": TokenType.COMMENT_ON, "IDENTITY": TokenType.IDENTITY, "GENERATED": TokenType.GENERATED, "DOUBLE PRECISION": TokenType.DOUBLE, @@ -176,6 +199,19 @@ class Postgres(Dialect): "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, "UUID": TokenType.UUID, + "TEMP": TokenType.TEMPORARY, + "BEGIN TRANSACTION": TokenType.BEGIN, + "BEGIN": TokenType.COMMAND, + "COMMENT ON": TokenType.COMMAND, + "DECLARE": TokenType.COMMAND, + "DO": TokenType.COMMAND, + "REFRESH": TokenType.COMMAND, + "REINDEX": TokenType.COMMAND, + "RESET": TokenType.COMMAND, + "REVOKE": TokenType.COMMAND, + "GRANT": TokenType.COMMAND, + **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, + **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } QUOTES = ["'", "$$"] SINGLE_TOKENS = { diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 11ea778..9d5cc11 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, ) from sqlglot.dialects.mysql import MySQL +from sqlglot.errors import UnsupportedError from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -61,8 +62,18 @@ def _initcap_sql(self, expression): return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" +def _decode_sql(self, expression): + _ensure_utf8(expression.args.get("charset")) + return f"FROM_UTF8({self.sql(expression, 'this')})" + + +def _encode_sql(self, expression): + _ensure_utf8(expression.args.get("charset")) + return f"TO_UTF8({self.sql(expression, 'this')})" + + def _no_sort_array(self, expression): - if expression.args.get("asc") == exp.FALSE: + if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: comparator = None @@ -72,7 +83,7 @@ def _no_sort_array(self, expression): def _schema_sql(self, expression): if isinstance(expression.parent, exp.Property): - columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions) + columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" for schema in expression.parent.find_all(exp.Schema): @@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression): return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" +def _ensure_utf8(charset): + if charset.name.lower() != "utf-8": + raise UnsupportedError(f"Unsupported charset {charset}") + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" @@ -115,6 +131,7 @@ class Presto(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "ROW": TokenType.STRUCT, } @@ -140,6 +157,14 @@ class Presto(Dialect): "STRPOS": exp.StrPosition.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "FROM_HEX": exp.Unhex.from_arg_list, + "TO_HEX": exp.Hex.from_arg_list, + "TO_UTF8": lambda args: exp.Encode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), + "FROM_UTF8": lambda args: exp.Decode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), } class Generator(generator.Generator): @@ -187,7 +212,10 @@ class Presto(Dialect): exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.Encode: _encode_sql, + exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, @@ -212,7 +240,13 @@ class Presto(Dialect): exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.Unhex: rename_func("FROM_HEX"), exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", } + + def transaction_sql(self, expression): + modes = expression.args.get("modes") + modes = f" {', '.join(modes)}" if modes else "" + return f"START TRANSACTION{modes}" diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index d1aaded..a96bd80 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -148,6 +148,7 @@ class Snowflake(Dialect): **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, } + FUNCTION_PARSERS.pop("TRIM") FUNC_TOKENS = { *parser.Parser.FUNC_TOKENS, @@ -203,6 +204,7 @@ class Snowflake(Dialect): exp.StrPosition: rename_func("POSITION"), exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", } TYPE_MAPPING = { diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 8c9fb76..87b98a5 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -63,3 +63,8 @@ class SQLite(Dialect): exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, } + + def transaction_sql(self, expression): + this = expression.this + this = f" {this}" if this else "" + return f"BEGIN{this} TRANSACTION" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a233d4b..d3b83de 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -248,7 +248,7 @@ class TSQL(Dialect): def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) - this = self._parse_column() + this = self._parse_conjunction() # Retrieve length of datatype and override to default if not specified if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 2d959ab..758ad1b 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import typing as t from collections import defaultdict from dataclasses import dataclass from heapq import heappop, heappush @@ -6,6 +9,10 @@ from sqlglot import Dialect from sqlglot import expressions as exp from sqlglot.helper import ensure_collection +if t.TYPE_CHECKING: + T = t.TypeVar("T") + Edit = t.Union[Insert, Remove, Move, Update, Keep] + @dataclass(frozen=True) class Insert: @@ -44,7 +51,7 @@ class Keep: target: exp.Expression -def diff(source, target): +def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: """ Returns the list of changes between the source and the target expressions. @@ -89,25 +96,25 @@ class ChangeDistiller: Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. """ - def __init__(self, f=0.6, t=0.6): + def __init__(self, f: float = 0.6, t: float = 0.6) -> None: self.f = f self.t = t self._sql_generator = Dialect().generator() - def diff(self, source, target): + def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]: self._source = source self._target = target self._source_index = {id(n[0]): n[0] for n in source.bfs()} self._target_index = {id(n[0]): n[0] for n in target.bfs()} self._unmatched_source_nodes = set(self._source_index) self._unmatched_target_nodes = set(self._target_index) - self._bigram_histo_cache = {} + self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} matching_set = self._compute_matching_set() return self._generate_edit_script(matching_set) - def _generate_edit_script(self, matching_set): - edit_script = [] + def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]: + edit_script: t.List[Edit] = [] for removed_node_id in self._unmatched_source_nodes: edit_script.append(Remove(self._source_index[removed_node_id])) for inserted_node_id in self._unmatched_target_nodes: @@ -125,7 +132,9 @@ class ChangeDistiller: return edit_script - def _generate_move_edits(self, source, target, matching_set): + def _generate_move_edits( + self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]] + ) -> t.List[Move]: source_args = [id(e) for e in _expression_only_args(source)] target_args = [id(e) for e in _expression_only_args(target)] @@ -138,7 +147,7 @@ class ChangeDistiller: return move_edits - def _compute_matching_set(self): + def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]: leaves_matching_set = self._compute_leaf_matching_set() matching_set = leaves_matching_set.copy() @@ -183,8 +192,8 @@ class ChangeDistiller: return matching_set - def _compute_leaf_matching_set(self): - candidate_matchings = [] + def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]: + candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = [] source_leaves = list(_get_leaves(self._source)) target_leaves = list(_get_leaves(self._target)) for source_leaf in source_leaves: @@ -216,7 +225,7 @@ class ChangeDistiller: return matching_set - def _dice_coefficient(self, source, target): + def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float: source_histo = self._bigram_histo(source) target_histo = self._bigram_histo(target) @@ -231,13 +240,13 @@ class ChangeDistiller: return 2 * overlap_len / total_grams - def _bigram_histo(self, expression): + def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]: if id(expression) in self._bigram_histo_cache: return self._bigram_histo_cache[id(expression)] expression_str = self._sql_generator.generate(expression) count = max(0, len(expression_str) - 1) - bigram_histo = defaultdict(int) + bigram_histo: t.DefaultDict[str, int] = defaultdict(int) for i in range(count): bigram_histo[expression_str[i : i + 2]] += 1 @@ -245,7 +254,7 @@ class ChangeDistiller: return bigram_histo -def _get_leaves(expression): +def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]: has_child_exprs = False for a in expression.args.values(): @@ -258,7 +267,7 @@ def _get_leaves(expression): yield expression -def _is_same_type(source, target): +def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: if type(source) is type(target): if isinstance(source, exp.Join): return source.args.get("side") == target.args.get("side") @@ -271,15 +280,17 @@ def _is_same_type(source, target): return False -def _expression_only_args(expression): - args = [] +def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]: + args: t.List[t.Union[exp.Expression, t.List]] = [] if expression: for a in expression.args.values(): args.extend(ensure_collection(a)) return [a for a in args if isinstance(a, exp.Expression)] -def _lcs(seq_a, seq_b, equal): +def _lcs( + seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool] +) -> t.Sequence[t.Optional[T]]: """Calculates the longest common subsequence""" len_a = len(seq_a) @@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal): for i in range(len_a + 1): for j in range(len_b + 1): if i == 0 or j == 0: - lcs_result[i][j] = [] + lcs_result[i][j] = [] # type: ignore elif equal(seq_a[i - 1], seq_b[j - 1]): - lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] + lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore else: lcs_result[i][j] = ( lcs_result[i - 1][j] - if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) + if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore else lcs_result[i][j - 1] ) - return lcs_result[len_a][len_b] + return lcs_result[len_a][len_b] # type: ignore diff --git a/sqlglot/errors.py b/sqlglot/errors.py index 2ef908f..23a08bd 100644 --- a/sqlglot/errors.py +++ b/sqlglot/errors.py @@ -37,6 +37,10 @@ class SchemaError(SqlglotError): pass +class ExecuteError(SqlglotError): + pass + + def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str: msg = [str(e) for e in errors[:maximum]] remaining = len(errors) - maximum diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index e765616..04621b5 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -1,20 +1,23 @@ import logging import time -from sqlglot import parse_one +from sqlglot import maybe_parse +from sqlglot.errors import ExecuteError from sqlglot.executor.python import PythonExecutor +from sqlglot.executor.table import Table, ensure_tables from sqlglot.optimizer import optimize from sqlglot.planner import Plan +from sqlglot.schema import ensure_schema logger = logging.getLogger("sqlglot") -def execute(sql, schema, read=None): +def execute(sql, schema=None, read=None, tables=None): """ Run a sql query against data. Args: - sql (str): a sql statement + sql (str|sqlglot.Expression): a sql statement schema (dict|sqlglot.optimizer.Schema): database schema. This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of the following forms: @@ -23,10 +26,20 @@ def execute(sql, schema, read=None): 3. {catalog: {db: {table: {col: type}}}} read (str): the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + tables (dict): additional tables to register. Returns: sqlglot.executor.Table: Simple columnar data structure. """ - expression = parse_one(sql, read=read) + tables = ensure_tables(tables) + if not schema: + schema = { + name: {column: type(table[0][column]).__name__ for column in table.columns} + for name, table in tables.mapping.items() + } + schema = ensure_schema(schema) + if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args: + raise ExecuteError("Tables must support the same table args as schema") + expression = maybe_parse(sql, dialect=read) now = time.time() expression = optimize(expression, schema, leave_tables_isolated=True) logger.debug("Optimization finished: %f", time.time() - now) @@ -34,6 +47,6 @@ def execute(sql, schema, read=None): plan = Plan(expression) logger.debug("Logical Plan: %s", plan) now = time.time() - result = PythonExecutor().execute(plan) + result = PythonExecutor(tables=tables).execute(plan) logger.debug("Query finished: %f", time.time() - now) return result diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index 393347b..e9ff75b 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -1,5 +1,12 @@ +from __future__ import annotations + +import typing as t + from sqlglot.executor.env import ENV +if t.TYPE_CHECKING: + from sqlglot.executor.table import Table, TableIter + class Context: """ @@ -12,14 +19,14 @@ class Context: evaluation of aggregation functions. """ - def __init__(self, tables, env=None): + def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None: """ Args - tables (dict): table_name -> Table, representing the scope of the current execution context - env (Optional[dict]): dictionary of functions within the execution context + tables: representing the scope of the current execution context. + env: dictionary of functions within the execution context. """ self.tables = tables - self._table = None + self._table: t.Optional[Table] = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} @@ -31,7 +38,7 @@ class Context: return tuple(self.eval(code) for code in codes) @property - def table(self): + def table(self) -> Table: if self._table is None: self._table = list(self.tables.values())[0] for other in self.tables.values(): @@ -41,8 +48,12 @@ class Context: raise Exception(f"Rows are different.") return self._table + def add_columns(self, *columns: str) -> None: + for table in self.tables.values(): + table.add_columns(*columns) + @property - def columns(self): + def columns(self) -> t.Tuple: return self.table.columns def __iter__(self): @@ -52,35 +63,39 @@ class Context: reader = table[i] yield reader, self - def table_iter(self, table): + def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]: self.env["scope"] = self.row_readers for reader in self.tables[table]: yield reader, self - def sort(self, key): - table = self.table + def filter(self, condition) -> None: + rows = [reader.row for reader, _ in self if self.eval(condition)] - def sort_key(row): - table.reader.row = row + for table in self.tables.values(): + table.rows = rows + + def sort(self, key) -> None: + def sort_key(row: t.Tuple) -> t.Tuple: + self.set_row(row) return self.eval_tuple(key) - table.rows.sort(key=sort_key) + self.table.rows.sort(key=sort_key) - def set_row(self, row): + def set_row(self, row: t.Tuple) -> None: for table in self.tables.values(): table.reader.row = row self.env["scope"] = self.row_readers - def set_index(self, index): + def set_index(self, index: int) -> None: for table in self.tables.values(): table[index] self.env["scope"] = self.row_readers - def set_range(self, start, end): + def set_range(self, start: int, end: int) -> None: for name in self.tables: self.range_readers[name].range = range(start, end) self.env["scope"] = self.range_readers - def __contains__(self, table): + def __contains__(self, table: str) -> bool: return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index bbe6c81..ed80cc9 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -1,7 +1,10 @@ import datetime +import inspect import re import statistics +from functools import wraps +from sqlglot import exp from sqlglot.helper import PYTHON_VERSION @@ -16,20 +19,153 @@ class reverse_key: return other.obj < self.obj +def filter_nulls(func): + @wraps(func) + def _func(values): + return func(v for v in values if v is not None) + + return _func + + +def null_if_any(*required): + """ + Decorator that makes a function return `None` if any of the `required` arguments are `None`. + + This also supports decoration with no arguments, e.g.: + + @null_if_any + def foo(a, b): ... + + In which case all arguments are required. + """ + f = None + if len(required) == 1 and callable(required[0]): + f = required[0] + required = () + + def decorator(func): + if required: + required_indices = [ + i for i, param in enumerate(inspect.signature(func).parameters) if param in required + ] + + def predicate(*args): + return any(args[i] is None for i in required_indices) + + else: + + def predicate(*args): + return any(a is None for a in args) + + @wraps(func) + def _func(*args): + if predicate(*args): + return None + return func(*args) + + return _func + + if f: + return decorator(f) + + return decorator + + +@null_if_any("substr", "this") +def str_position(substr, this, position=None): + position = position - 1 if position is not None else position + return this.find(substr, position) + 1 + + +@null_if_any("this") +def substring(this, start=None, length=None): + if start is None: + return this + elif start == 0: + return "" + elif start < 0: + start = len(this) + start + else: + start -= 1 + + end = None if length is None else start + length + + return this[start:end] + + +@null_if_any +def cast(this, to): + if to == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(this) + if to == exp.DataType.Type.DATETIME: + return datetime.datetime.fromisoformat(this) + if to in exp.DataType.TEXT_TYPES: + return str(this) + if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}: + return float(this) + if to in exp.DataType.NUMERIC_TYPES: + return int(this) + raise NotImplementedError(f"Casting to '{to}' not implemented.") + + +def ordered(this, desc, nulls_first): + if desc: + return reverse_key(this) + return this + + +@null_if_any +def interval(this, unit): + if unit == "DAY": + return datetime.timedelta(days=float(this)) + raise NotImplementedError + + ENV = { "__builtins__": {}, - "datetime": datetime, - "locals": locals, - "re": re, - "bool": bool, - "float": float, - "int": int, - "str": str, - "desc": reverse_key, - "SUM": sum, - "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore - "COUNT": lambda acc: sum(1 for e in acc if e is not None), - "MAX": max, - "MIN": min, + "exp": exp, + # aggs + "SUM": filter_nulls(sum), + "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore + "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)), + "MAX": filter_nulls(max), + "MIN": filter_nulls(min), + # scalar functions + "ABS": null_if_any(lambda this: abs(this)), + "ADD": null_if_any(lambda e, this: e + this), + "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), + "BITWISEAND": null_if_any(lambda this, e: this & e), + "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), + "BITWISEOR": null_if_any(lambda this, e: this | e), + "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e), + "BITWISEXOR": null_if_any(lambda this, e: this ^ e), + "CAST": cast, + "COALESCE": lambda *args: next((a for a in args if a is not None), None), + "CONCAT": null_if_any(lambda *args: "".join(args)), + "CONCATWS": null_if_any(lambda this, *args: this.join(args)), + "DIV": null_if_any(lambda e, this: e / this), + "EQ": null_if_any(lambda this, e: this == e), + "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), + "GT": null_if_any(lambda this, e: this > e), + "GTE": null_if_any(lambda this, e: this >= e), + "IFNULL": lambda e, alt: alt if e is None else e, + "IF": lambda predicate, true, false: true if predicate else false, + "INTDIV": null_if_any(lambda e, this: e // this), + "INTERVAL": interval, + "LIKE": null_if_any( + lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this)) + ), + "LOWER": null_if_any(lambda arg: arg.lower()), + "LT": null_if_any(lambda this, e: this < e), + "LTE": null_if_any(lambda this, e: this <= e), + "MOD": null_if_any(lambda e, this: e % this), + "MUL": null_if_any(lambda e, this: e * this), + "NEQ": null_if_any(lambda this, e: this != e), + "ORD": null_if_any(ord), + "ORDERED": ordered, "POW": pow, + "STRPOSITION": str_position, + "SUB": null_if_any(lambda e, this: e - this), + "SUBSTRING": substring, + "UPPER": null_if_any(lambda arg: arg.upper()), } diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 7d1db32..cb2543c 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -5,16 +5,18 @@ import math from sqlglot import exp, generator, planner, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.errors import ExecuteError from sqlglot.executor.context import Context from sqlglot.executor.env import ENV -from sqlglot.executor.table import Table -from sqlglot.helper import csv_reader +from sqlglot.executor.table import RowReader, Table +from sqlglot.helper import csv_reader, subclasses class PythonExecutor: - def __init__(self, env=None): - self.generator = Python().generator(identify=True) + def __init__(self, env=None, tables=None): + self.generator = Python().generator(identify=True, comments=False) self.env = {**ENV, **(env or {})} + self.tables = tables or {} def execute(self, plan): running = set() @@ -24,36 +26,41 @@ class PythonExecutor: while queue: node = queue.pop() - context = self.context( - { - name: table - for dep in node.dependencies - for name, table in contexts[dep].tables.items() - } - ) - running.add(node) - - if isinstance(node, planner.Scan): - contexts[node] = self.scan(node, context) - elif isinstance(node, planner.Aggregate): - contexts[node] = self.aggregate(node, context) - elif isinstance(node, planner.Join): - contexts[node] = self.join(node, context) - elif isinstance(node, planner.Sort): - contexts[node] = self.sort(node, context) - else: - raise NotImplementedError - - running.remove(node) - finished.add(node) - - for dep in node.dependents: - if dep not in running and all(d in contexts for d in dep.dependencies): - queue.add(dep) - - for dep in node.dependencies: - if all(d in finished for d in dep.dependents): - contexts.pop(dep) + try: + context = self.context( + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } + ) + running.add(node) + + if isinstance(node, planner.Scan): + contexts[node] = self.scan(node, context) + elif isinstance(node, planner.Aggregate): + contexts[node] = self.aggregate(node, context) + elif isinstance(node, planner.Join): + contexts[node] = self.join(node, context) + elif isinstance(node, planner.Sort): + contexts[node] = self.sort(node, context) + elif isinstance(node, planner.SetOperation): + contexts[node] = self.set_operation(node, context) + else: + raise NotImplementedError + + running.remove(node) + finished.add(node) + + for dep in node.dependents: + if dep not in running and all(d in contexts for d in dep.dependencies): + queue.add(dep) + + for dep in node.dependencies: + if all(d in finished for d in dep.dependents): + contexts.pop(dep) + except Exception as e: + raise ExecuteError(f"Step '{node.id}' failed: {e}") from e root = plan.root return contexts[root].tables[root.name] @@ -76,38 +83,43 @@ class PythonExecutor: return Context(tables, env=self.env) def table(self, expressions): - return Table(expression.alias_or_name for expression in expressions) + return Table( + expression.alias_or_name if isinstance(expression, exp.Expression) else expression + for expression in expressions + ) def scan(self, step, context): source = step.source - if isinstance(source, exp.Expression): + if source and isinstance(source, exp.Expression): source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) - if source in context: + if source is None: + context, table_iter = self.static() + elif source in context: if not projections and not condition: return self.context({step.name: context.tables[source]}) table_iter = context.table_iter(source) - else: + elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV): table_iter = self.scan_csv(step) + context = next(table_iter) + else: + context, table_iter = self.scan_table(step) if projections: sink = self.table(step.projections) else: - sink = None - - for reader, ctx in table_iter: - if sink is None: - sink = Table(reader.columns) + sink = self.table(context.columns) - if condition and not ctx.eval(condition): + for reader in table_iter: + if condition and not context.eval(condition): continue if projections: - sink.append(ctx.eval_tuple(projections)) + sink.append(context.eval_tuple(projections)) else: sink.append(reader.row) @@ -116,14 +128,23 @@ class PythonExecutor: return self.context({step.name: sink}) + def static(self): + return self.context({}), [RowReader(())] + + def scan_table(self, step): + table = self.tables.find(step.source) + context = self.context({step.source.alias_or_name: table}) + return context, iter(table) + def scan_csv(self, step): - source = step.source - alias = source.alias + alias = step.source.alias + source = step.source.this with csv_reader(source) as reader: columns = next(reader) table = Table(columns) context = self.context({alias: table}) + yield context types = [] for row in reader: @@ -134,7 +155,7 @@ class PythonExecutor: except (ValueError, SyntaxError): types.append(str) context.set_row(tuple(t(v) for t, v in zip(types, row))) - yield context.table.reader, context + yield context.table.reader def join(self, step, context): source = step.name @@ -160,16 +181,19 @@ class PythonExecutor: for name, column_range in column_ranges.items() } ) + condition = self.generate(join["condition"]) + if condition: + source_context.filter(condition) condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) - if not condition or not projections: + if not condition and not projections: return source_context sink = self.table(step.projections if projections else source_context.columns) - for reader, ctx in join_context: + for reader, ctx in source_context: if condition and not ctx.eval(condition): continue @@ -181,7 +205,15 @@ class PythonExecutor: if len(sink) >= step.limit: break - return self.context({step.name: sink}) + if projections: + return self.context({step.name: sink}) + else: + return self.context( + { + name: Table(table.columns, sink.rows, table.column_range) + for name, table in source_context.tables.items() + } + ) def nested_loop_join(self, _join, source_context, join_context): table = Table(source_context.columns + join_context.columns) @@ -195,6 +227,8 @@ class PythonExecutor: def hash_join(self, join, source_context, join_context): source_key = self.generate_tuple(join["source_key"]) join_key = self.generate_tuple(join["join_key"]) + left = join.get("side") == "LEFT" + right = join.get("side") == "RIGHT" results = collections.defaultdict(lambda: ([], [])) @@ -204,28 +238,47 @@ class PythonExecutor: results[ctx.eval_tuple(join_key)][1].append(reader.row) table = Table(source_context.columns + join_context.columns) + nulls = [(None,) * len(join_context.columns if left else source_context.columns)] for a_group, b_group in results.values(): + if left: + b_group = b_group or nulls + elif right: + a_group = a_group or nulls + for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) return table def aggregate(self, step, context): - source = step.source - group_by = self.generate_tuple(step.group) + group_by = self.generate_tuple(step.group.values()) aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) if operands: - source_table = context.tables[source] - operand_table = Table(source_table.columns + self.table(step.operands).columns) + operand_table = Table(self.table(step.operands).columns) for reader, ctx in context: - operand_table.append(reader.row + ctx.eval_tuple(operands)) + operand_table.append(ctx.eval_tuple(operands)) + + for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)): + context.table.rows[i] = a + b + + width = len(context.columns) + context.add_columns(*operand_table.columns) + + operand_table = Table( + context.columns, + context.table.rows, + range(width, width + len(operand_table.columns)), + ) context = self.context( - {None: operand_table, **{table: operand_table for table in context.tables}} + { + None: operand_table, + **context.tables, + } ) context.sort(group_by) @@ -233,25 +286,22 @@ class PythonExecutor: group = None start = 0 end = 1 - length = len(context.tables[source]) - table = self.table(step.group + step.aggregations) + length = len(context.table) + table = self.table(list(step.group) + step.aggregations) for i in range(length): context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 - + if key != group: + context.set_range(start, end - 2) + table.append(group + context.eval_tuple(aggregations)) + group = key + start = end - 2 if i == length - 1: context.set_range(start, end - 1) - elif key != group: - context.set_range(start, end - 2) - else: - continue - - table.append(group + context.eval_tuple(aggregations)) - group = key - start = end - 2 + table.append(group + context.eval_tuple(aggregations)) context = self.context({step.name: table, **{name: table for name in context.tables}}) @@ -262,60 +312,77 @@ class PythonExecutor: def sort(self, step, context): projections = self.generate_tuple(step.projections) - sink = self.table(step.projections) + projection_columns = [p.alias_or_name for p in step.projections] + all_columns = list(context.columns) + projection_columns + sink = self.table(all_columns) for reader, ctx in context: - sink.append(ctx.eval_tuple(projections)) + sink.append(reader.row + ctx.eval_tuple(projections)) - context = self.context( + sort_ctx = self.context( { None: sink, **{table: sink for table in context.tables}, } ) - context.sort(self.generate_tuple(step.key)) + sort_ctx.sort(self.generate_tuple(step.key)) if not math.isinf(step.limit): - context.table.rows = context.table.rows[0 : step.limit] + sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit] - return self.context({step.name: context.table}) + output = Table( + projection_columns, + rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows], + ) + return self.context({step.name: output}) + def set_operation(self, step, context): + left = context.tables[step.left] + right = context.tables[step.right] -def _cast_py(self, expression): - to = expression.args["to"].this - this = self.sql(expression, "this") + sink = self.table(left.columns) + + if issubclass(step.op, exp.Intersect): + sink.rows = list(set(left.rows).intersection(set(right.rows))) + elif issubclass(step.op, exp.Except): + sink.rows = list(set(left.rows).difference(set(right.rows))) + elif issubclass(step.op, exp.Union) and step.distinct: + sink.rows = list(set(left.rows).union(set(right.rows))) + else: + sink.rows = left.rows + right.rows - if to == exp.DataType.Type.DATE: - return f"datetime.date.fromisoformat({this})" - if to == exp.DataType.Type.TEXT: - return f"str({this})" - raise NotImplementedError + return self.context({step.name: sink}) -def _column_py(self, expression): - table = self.sql(expression, "table") or None +def _ordered_py(self, expression): this = self.sql(expression, "this") - return f"scope[{table}][{this}]" + desc = "True" if expression.args.get("desc") else "False" + nulls_first = "True" if expression.args.get("nulls_first") else "False" + return f"ORDERED({this}, {desc}, {nulls_first})" -def _interval_py(self, expression): - this = self.sql(expression, "this") - unit = expression.text("unit").upper() - if unit == "DAY": - return f"datetime.timedelta(days=float({this}))" - raise NotImplementedError +def _rename(self, e): + try: + if "expressions" in e.args: + this = self.sql(e, "this") + this = f"{this}, " if this else "" + return f"{e.key.upper()}({this}{self.expressions(e)})" + return f"{e.key.upper()}({self.format_args(*e.args.values())})" + except Exception as ex: + raise Exception(f"Could not rename {repr(e)}") from ex -def _like_py(self, expression): +def _case_sql(self, expression): this = self.sql(expression, "this") - expression = self.sql(expression, "expression") - return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))""" + chain = self.sql(expression, "default") or "None" + for e in reversed(expression.args["ifs"]): + true = self.sql(e, "true") + condition = self.sql(e, "this") + condition = f"{this} = ({condition})" if this else condition + chain = f"{true} if {condition} else ({chain})" -def _ordered_py(self, expression): - this = self.sql(expression, "this") - desc = expression.args.get("desc") - return f"desc({this})" if desc else this + return chain class Python(Dialect): @@ -324,32 +391,22 @@ class Python(Dialect): class Generator(generator.Generator): TRANSFORMS = { + **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)}, + **{klass: _rename for klass in exp.ALL_FUNCTIONS}, + exp.Case: _case_sql, exp.Alias: lambda self, e: self.sql(e.this), exp.Array: inline_array_sql, exp.And: lambda self, e: self.binary(e, "and"), + exp.Between: _rename, exp.Boolean: lambda self, e: "True" if e.this else "False", - exp.Cast: _cast_py, - exp.Column: _column_py, - exp.EQ: lambda self, e: self.binary(e, "=="), + exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", + exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", - exp.Interval: _interval_py, exp.Is: lambda self, e: self.binary(e, "is"), - exp.Like: _like_py, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", exp.Or: lambda self, e: self.binary(e, "or"), exp.Ordered: _ordered_py, exp.Star: lambda *_: "1", } - - def case_sql(self, expression): - this = self.sql(expression, "this") - chain = self.sql(expression, "default") or "None" - - for e in reversed(expression.args["ifs"]): - true = self.sql(e, "true") - condition = self.sql(e, "this") - condition = f"{this} = ({condition})" if this else condition - chain = f"{true} if {condition} else ({chain})" - - return chain diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 6796740..f1b5b54 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,14 +1,27 @@ +from __future__ import annotations + +from sqlglot.helper import dict_depth +from sqlglot.schema import AbstractMappingSchema + + class Table: def __init__(self, columns, rows=None, column_range=None): self.columns = tuple(columns) self.column_range = column_range self.reader = RowReader(self.columns, self.column_range) - self.rows = rows or [] if rows: assert len(rows[0]) == len(self.columns) self.range_reader = RangeReader(self) + def add_columns(self, *columns: str) -> None: + self.columns += columns + if self.column_range: + self.column_range = range( + self.column_range.start, self.column_range.stop + len(columns) + ) + self.reader = RowReader(self.columns, self.column_range) + def append(self, row): assert len(row) == len(self.columns) self.rows.append(row) @@ -87,3 +100,31 @@ class RowReader: def __getitem__(self, column): return self.row[self.columns[column]] + + +class Tables(AbstractMappingSchema[Table]): + pass + + +def ensure_tables(d: dict | None) -> Tables: + return Tables(_ensure_tables(d)) + + +def _ensure_tables(d: dict | None) -> dict: + if not d: + return {} + + depth = dict_depth(d) + + if depth > 1: + return {k: _ensure_tables(v) for k, v in d.items()} + + result = {} + for name, table in d.items(): + if isinstance(table, Table): + result[name] = table + else: + columns = tuple(table[0]) if table else () + rows = [tuple(row[c] for c in columns) for row in table] + result[name] = Table(columns=columns, rows=rows) + return result diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 57a2c88..beafca8 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -641,9 +641,11 @@ class Set(Expression): class SetItem(Expression): arg_types = { - "this": True, + "this": False, + "expressions": False, "kind": False, "collate": False, # MySQL SET NAMES statement + "global": False, } @@ -787,6 +789,7 @@ class Drop(Expression): "exists": False, "temporary": False, "materialized": False, + "cascade": False, } @@ -1073,6 +1076,18 @@ class FileFormatProperty(Property): pass +class DistKeyProperty(Property): + pass + + +class SortKeyProperty(Property): + pass + + +class DistStyleProperty(Property): + pass + + class LocationProperty(Property): pass @@ -1130,6 +1145,9 @@ class Properties(Expression): "LOCATION": LocationProperty, "PARTITIONED_BY": PartitionedByProperty, "TABLE_FORMAT": TableFormatProperty, + "DISTKEY": DistKeyProperty, + "DISTSTYLE": DistStyleProperty, + "SORTKEY": SortKeyProperty, } @classmethod @@ -1356,7 +1374,7 @@ class Var(Expression): class Schema(Expression): - arg_types = {"this": False, "expressions": True} + arg_types = {"this": False, "expressions": False} class Select(Subqueryable): @@ -1741,7 +1759,7 @@ class Select(Subqueryable): ) if join_alias: - join.set("this", alias_(join.args["this"], join_alias, table=True)) + join.set("this", alias_(join.this, join_alias, table=True)) return _apply_list_builder( join, instance=self, @@ -1884,6 +1902,7 @@ class Subquery(DerivedTable, Unionable): arg_types = { "this": True, "alias": False, + "with": False, **QUERY_MODIFIERS, } @@ -2025,6 +2044,31 @@ class DataType(Expression): NULL = auto() UNKNOWN = auto() # Sentinel value, useful for type annotation + TEXT_TYPES = { + Type.CHAR, + Type.NCHAR, + Type.VARCHAR, + Type.NVARCHAR, + Type.TEXT, + } + + NUMERIC_TYPES = { + Type.INT, + Type.TINYINT, + Type.SMALLINT, + Type.BIGINT, + Type.FLOAT, + Type.DOUBLE, + } + + TEMPORAL_TYPES = { + Type.TIMESTAMP, + Type.TIMESTAMPTZ, + Type.TIMESTAMPLTZ, + Type.DATE, + Type.DATETIME, + } + @classmethod def build(cls, dtype, **kwargs) -> DataType: return DataType( @@ -2054,16 +2098,25 @@ class Exists(SubqueryPredicate): pass -# Commands to interact with the databases or engines -# These expressions don't truly parse the expression and consume -# whatever exists as a string until the end or a semicolon +# Commands to interact with the databases or engines. For most of the command +# expressions we parse whatever comes after the command's name as a string. class Command(Expression): arg_types = {"this": True, "expression": False} -# Binary Expressions -# (ADD a b) -# (FROM table selects) +class Transaction(Command): + arg_types = {"this": False, "modes": False} + + +class Commit(Command): + arg_types = {} # type: ignore + + +class Rollback(Command): + arg_types = {"savepoint": False} + + +# Binary expressions like (ADD a b) class Binary(Expression): arg_types = {"this": True, "expression": True} @@ -2215,7 +2268,7 @@ class Not(Unary, Condition): class Paren(Unary, Condition): - pass + arg_types = {"this": True, "with": False} class Neg(Unary): @@ -2428,6 +2481,10 @@ class Cast(Func): return self.args["to"] +class Collate(Binary): + pass + + class TryCast(Cast): pass @@ -2442,13 +2499,17 @@ class Coalesce(Func): is_var_len_args = True -class ConcatWs(Func): - arg_types = {"expressions": False} +class Concat(Func): + arg_types = {"expressions": True} is_var_len_args = True +class ConcatWs(Concat): + _sql_names = ["CONCAT_WS"] + + class Count(AggFunc): - pass + arg_types = {"this": False} class CurrentDate(Func): @@ -2556,10 +2617,18 @@ class Day(Func): pass +class Decode(Func): + arg_types = {"this": True, "charset": True} + + class DiToDate(Func): pass +class Encode(Func): + arg_types = {"this": True, "charset": True} + + class Exp(Func): pass @@ -2581,6 +2650,10 @@ class GroupConcat(Func): arg_types = {"this": True, "separator": False} +class Hex(Func): + pass + + class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -2641,7 +2714,7 @@ class Log10(Func): class Lower(Func): - pass + _sql_names = ["LOWER", "LCASE"] class Map(Func): @@ -2686,6 +2759,12 @@ class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False} +class ReadCSV(Func): + _sql_names = ["READ_CSV"] + is_var_len_args = True + arg_types = {"this": True, "expressions": False} + + class Reduce(Func): arg_types = {"this": True, "initial": True, "merge": True, "finish": True} @@ -2804,8 +2883,8 @@ class TimeStrToUnix(Func): class Trim(Func): arg_types = { "this": True, - "position": False, "expression": False, + "position": False, "collation": False, } @@ -2826,6 +2905,10 @@ class TsOrDiToDi(Func): pass +class Unhex(Func): + pass + + class UnixToStr(Func): arg_types = {"this": True, "format": False} @@ -2843,7 +2926,7 @@ class UnixToTimeStr(Func): class Upper(Func): - pass + _sql_names = ["UPPER", "UCASE"] class Variance(AggFunc): @@ -3701,6 +3784,19 @@ def replace_placeholders(expression, *args, **kwargs): return expression.transform(_replace_placeholders, iter(args), **kwargs) +def true(): + return Boolean(this=True) + + +def false(): + return Boolean(this=False) + + +def null(): + return Null() + + +# TODO: deprecate this TRUE = Boolean(this=True) FALSE = Boolean(this=False) NULL = Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 11d9073..ffb34eb 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -67,7 +67,7 @@ class Generator: exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), - exp.VolatilityProperty: lambda self, e: self.sql(e.name), + exp.VolatilityProperty: lambda self, e: e.name, } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -94,6 +94,9 @@ class Generator: ROOT_PROPERTIES = { exp.ReturnsProperty, exp.LanguageProperty, + exp.DistStyleProperty, + exp.DistKeyProperty, + exp.SortKeyProperty, } WITH_PROPERTIES = { @@ -241,7 +244,7 @@ class Generator: if not NEWLINE_RE.search(comment): return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" - return f"/*{comment}*/\n{sql}" + return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/" def wrap(self, expression): this_sql = self.indent( @@ -475,7 +478,8 @@ class Generator: exists_sql = " IF EXISTS " if expression.args.get("exists") else " " temporary = " TEMPORARY" if expression.args.get("temporary") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" - return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}" + cascade = " CASCADE" if expression.args.get("cascade") else "" + return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}" def except_sql(self, expression): return self.prepend_ctes( @@ -915,13 +919,15 @@ class Generator: def subquery_sql(self, expression): alias = self.sql(expression, "alias") - return self.query_modifiers( + sql = self.query_modifiers( expression, self.wrap(expression), self.expressions(expression, key="pivots", sep=" "), f" AS {alias}" if alias else "", ) + return self.prepend_ctes(expression, sql) + def qualify_sql(self, expression): this = self.indent(self.sql(expression, "this")) return f"{self.seg('QUALIFY')}{self.sep()}{this}" @@ -1111,9 +1117,12 @@ class Generator: def paren_sql(self, expression): if isinstance(expression.unnest(), exp.Select): - return self.wrap(expression) - sql = self.seg(self.indent(self.sql(expression, "this")), sep="") - return f"({sql}{self.seg(')', sep='')}" + sql = self.wrap(expression) + else: + sql = self.seg(self.indent(self.sql(expression, "this")), sep="") + sql = f"({sql}{self.seg(')', sep='')}" + + return self.prepend_ctes(expression, sql) def neg_sql(self, expression): return f"-{self.sql(expression, 'this')}" @@ -1173,9 +1182,23 @@ class Generator: zone = self.sql(expression, "this") return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" + def collate_sql(self, expression): + return self.binary(expression, "COLLATE") + def command_sql(self, expression): return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" + def transaction_sql(self, *_): + return "BEGIN" + + def commit_sql(self, *_): + return "COMMIT" + + def rollback_sql(self, expression): + savepoint = expression.args.get("savepoint") + savepoint = f" TO {savepoint}" if savepoint else "" + return f"ROLLBACK{savepoint}" + def distinct_sql(self, expression): this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1193,10 +1216,7 @@ class Generator: def intdiv_sql(self, expression): return self.sql( exp.Cast( - this=exp.Div( - this=expression.args["this"], - expression=expression.args["expression"], - ), + this=exp.Div(this=expression.this, expression=expression.expression), to=exp.DataType(this=exp.DataType.Type.INT), ) ) diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 379c2e7..8c5808d 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -11,7 +11,8 @@ from copy import copy from enum import Enum if t.TYPE_CHECKING: - from sqlglot.expressions import Expression, Table + from sqlglot import exp + from sqlglot.expressions import Expression T = t.TypeVar("T") E = t.TypeVar("E", bound=Expression) @@ -150,7 +151,7 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: if expression.is_int: expression = expression.copy() logger.warning("Applying array index offset (%s)", offset) - expression.args["this"] = str(int(expression.args["this"]) + offset) + expression.args["this"] = str(int(expression.this) + offset) return [expression] return expressions @@ -228,19 +229,18 @@ def open_file(file_name: str) -> t.TextIO: @contextmanager -def csv_reader(table: Table) -> t.Any: +def csv_reader(read_csv: exp.ReadCSV) -> t.Any: """ Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. Args: - table: a `Table` expression with an anonymous function `READ_CSV` in it. + read_csv: a `ReadCSV` function call Yields: A python csv reader. """ - file, *args = table.this.expressions - file = file.name - file = open_file(file) + args = read_csv.expressions + file = open_file(read_csv.name) delimiter = "," args = iter(arg.name for arg in args) @@ -354,3 +354,34 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, yield from flatten(value) else: yield value + + +def dict_depth(d: t.Dict) -> int: + """ + Get the nesting depth of a dictionary. + + For example: + >>> dict_depth(None) + 0 + >>> dict_depth({}) + 1 + >>> dict_depth({"a": "b"}) + 1 + >>> dict_depth({"a": {}}) + 2 + >>> dict_depth({"a": {"b": {}}}) + 3 + + Args: + d (dict): dictionary + Returns: + int: depth + """ + try: + return 1 + dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 96331e2..191ea52 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -245,23 +245,31 @@ class TypeAnnotator: def annotate(self, expression): if isinstance(expression, self.TRAVERSABLES): for scope in traverse_scope(expression): - subscope_selects = { - name: {select.alias_or_name: select for select in source.selects} - for name, source in scope.sources.items() - if isinstance(source, Scope) - } - + selects = {} + for name, source in scope.sources.items(): + if not isinstance(source, Scope): + continue + if isinstance(source.expression, exp.Values): + selects[name] = { + alias: column + for alias, column in zip( + source.expression.alias_column_names, + source.expression.expressions[0].expressions, + ) + } + else: + selects[name] = { + select.alias_or_name: select for select in source.expression.selects + } # First annotate the current scope's column references for col in scope.columns: source = scope.sources[col.table] if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) else: - col.type = subscope_selects[col.table][col.name].type - + col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) - return self._maybe_annotate(expression) # This takes care of non-traversable expressions def _maybe_annotate(self, expression): diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py new file mode 100644 index 0000000..9b3d98a --- /dev/null +++ b/sqlglot/optimizer/canonicalize.py @@ -0,0 +1,48 @@ +import itertools + +from sqlglot import exp + + +def canonicalize(expression: exp.Expression) -> exp.Expression: + """Converts a sql expression into a standard form. + + This method relies on annotate_types because many of the + conversions rely on type inference. + + Args: + expression: The expression to canonicalize. + """ + exp.replace_children(expression, canonicalize) + expression = add_text_to_concat(expression) + expression = coerce_type(expression) + return expression + + +def add_text_to_concat(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES: + node = exp.Concat(this=node.this, expression=node.expression) + return node + + +def coerce_type(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Binary): + _coerce_date(node.left, node.right) + elif isinstance(node, exp.Between): + _coerce_date(node.this, node.args["low"]) + elif isinstance(node, exp.Extract): + if node.expression.type not in exp.DataType.TEMPORAL_TYPES: + _replace_cast(node.expression, "datetime") + return node + + +def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: + for a, b in itertools.permutations([a, b]): + if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE: + _replace_cast(b, "date") + + +def _replace_cast(node: exp.Expression, to: str) -> None: + data_type = exp.DataType.build(to) + cast = exp.Cast(this=node.copy(), to=data_type) + cast.type = data_type + node.replace(cast) diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 29621af..de4e011 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -128,8 +128,8 @@ def join_condition(join): Tuple of (source key, join key, remaining predicate) """ name = join.this.alias_or_name - on = join.args.get("on") or exp.TRUE - on = on.copy() + on = (join.args.get("on") or exp.true()).copy() + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) source_key = [] join_key = [] @@ -141,7 +141,7 @@ def join_condition(join): # # should pull y.b as the join key and x.a as the source key if normalized(on): - for condition in on.flatten() if isinstance(on, exp.And) else [on]: + for condition in on.flatten(): if isinstance(condition, exp.EQ): left, right = condition.unnest_operands() left_tables = exp.column_table_names(left) @@ -150,13 +150,12 @@ def join_condition(join): if name in left_tables and name not in right_tables: join_key.append(left) source_key.append(right) - condition.replace(exp.TRUE) + condition.replace(exp.true()) elif name in right_tables and name not in left_tables: join_key.append(right) source_key.append(left) - condition.replace(exp.TRUE) + condition.replace(exp.true()) on = simplify(on) - remaining_condition = None if on == exp.TRUE else on - + remaining_condition = None if on == exp.true() else on return source_key, join_key, remaining_condition diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 40e4ab1..fd69832 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -29,7 +29,7 @@ def optimize_joins(expression): if isinstance(on, exp.Connector): for predicate in on.flatten(): if name in exp.column_table_names(predicate): - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) join.on(predicate, copy=False) expression = reorder_joins(expression) @@ -70,6 +70,6 @@ def normalize(expression): def other_table_names(join, exclude): return [ name - for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + for name in (exp.column_table_names(join.args.get("on") or exp.true())) if name != exclude ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index b2ed062..d0e38cd 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,4 +1,6 @@ import sqlglot +from sqlglot.optimizer.annotate_types import annotate_types +from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries @@ -28,6 +30,8 @@ RULES = ( merge_subqueries, eliminate_joins, eliminate_ctes, + annotate_types, + canonicalize, quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 6364f65..f92e5c3 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count): for predicate in predicates: for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) node.on(predicate, copy=False) break if isinstance(node, exp.Select): - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) node.where(replace_aliases(node, predicate), copy=False) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 69fe2b8..e6e6dc9 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -382,9 +382,7 @@ class _Resolver: raise OptimizeError(str(e)) from e if isinstance(source, Scope) and isinstance(source.expression, exp.Values): - values_alias = source.expression.parent - if hasattr(values_alias, "alias_column_names"): - return values_alias.alias_column_names + return source.expression.alias_column_names # Otherwise, if referencing another scope, return that scope's named selects return source.expression.named_selects diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 0e467d3..5d8e0d9 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,10 +1,11 @@ import itertools from sqlglot import alias, exp +from sqlglot.helper import csv_reader from sqlglot.optimizer.scope import traverse_scope -def qualify_tables(expression, db=None, catalog=None): +def qualify_tables(expression, db=None, catalog=None, schema=None): """ Rewrite sqlglot AST to have fully qualified tables. @@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None): expression (sqlglot.Expression): expression to qualify db (str): Database name catalog (str): Catalog name + schema: A schema to populate Returns: sqlglot.Expression: qualified expression """ @@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None): source.set("catalog", exp.to_identifier(catalog)) if not source.alias: - source.replace( + source = source.replace( alias( source.copy(), source.this if identifier else f"_q_{next(sequence)}", @@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None): ) ) + if schema and isinstance(source.this, exp.ReadCSV): + with csv_reader(source.this) as reader: + header = next(reader) + columns = next(reader) + schema.add_table( + source, {k: type(v).__name__ for k, v in zip(header, columns)} + ) + return expression diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d759e86..c432c59 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -189,11 +189,11 @@ def absorb_and_eliminate(expression): # absorb if is_complement(b, aa): - aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) + aa.replace(exp.true() if kind == exp.And else exp.false()) elif is_complement(b, ab): - ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) + ab.replace(exp.true() if kind == exp.And else exp.false()) elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): - a.replace(exp.FALSE if kind == exp.And else exp.TRUE) + a.replace(exp.false() if kind == exp.And else exp.true()) elif isinstance(b, kind): # eliminate rhs = b.unnest_operands() diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index f41a84e..dbd680b 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence): select.parent.replace(alias) for key, column, predicate in keys: - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) nested = exp.column(key_aliases[key], table_alias) if key in group_by: diff --git a/sqlglot/parser.py b/sqlglot/parser.py index bbea0e5..5b93510 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -141,26 +141,29 @@ class Parser(metaclass=_Parser): ID_VAR_TOKENS = { TokenType.VAR, - TokenType.ALTER, TokenType.ALWAYS, TokenType.ANTI, TokenType.APPLY, + TokenType.AUTO_INCREMENT, TokenType.BEGIN, TokenType.BOTH, TokenType.BUCKET, TokenType.CACHE, - TokenType.CALL, + TokenType.CASCADE, TokenType.COLLATE, + TokenType.COMMAND, TokenType.COMMIT, TokenType.CONSTRAINT, + TokenType.CURRENT_TIME, TokenType.DEFAULT, TokenType.DELETE, TokenType.DESCRIBE, TokenType.DETERMINISTIC, + TokenType.DISTKEY, + TokenType.DISTSTYLE, TokenType.EXECUTE, TokenType.ENGINE, TokenType.ESCAPE, - TokenType.EXPLAIN, TokenType.FALSE, TokenType.FIRST, TokenType.FOLLOWING, @@ -182,7 +185,6 @@ class Parser(metaclass=_Parser): TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, - TokenType.OPTIMIZE, TokenType.OPTIONS, TokenType.ORDINALITY, TokenType.PARTITIONED_BY, @@ -199,6 +201,7 @@ class Parser(metaclass=_Parser): TokenType.SEMI, TokenType.SET, TokenType.SHOW, + TokenType.SORTKEY, TokenType.STABLE, TokenType.STORED, TokenType.TABLE, @@ -207,7 +210,6 @@ class Parser(metaclass=_Parser): TokenType.TRANSIENT, TokenType.TOP, TokenType.TRAILING, - TokenType.TRUNCATE, TokenType.TRUE, TokenType.UNBOUNDED, TokenType.UNIQUE, @@ -217,6 +219,7 @@ class Parser(metaclass=_Parser): TokenType.VOLATILE, *SUBQUERY_PREDICATES, *TYPE_TOKENS, + *NO_PAREN_FUNCTIONS, } TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} @@ -231,6 +234,7 @@ class Parser(metaclass=_Parser): TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, + TokenType.IDENTIFIER, TokenType.ISNULL, TokenType.OFFSET, TokenType.PRIMARY_KEY, @@ -242,6 +246,7 @@ class Parser(metaclass=_Parser): TokenType.RIGHT, TokenType.DATE, TokenType.DATETIME, + TokenType.TABLE, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, *TYPE_TOKENS, @@ -277,6 +282,7 @@ class Parser(metaclass=_Parser): TokenType.DASH: exp.Sub, TokenType.PLUS: exp.Add, TokenType.MOD: exp.Mod, + TokenType.COLLATE: exp.Collate, } FACTOR = { @@ -391,7 +397,10 @@ class Parser(metaclass=_Parser): TokenType.DELETE: lambda self: self._parse_delete(), TokenType.CACHE: lambda self: self._parse_cache(), TokenType.UNCACHE: lambda self: self._parse_uncache(), - TokenType.USE: lambda self: self._parse_use(), + TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), + TokenType.BEGIN: lambda self: self._parse_transaction(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), + TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), } PRIMARY_PARSERS = { @@ -402,7 +411,8 @@ class Parser(metaclass=_Parser): exp.Literal, this=token.text, is_string=False ), TokenType.STAR: lambda self, _: self.expression( - exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} + exp.Star, + **{"except": self._parse_except(), "replace": self._parse_replace()}, ), TokenType.NULL: lambda self, _: self.expression(exp.Null), TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), @@ -446,6 +456,9 @@ class Parser(metaclass=_Parser): TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), TokenType.STORED: lambda self: self._parse_stored(), + TokenType.DISTKEY: lambda self: self._parse_distkey(), + TokenType.DISTSTYLE: lambda self: self._parse_diststyle(), + TokenType.SORTKEY: lambda self: self._parse_sortkey(), TokenType.RETURNS: lambda self: self._parse_returns(), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), @@ -471,7 +484,9 @@ class Parser(metaclass=_Parser): } CONSTRAINT_PARSERS = { - TokenType.CHECK: lambda self: self._parse_check(), + TokenType.CHECK: lambda self: self.expression( + exp.Check, this=self._parse_wrapped(self._parse_conjunction) + ), TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), TokenType.UNIQUE: lambda self: self._parse_unique(), } @@ -521,6 +536,8 @@ class Parser(metaclass=_Parser): TokenType.SCHEMA, } + TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + STRICT_CAST = True __slots__ = ( @@ -740,6 +757,7 @@ class Parser(metaclass=_Parser): kind=kind, temporary=temporary, materialized=materialized, + cascade=self._match(TokenType.CASCADE), ) def _parse_exists(self, not_=False): @@ -777,7 +795,11 @@ class Parser(metaclass=_Parser): expression = self._parse_select_or_expression() elif create_token.token_type == TokenType.INDEX: this = self._parse_index() - elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA): + elif create_token.token_type in ( + TokenType.TABLE, + TokenType.VIEW, + TokenType.SCHEMA, + ): this = self._parse_table(schema=True) properties = self._parse_properties() if self._match(TokenType.ALIAS): @@ -834,7 +856,38 @@ class Parser(metaclass=_Parser): return self.expression( exp.FileFormatProperty, this=exp.Literal.string("FORMAT"), - value=exp.Literal.string(self._parse_var().name), + value=exp.Literal.string(self._parse_var_or_string().name), + ) + + def _parse_distkey(self): + self._match_l_paren() + this = exp.Literal.string("DISTKEY") + value = exp.Literal.string(self._parse_var().name) + self._match_r_paren() + return self.expression( + exp.DistKeyProperty, + this=this, + value=value, + ) + + def _parse_sortkey(self): + self._match_l_paren() + this = exp.Literal.string("SORTKEY") + value = exp.Literal.string(self._parse_var().name) + self._match_r_paren() + return self.expression( + exp.SortKeyProperty, + this=this, + value=value, + ) + + def _parse_diststyle(self): + this = exp.Literal.string("DISTSTYLE") + value = exp.Literal.string(self._parse_var().name) + return self.expression( + exp.DistStyleProperty, + this=this, + value=value, ) def _parse_auto_increment(self): @@ -842,7 +895,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.AutoIncrementProperty, this=exp.Literal.string("AUTO_INCREMENT"), - value=self._parse_var() or self._parse_number(), + value=self._parse_number(), ) def _parse_schema_comment(self): @@ -898,13 +951,10 @@ class Parser(metaclass=_Parser): while True: if self._match(TokenType.WITH): - self._match_l_paren() - properties.extend(self._parse_csv(lambda: self._parse_property())) - self._match_r_paren() + properties.extend(self._parse_wrapped_csv(self._parse_property)) elif self._match(TokenType.PROPERTIES): - self._match_l_paren() properties.extend( - self._parse_csv( + self._parse_wrapped_csv( lambda: self.expression( exp.AnonymousProperty, this=self._parse_string(), @@ -912,25 +962,24 @@ class Parser(metaclass=_Parser): ) ) ) - self._match_r_paren() else: identified_property = self._parse_property() if not identified_property: break properties.append(identified_property) + if properties: return self.expression(exp.Properties, expressions=properties) return None def _parse_describe(self): self._match(TokenType.TABLE) - return self.expression(exp.Describe, this=self._parse_id_var()) def _parse_insert(self): overwrite = self._match(TokenType.OVERWRITE) local = self._match(TokenType.LOCAL) - if self._match_text("DIRECTORY"): + if self._match_text_seq("DIRECTORY"): this = self.expression( exp.Directory, this=self._parse_var_or_string(), @@ -954,27 +1003,27 @@ class Parser(metaclass=_Parser): if not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None - self._match_text("DELIMITED") + self._match_text_seq("DELIMITED") kwargs = {} - if self._match_text("FIELDS", "TERMINATED", "BY"): + if self._match_text_seq("FIELDS", "TERMINATED", "BY"): kwargs["fields"] = self._parse_string() - if self._match_text("ESCAPED", "BY"): + if self._match_text_seq("ESCAPED", "BY"): kwargs["escaped"] = self._parse_string() - if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"): + if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"): kwargs["collection_items"] = self._parse_string() - if self._match_text("MAP", "KEYS", "TERMINATED", "BY"): + if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"): kwargs["map_keys"] = self._parse_string() - if self._match_text("LINES", "TERMINATED", "BY"): + if self._match_text_seq("LINES", "TERMINATED", "BY"): kwargs["lines"] = self._parse_string() - if self._match_text("NULL", "DEFINED", "AS"): + if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() return self.expression(exp.RowFormat, **kwargs) def _parse_load_data(self): local = self._match(TokenType.LOCAL) - self._match_text("INPATH") + self._match_text_seq("INPATH") inpath = self._parse_string() overwrite = self._match(TokenType.OVERWRITE) self._match_pair(TokenType.INTO, TokenType.TABLE) @@ -986,8 +1035,8 @@ class Parser(metaclass=_Parser): overwrite=overwrite, inpath=inpath, partition=self._parse_partition(), - input_format=self._match_text("INPUTFORMAT") and self._parse_string(), - serde=self._match_text("SERDE") and self._parse_string(), + input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), + serde=self._match_text_seq("SERDE") and self._parse_string(), ) def _parse_delete(self): @@ -996,9 +1045,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Delete, this=self._parse_table(schema=True), - using=self._parse_csv( - lambda: self._match(TokenType.USING) and self._parse_table(schema=True) - ), + using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()), where=self._parse_where(), ) @@ -1029,12 +1076,7 @@ class Parser(metaclass=_Parser): options = [] if self._match(TokenType.OPTIONS): - self._match_l_paren() - k = self._parse_string() - self._match(TokenType.EQ) - v = self._parse_string() - options = [k, v] - self._match_r_paren() + options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ) self._match(TokenType.ALIAS) return self.expression( @@ -1050,27 +1092,13 @@ class Parser(metaclass=_Parser): return None def parse_values(): - key = self._parse_var() - value = None - - if self._match(TokenType.EQ): - value = self._parse_string() - - return exp.Property(this=key, value=value) - - self._match_l_paren() - values = self._parse_csv(parse_values) - self._match_r_paren() + props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ) + return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1)) - return self.expression( - exp.Partition, - this=values, - ) + return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) def _parse_value(self): - self._match_l_paren() - expressions = self._parse_csv(self._parse_conjunction) - self._match_r_paren() + expressions = self._parse_wrapped_csv(self._parse_conjunction) return self.expression(exp.Tuple, expressions=expressions) def _parse_select(self, nested=False, table=False): @@ -1124,10 +1152,11 @@ class Parser(metaclass=_Parser): self._match_r_paren() this = self._parse_subquery(this) elif self._match(TokenType.VALUES): - this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value)) - alias = self._parse_table_alias() - if alias: - this = self.expression(exp.Subquery, this=this, alias=alias) + this = self.expression( + exp.Values, + expressions=self._parse_csv(self._parse_value), + alias=self._parse_table_alias(), + ) else: this = None @@ -1140,7 +1169,6 @@ class Parser(metaclass=_Parser): recursive = self._match(TokenType.RECURSIVE) expressions = [] - while True: expressions.append(self._parse_cte()) @@ -1149,11 +1177,7 @@ class Parser(metaclass=_Parser): else: self._match(TokenType.WITH) - return self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ) + return self.expression(exp.With, expressions=expressions, recursive=recursive) def _parse_cte(self): alias = self._parse_table_alias() @@ -1163,13 +1187,9 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.ALIAS): self.raise_error("Expected AS in CTE") - self._match_l_paren() - expression = self._parse_statement() - self._match_r_paren() - return self.expression( exp.CTE, - this=expression, + this=self._parse_wrapped(self._parse_statement), alias=alias, ) @@ -1223,7 +1243,7 @@ class Parser(metaclass=_Parser): def _parse_hint(self): if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) - if not self._match(TokenType.HINT): + if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") return self.expression(exp.Hint, expressions=hints) return None @@ -1259,26 +1279,18 @@ class Parser(metaclass=_Parser): columns = self._parse_csv(self._parse_id_var) elif self._match(TokenType.L_PAREN): columns = self._parse_csv(self._parse_id_var) - self._match(TokenType.R_PAREN) + self._match_r_paren() expression = self.expression( exp.Lateral, this=this, view=view, outer=outer, - alias=self.expression( - exp.TableAlias, - this=table_alias, - columns=columns, - ), + alias=self.expression(exp.TableAlias, this=table_alias, columns=columns), ) if outer_apply or cross_apply: - return self.expression( - exp.Join, - this=expression, - side=None if cross_apply else "LEFT", - ) + return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT") return expression @@ -1387,12 +1399,8 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.UNNEST): return None - self._match_l_paren() - expressions = self._parse_csv(self._parse_column) - self._match_r_paren() - + expressions = self._parse_wrapped_csv(self._parse_column) ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) - alias = self._parse_table_alias() if alias and self.unnest_column_only: @@ -1402,10 +1410,7 @@ class Parser(metaclass=_Parser): alias.set("this", None) return self.expression( - exp.Unnest, - expressions=expressions, - ordinality=ordinality, - alias=alias, + exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias ) def _parse_derived_table_values(self): @@ -1418,13 +1423,7 @@ class Parser(metaclass=_Parser): if is_derived: self._match_r_paren() - alias = self._parse_table_alias() - - return self.expression( - exp.Values, - expressions=expressions, - alias=alias, - ) + return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) def _parse_table_sample(self): if not self._match(TokenType.TABLE_SAMPLE): @@ -1460,9 +1459,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() if self._match(TokenType.SEED): - self._match_l_paren() - seed = self._parse_number() - self._match_r_paren() + seed = self._parse_wrapped(self._parse_number) return self.expression( exp.TableSample, @@ -1513,12 +1510,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() - return self.expression( - exp.Pivot, - expressions=expressions, - field=field, - unpivot=unpivot, - ) + return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) def _parse_where(self, skip_where_token=False): if not skip_where_token and not self._match(TokenType.WHERE): @@ -1539,11 +1531,7 @@ class Parser(metaclass=_Parser): def _parse_grouping_sets(self): if not self._match(TokenType.GROUPING_SETS): return None - - self._match_l_paren() - grouping_sets = self._parse_csv(self._parse_grouping_set) - self._match_r_paren() - return grouping_sets + return self._parse_wrapped_csv(self._parse_grouping_set) def _parse_grouping_set(self): if self._match(TokenType.L_PAREN): @@ -1573,7 +1561,6 @@ class Parser(metaclass=_Parser): def _parse_sort(self, token_type, exp_class): if not self._match(token_type): return None - return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) def _parse_ordered(self): @@ -1602,9 +1589,12 @@ class Parser(metaclass=_Parser): if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) + if limit_paren: - self._match(TokenType.R_PAREN) + self._match_r_paren() + return limit_exp + if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" @@ -1612,11 +1602,13 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) self._match(TokenType.ONLY) return self.expression(exp.Fetch, direction=direction, count=count) + return this def _parse_offset(self, this=None): if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): return this + count = self._parse_number() self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) @@ -1678,6 +1670,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.DISTINCT_FROM): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) + this = self.expression( exp.Is, this=this, @@ -1754,11 +1747,7 @@ class Parser(metaclass=_Parser): def _parse_type(self): if self._match(TokenType.INTERVAL): - return self.expression( - exp.Interval, - this=self._parse_term(), - unit=self._parse_var(), - ) + return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var()) index = self._index type_token = self._parse_types(check_func=True) @@ -1824,30 +1813,18 @@ class Parser(metaclass=_Parser): value = None if type_token in self.TIMESTAMPS: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMPTZ, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) elif ( self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ ): - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMPLTZ, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match(TokenType.WITHOUT_TIME_ZONE): - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMP, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) maybe_func = maybe_func and value is None if value is None: - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMP, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) if maybe_func and check_func: index2 = self._index @@ -1872,6 +1849,7 @@ class Parser(metaclass=_Parser): this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() + if not data_type: return None return self.expression(exp.StructKwarg, this=this, expression=data_type) @@ -1879,7 +1857,6 @@ class Parser(metaclass=_Parser): def _parse_at_time_zone(self, this): if not self._match(TokenType.AT_TIME_ZONE): return this - return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) def _parse_column(self): @@ -1984,16 +1961,14 @@ class Parser(metaclass=_Parser): else: subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) - if subquery_predicate and self._curr.token_type in ( - TokenType.SELECT, - TokenType.WITH, - ): + if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH): this = self.expression(subquery_predicate, this=self._parse_select()) self._match_r_paren() return this if functions is None: functions = self.FUNCTIONS + function = functions.get(upper) args = self._parse_csv(self._parse_lambda) @@ -2014,6 +1989,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN): return this + expressions = self._parse_csv(self._parse_udf_kwarg) self._match_r_paren() return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) @@ -2021,25 +1997,19 @@ class Parser(metaclass=_Parser): def _parse_introducer(self, token): literal = self._parse_primary() if literal: - return self.expression( - exp.Introducer, - this=token.text, - expression=literal, - ) + return self.expression(exp.Introducer, this=token.text, expression=literal) return self.expression(exp.Identifier, this=token.text) def _parse_session_parameter(self): kind = None this = self._parse_id_var() or self._parse_primary() + if self._match(TokenType.DOT): kind = this.name this = self._parse_var() or self._parse_primary() - return self.expression( - exp.SessionParameter, - this=this, - kind=kind, - ) + + return self.expression(exp.SessionParameter, this=this, kind=kind) def _parse_udf_kwarg(self): this = self._parse_id_var() @@ -2106,7 +2076,10 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) def _parse_column_constraint(self): - this = None + this = self._parse_references() + + if this: + return this if self._match(TokenType.CONSTRAINT): this = self._parse_id_var() @@ -2114,13 +2087,12 @@ class Parser(metaclass=_Parser): if self._match(TokenType.AUTO_INCREMENT): kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): - self._match_l_paren() - kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction()) - self._match_r_paren() + constraint = self._parse_wrapped(self._parse_conjunction) + kind = self.expression(exp.CheckColumnConstraint, this=constraint) elif self._match(TokenType.COLLATE): kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field()) + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.SCHEMA_COMMENT): @@ -2137,7 +2109,7 @@ class Parser(metaclass=_Parser): kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) else: - return None + return this return self.expression(exp.ColumnConstraint, this=this, kind=kind) @@ -2159,37 +2131,29 @@ class Parser(metaclass=_Parser): def _parse_unnamed_constraint(self): if not self._match_set(self.CONSTRAINT_PARSERS): return None - return self.CONSTRAINT_PARSERS[self._prev.token_type](self) - def _parse_check(self): - self._match(TokenType.CHECK) - self._match_l_paren() - expression = self._parse_conjunction() - self._match_r_paren() - - return self.expression(exp.Check, this=expression) - def _parse_unique(self): - self._match(TokenType.UNIQUE) - columns = self._parse_wrapped_id_vars() - - return self.expression(exp.Unique, expressions=columns) + return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) - def _parse_foreign_key(self): - self._match(TokenType.FOREIGN_KEY) - - expressions = self._parse_wrapped_id_vars() - reference = self._match(TokenType.REFERENCES) and self.expression( + def _parse_references(self): + if not self._match(TokenType.REFERENCES): + return None + return self.expression( exp.Reference, this=self._parse_id_var(), expressions=self._parse_wrapped_id_vars(), ) + + def _parse_foreign_key(self): + expressions = self._parse_wrapped_id_vars() + reference = self._parse_references() options = {} while self._match(TokenType.ON): if not self._match_set((TokenType.DELETE, TokenType.UPDATE)): self.raise_error("Expected DELETE or UPDATE") + kind = self._prev.text.lower() if self._match(TokenType.NO_ACTION): @@ -2200,6 +2164,7 @@ class Parser(metaclass=_Parser): else: self._advance() action = self._prev.text.upper() + options[kind] = action return self.expression( @@ -2363,20 +2328,14 @@ class Parser(metaclass=_Parser): def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): - self._match_l_paren() - this = self.expression(exp.Filter, this=this, expression=self._parse_where()) - self._match_r_paren() + where = self._parse_wrapped(self._parse_where) + this = self.expression(exp.Filter, this=this, expression=where) # T-SQL allows the OVER (...) syntax after WITHIN GROUP. # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 if self._match(TokenType.WITHIN_GROUP): - self._match_l_paren() - this = self.expression( - exp.WithinGroup, - this=this, - expression=self._parse_order(), - ) - self._match_r_paren() + order = self._parse_wrapped(self._parse_order) + this = self.expression(exp.WithinGroup, this=this, expression=order) # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER # Some dialects choose to implement and some do not. @@ -2404,18 +2363,11 @@ class Parser(metaclass=_Parser): return this if not self._match(TokenType.L_PAREN): - alias = self._parse_id_var(False) - - return self.expression( - exp.Window, - this=this, - alias=alias, - ) - - partition = None + return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) alias = self._parse_id_var(False) + partition = None if self._match(TokenType.PARTITION_BY): partition = self._parse_csv(self._parse_conjunction) @@ -2552,17 +2504,13 @@ class Parser(metaclass=_Parser): def _parse_replace(self): if not self._match(TokenType.REPLACE): return None + return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) - self._match_l_paren() - columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression())) - self._match_r_paren() - return columns - - def _parse_csv(self, parse_method): + def _parse_csv(self, parse_method, sep=TokenType.COMMA): parse_result = parse_method() items = [parse_result] if parse_result is not None else [] - while self._match(TokenType.COMMA): + while self._match(sep): if parse_result and self._prev_comment is not None: parse_result.comment = self._prev_comment @@ -2583,16 +2531,53 @@ class Parser(metaclass=_Parser): return this def _parse_wrapped_id_vars(self): + return self._parse_wrapped_csv(self._parse_id_var) + + def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA): + return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep)) + + def _parse_wrapped(self, parse_method): self._match_l_paren() - expressions = self._parse_csv(self._parse_id_var) + parse_result = parse_method() self._match_r_paren() - return expressions + return parse_result def _parse_select_or_expression(self): return self._parse_select() or self._parse_expression() - def _parse_use(self): - return self.expression(exp.Use, this=self._parse_id_var()) + def _parse_transaction(self): + this = None + if self._match_texts(self.TRANSACTION_KIND): + this = self._prev.text + + self._match_texts({"TRANSACTION", "WORK"}) + + modes = [] + while True: + mode = [] + while self._match(TokenType.VAR): + mode.append(self._prev.text) + + if mode: + modes.append(" ".join(mode)) + if not self._match(TokenType.COMMA): + break + + return self.expression(exp.Transaction, this=this, modes=modes) + + def _parse_commit_or_rollback(self): + savepoint = None + is_rollback = self._prev.token_type == TokenType.ROLLBACK + + self._match_texts({"TRANSACTION", "WORK"}) + + if self._match_text_seq("TO"): + self._match_text_seq("SAVEPOINT") + savepoint = self._parse_id_var() + + if is_rollback: + return self.expression(exp.Rollback, savepoint=savepoint) + return self.expression(exp.Commit) def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) @@ -2675,7 +2660,13 @@ class Parser(metaclass=_Parser): if expression and self._prev_comment: expression.comment = self._prev_comment - def _match_text(self, *texts): + def _match_texts(self, texts): + if self._curr and self._curr.text.upper() in texts: + self._advance() + return True + return False + + def _match_text_seq(self, *texts): index = self._index for text in texts: if self._curr and self._curr.text.upper() == text: diff --git a/sqlglot/planner.py b/sqlglot/planner.py index cd1de5e..51db2d4 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import itertools import math +import typing as t from sqlglot import alias, exp from sqlglot.errors import UnsupportedError @@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition class Plan: - def __init__(self, expression): - self.expression = expression + def __init__(self, expression: exp.Expression) -> None: + self.expression = expression.copy() self.root = Step.from_expression(self.expression) - self._dag = {} + self._dag: t.Dict[Step, t.Set[Step]] = {} @property - def dag(self): + def dag(self) -> t.Dict[Step, t.Set[Step]]: if not self._dag: - dag = {} + dag: t.Dict[Step, t.Set[Step]] = {} nodes = {self.root} while nodes: @@ -29,32 +32,64 @@ class Plan: return self._dag @property - def leaves(self): + def leaves(self) -> t.Generator[Step, None, None]: return (node for node, deps in self.dag.items() if not deps) + def __repr__(self) -> str: + return f"Plan\n----\n{repr(self.root)}" + class Step: @classmethod - def from_expression(cls, expression, ctes=None): + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: """ - Build a DAG of Steps from a SQL expression. - - Giving an expression like: - - SELECT x.a, SUM(x.b) - FROM x - JOIN y - ON x.a = y.a + Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. + Note: the expression's tables and subqueries must be aliased for this method to work. For + example, given the following expression: + + SELECT + x.a, + SUM(x.b) + FROM x AS x + JOIN y AS y + ON x.a = y.a GROUP BY x.a - Transform it into a DAG of the form: - - Aggregate(x.a, SUM(x.b)) - Join(y) - Scan(x) - Scan(y) - - This can then more easily be executed on by an engine. + the following DAG is produced (the expression IDs might differ per execution): + + - Aggregate: x (4347984624) + Context: + Aggregations: + - SUM(x.b) + Group: + - x.a + Projections: + - x.a + - "x"."" + Dependencies: + - Join: x (4347985296) + Context: + y: + On: x.a = y.a + Projections: + Dependencies: + - Scan: x (4347983136) + Context: + Source: x AS x + Projections: + - Scan: y (4343416624) + Context: + Source: y AS y + Projections: + + Args: + expression: the expression to build the DAG from. + ctes: a dictionary that maps CTEs to their corresponding Step DAG by name. + + Returns: + A Step DAG corresponding to `expression`. """ ctes = ctes or {} with_ = expression.args.get("with") @@ -65,11 +100,11 @@ class Step: for cte in with_.expressions: step = Step.from_expression(cte.this, ctes) step.name = cte.alias - ctes[step.name] = step + ctes[step.name] = step # type: ignore from_ = expression.args.get("from") - if from_: + if isinstance(expression, exp.Select) and from_: from_ = from_.expressions if len(from_) > 1: raise UnsupportedError( @@ -77,8 +112,10 @@ class Step: ) step = Scan.from_expression(from_[0], ctes) + elif isinstance(expression, exp.Union): + step = SetOperation.from_expression(expression, ctes) else: - raise UnsupportedError("Static selects are unsupported.") + step = Scan() joins = expression.args.get("joins") @@ -115,7 +152,7 @@ class Step: group = expression.args.get("group") - if group: + if group or aggregations: aggregate = Aggregate() aggregate.source = step.name aggregate.name = step.name @@ -123,7 +160,15 @@ class Step: alias(operand, alias_) for operand, alias_ in operands.items() ) aggregate.aggregations = aggregations - aggregate.group = group.expressions + # give aggregates names and replace projections with references to them + aggregate.group = { + f"_g{i}": e for i, e in enumerate(group.expressions if group else []) + } + for projection in projections: + for i, e in aggregate.group.items(): + for child, _, _ in projection.walk(): + if child == e: + child.replace(exp.column(i, step.name)) aggregate.add_dependency(step) step = aggregate @@ -150,22 +195,22 @@ class Step: return step - def __init__(self): - self.name = None - self.dependencies = set() - self.dependents = set() - self.projections = [] - self.limit = math.inf - self.condition = None + def __init__(self) -> None: + self.name: t.Optional[str] = None + self.dependencies: t.Set[Step] = set() + self.dependents: t.Set[Step] = set() + self.projections: t.Sequence[exp.Expression] = [] + self.limit: float = math.inf + self.condition: t.Optional[exp.Expression] = None - def add_dependency(self, dependency): + def add_dependency(self, dependency: Step) -> None: self.dependencies.add(dependency) dependency.dependents.add(self) - def __repr__(self): + def __repr__(self) -> str: return self.to_s() - def to_s(self, level=0): + def to_s(self, level: int = 0) -> str: indent = " " * level nested = f"{indent} " @@ -175,7 +220,7 @@ class Step: context = [f"{nested}Context:"] + context lines = [ - f"{indent}- {self.__class__.__name__}: {self.name}", + f"{indent}- {self.id}", *context, f"{nested}Projections:", ] @@ -193,13 +238,25 @@ class Step: return "\n".join(lines) - def _to_s(self, _indent): + @property + def type_name(self) -> str: + return self.__class__.__name__ + + @property + def id(self) -> str: + name = self.name + name = f" {name}" if name else "" + return f"{self.type_name}:{name} ({id(self)})" + + def _to_s(self, _indent: str) -> t.List[str]: return [] class Scan(Step): @classmethod - def from_expression(cls, expression, ctes=None): + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: table = expression alias_ = expression.alias @@ -217,26 +274,24 @@ class Scan(Step): step = Scan() step.name = alias_ step.source = expression - if table.name in ctes: + if ctes and table.name in ctes: step.add_dependency(ctes[table.name]) return step - def __init__(self): + def __init__(self) -> None: super().__init__() - self.source = None - - def _to_s(self, indent): - return [f"{indent}Source: {self.source.sql()}"] + self.source: t.Optional[exp.Expression] = None - -class Write(Step): - pass + def _to_s(self, indent: str) -> t.List[str]: + return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore class Join(Step): @classmethod - def from_joins(cls, joins, ctes=None): + def from_joins( + cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: step = Join() for join in joins: @@ -252,28 +307,28 @@ class Join(Step): return step - def __init__(self): + def __init__(self) -> None: super().__init__() - self.joins = {} + self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {} - def _to_s(self, indent): + def _to_s(self, indent: str) -> t.List[str]: lines = [] for name, join in self.joins.items(): lines.append(f"{indent}{name}: {join['side']}") if join.get("condition"): - lines.append(f"{indent}On: {join['condition'].sql()}") + lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore return lines class Aggregate(Step): - def __init__(self): + def __init__(self) -> None: super().__init__() - self.aggregations = [] - self.operands = [] - self.group = [] - self.source = None + self.aggregations: t.List[exp.Expression] = [] + self.operands: t.Tuple[exp.Expression, ...] = () + self.group: t.Dict[str, exp.Expression] = {} + self.source: t.Optional[str] = None - def _to_s(self, indent): + def _to_s(self, indent: str) -> t.List[str]: lines = [f"{indent}Aggregations:"] for expression in self.aggregations: @@ -281,7 +336,7 @@ class Aggregate(Step): if self.group: lines.append(f"{indent}Group:") - for expression in self.group: + for expression in self.group.values(): lines.append(f"{indent} - {expression.sql()}") if self.operands: lines.append(f"{indent}Operands:") @@ -292,14 +347,56 @@ class Aggregate(Step): class Sort(Step): - def __init__(self): + def __init__(self) -> None: super().__init__() self.key = None - def _to_s(self, indent): + def _to_s(self, indent: str) -> t.List[str]: lines = [f"{indent}Key:"] - for expression in self.key: + for expression in self.key: # type: ignore lines.append(f"{indent} - {expression.sql()}") return lines + + +class SetOperation(Step): + def __init__( + self, + op: t.Type[exp.Expression], + left: str | None, + right: str | None, + distinct: bool = False, + ) -> None: + super().__init__() + self.op = op + self.left = left + self.right = right + self.distinct = distinct + + @classmethod + def from_expression( + cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None + ) -> Step: + assert isinstance(expression, exp.Union) + left = Step.from_expression(expression.left, ctes) + right = Step.from_expression(expression.right, ctes) + step = cls( + op=expression.__class__, + left=left.name, + right=right.name, + distinct=expression.args.get("distinct"), + ) + step.add_dependency(left) + step.add_dependency(right) + return step + + def _to_s(self, indent: str) -> t.List[str]: + lines = [] + if self.distinct: + lines.append(f"{indent}Distinct: {self.distinct}") + return lines + + @property + def type_name(self) -> str: + return self.op.__name__ diff --git a/sqlglot/schema.py b/sqlglot/schema.py index fcf7291..f6f303b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import expressions as exp from sqlglot.errors import SchemaError -from sqlglot.helper import csv_reader +from sqlglot.helper import dict_depth from sqlglot.trie import in_trie, new_trie if t.TYPE_CHECKING: @@ -15,6 +15,8 @@ if t.TYPE_CHECKING: TABLE_ARGS = ("this", "db", "catalog") +T = t.TypeVar("T") + class Schema(abc.ABC): """Abstract base class for database schemas""" @@ -57,8 +59,81 @@ class Schema(abc.ABC): The resulting column type. """ + @property + def supported_table_args(self) -> t.Tuple[str, ...]: + """ + Table arguments this schema support, e.g. `("this", "db", "catalog")` + """ + raise NotImplementedError + + +class AbstractMappingSchema(t.Generic[T]): + def __init__( + self, + mapping: dict | None = None, + ) -> None: + self.mapping = mapping or {} + self.mapping_trie = self._build_trie(self.mapping) + self._supported_table_args: t.Tuple[str, ...] = tuple() + + def _build_trie(self, schema: t.Dict) -> t.Dict: + return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth())) + + def _depth(self) -> int: + return dict_depth(self.mapping) + + @property + def supported_table_args(self) -> t.Tuple[str, ...]: + if not self._supported_table_args and self.mapping: + depth = self._depth() + + if not depth: # None + self._supported_table_args = tuple() + elif 1 <= depth <= 3: + self._supported_table_args = TABLE_ARGS[:depth] + else: + raise SchemaError(f"Invalid mapping shape. Depth: {depth}") + + return self._supported_table_args + + def table_parts(self, table: exp.Table) -> t.List[str]: + if isinstance(table.this, exp.ReadCSV): + return [table.this.name] + return [table.text(part) for part in TABLE_ARGS if table.text(part)] + + def find( + self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True + ) -> t.Optional[T]: + parts = self.table_parts(table)[0 : len(self.supported_table_args)] + value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) + + if value == 0: + if raise_on_missing: + raise SchemaError(f"Cannot find mapping for {table}.") + else: + return None + elif value == 1: + possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) + if len(possibilities) == 1: + parts.extend(possibilities[0]) + else: + message = ", ".join(".".join(parts) for parts in possibilities) + if raise_on_missing: + raise SchemaError(f"Ambiguous mapping for {table}: {message}.") + return None + return self._nested_get(parts, raise_on_missing=raise_on_missing) -class MappingSchema(Schema): + def _nested_get( + self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True + ) -> t.Optional[t.Any]: + return _nested_get( + d or self.mapping, + *zip(self.supported_table_args, reversed(parts)), + raise_on_missing=raise_on_missing, + ) + + +class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): """ Schema based on a nested mapping. @@ -82,17 +157,17 @@ class MappingSchema(Schema): visible: t.Optional[t.Dict] = None, dialect: t.Optional[str] = None, ) -> None: - self.schema = schema or {} + super().__init__(schema) self.visible = visible or {} - self.schema_trie = self._build_trie(self.schema) self.dialect = dialect - self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {} - self._supported_table_args: t.Tuple[str, ...] = tuple() + self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = { + "STR": exp.DataType.Type.TEXT, + } @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: return MappingSchema( - schema=mapping_schema.schema, + schema=mapping_schema.mapping, visible=mapping_schema.visible, dialect=mapping_schema.dialect, ) @@ -100,27 +175,13 @@ class MappingSchema(Schema): def copy(self, **kwargs) -> MappingSchema: return MappingSchema( **{ # type: ignore - "schema": self.schema.copy(), + "schema": self.mapping.copy(), "visible": self.visible.copy(), "dialect": self.dialect, **kwargs, } ) - @property - def supported_table_args(self): - if not self._supported_table_args and self.schema: - depth = _dict_depth(self.schema) - - if not depth or depth == 1: # {} - self._supported_table_args = tuple() - elif 2 <= depth <= 4: - self._supported_table_args = TABLE_ARGS[: depth - 1] - else: - raise SchemaError(f"Invalid schema shape. Depth: {depth}") - - return self._supported_table_args - def add_table( self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None ) -> None: @@ -133,17 +194,21 @@ class MappingSchema(Schema): """ table_ = self._ensure_table(table) column_mapping = ensure_column_mapping(column_mapping) - schema = self.find_schema(table_, raise_on_missing=False) + schema = self.find(table_, raise_on_missing=False) if schema and not column_mapping: return _nested_set( - self.schema, + self.mapping, list(reversed(self.table_parts(table_))), column_mapping, ) - self.schema_trie = self._build_trie(self.schema) + self.mapping_trie = self._build_trie(self.mapping) + + def _depth(self) -> int: + # The columns themselves are a mapping, but we don't want to include those + return super()._depth() - 1 def _ensure_table(self, table: exp.Table | str) -> exp.Table: table_ = exp.to_table(table) @@ -153,16 +218,9 @@ class MappingSchema(Schema): return table_ - def table_parts(self, table: exp.Table) -> t.List[str]: - return [table.text(part) for part in TABLE_ARGS if table.text(part)] - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: table_ = self._ensure_table(table) - - if not isinstance(table_.this, exp.Identifier): - return fs_get(table) # type: ignore - - schema = self.find_schema(table_) + schema = self.find(table_) if schema is None: raise SchemaError(f"Could not find table schema {table}") @@ -173,36 +231,13 @@ class MappingSchema(Schema): visible = self._nested_get(self.table_parts(table_), self.visible) return [col for col in schema if col in visible] # type: ignore - def find_schema( - self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True - ) -> t.Optional[t.Dict[str, str]]: - parts = self.table_parts(table)[0 : len(self.supported_table_args)] - value, trie = in_trie(self.schema_trie if trie is None else trie, parts) - - if value == 0: - if raise_on_missing: - raise SchemaError(f"Cannot find schema for {table}.") - else: - return None - elif value == 1: - possibilities = flatten_schema(trie) - if len(possibilities) == 1: - parts.extend(possibilities[0]) - else: - message = ", ".join(".".join(parts) for parts in possibilities) - if raise_on_missing: - raise SchemaError(f"Ambiguous schema for {table}: {message}.") - return None - - return self._nested_get(parts, raise_on_missing=raise_on_missing) - def get_column_type( self, table: exp.Table | str, column: exp.Column | str ) -> exp.DataType.Type: column_name = column if isinstance(column, str) else column.name table_ = exp.to_table(table) if table_: - table_schema = self.find_schema(table_) + table_schema = self.find(table_) schema_type = table_schema.get(column_name).upper() # type: ignore return self._convert_type(schema_type) raise SchemaError(f"Could not convert table '{table}'") @@ -228,18 +263,6 @@ class MappingSchema(Schema): return self._type_mapping_cache[schema_type] - def _build_trie(self, schema: t.Dict): - return new_trie(tuple(reversed(t)) for t in flatten_schema(schema)) - - def _nested_get( - self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True - ) -> t.Optional[t.Any]: - return _nested_get( - d or self.schema, - *zip(self.supported_table_args, reversed(parts)), - raise_on_missing=raise_on_missing, - ) - def ensure_schema(schema: t.Any) -> Schema: if isinstance(schema, Schema): @@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): raise ValueError(f"Invalid mapping provided: {type(mapping)}") -def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]: +def flatten_schema( + schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None +) -> t.List[t.List[str]]: tables = [] keys = keys or [] - depth = _dict_depth(schema) for k, v in schema.items(): - if depth >= 3: - tables.extend(flatten_schema(v, keys + [k])) - elif depth == 2: + if depth >= 2: + tables.extend(flatten_schema(v, depth - 1, keys + [k])) + elif depth == 1: tables.append(keys + [k]) return tables -def fs_get(table: exp.Table) -> t.List[str]: - name = table.this.name - - if name.upper() == "READ_CSV": - with csv_reader(table) as reader: - return next(reader) - - raise ValueError(f"Cannot read schema for {table}") - - def _nested_get( d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True ) -> t.Optional[t.Any]: @@ -310,7 +324,7 @@ def _nested_get( if d is None: if raise_on_missing: name = "table" if name == "this" else name - raise ValueError(f"Unknown {name}") + raise ValueError(f"Unknown {name}: {key}") return None return d @@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: subd[keys[-1]] = value return d - - -def _dict_depth(d: t.Dict) -> int: - """ - Get the nesting depth of a dictionary. - - For example: - >>> _dict_depth(None) - 0 - >>> _dict_depth({}) - 1 - >>> _dict_depth({"a": "b"}) - 1 - >>> _dict_depth({"a": {}}) - 2 - >>> _dict_depth({"a": {"b": {}}}) - 3 - - Args: - d (dict): dictionary - Returns: - int: depth - """ - try: - return 1 + _dict_depth(next(iter(d.values()))) - except AttributeError: - # d doesn't have attribute "values" - return 0 - except StopIteration: - # d.values() returns an empty sequence - return 1 diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 95d84d6..ec8cd91 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -105,12 +105,9 @@ class TokenType(AutoName): OBJECT = auto() # keywords - ADD_FILE = auto() ALIAS = auto() ALWAYS = auto() ALL = auto() - ALTER = auto() - ANALYZE = auto() ANTI = auto() ANY = auto() APPLY = auto() @@ -124,14 +121,14 @@ class TokenType(AutoName): BUCKET = auto() BY_DEFAULT = auto() CACHE = auto() - CALL = auto() + CASCADE = auto() CASE = auto() CHARACTER_SET = auto() CHECK = auto() CLUSTER_BY = auto() COLLATE = auto() + COMMAND = auto() COMMENT = auto() - COMMENT_ON = auto() COMMIT = auto() CONSTRAINT = auto() CREATE = auto() @@ -149,7 +146,9 @@ class TokenType(AutoName): DETERMINISTIC = auto() DISTINCT = auto() DISTINCT_FROM = auto() + DISTKEY = auto() DISTRIBUTE_BY = auto() + DISTSTYLE = auto() DIV = auto() DROP = auto() ELSE = auto() @@ -159,7 +158,6 @@ class TokenType(AutoName): EXCEPT = auto() EXECUTE = auto() EXISTS = auto() - EXPLAIN = auto() FALSE = auto() FETCH = auto() FILTER = auto() @@ -216,7 +214,6 @@ class TokenType(AutoName): OFFSET = auto() ON = auto() ONLY = auto() - OPTIMIZE = auto() OPTIONS = auto() ORDER_BY = auto() ORDERED = auto() @@ -258,6 +255,7 @@ class TokenType(AutoName): SHOW = auto() SIMILAR_TO = auto() SOME = auto() + SORTKEY = auto() SORT_BY = auto() STABLE = auto() STORED = auto() @@ -268,9 +266,8 @@ class TokenType(AutoName): TRANSIENT = auto() TOP = auto() THEN = auto() - TRUE = auto() TRAILING = auto() - TRUNCATE = auto() + TRUE = auto() UNBOUNDED = auto() UNCACHE = auto() UNION = auto() @@ -280,7 +277,6 @@ class TokenType(AutoName): USE = auto() USING = auto() VALUES = auto() - VACUUM = auto() VIEW = auto() VOLATILE = auto() WHEN = auto() @@ -420,7 +416,6 @@ class Tokenizer(metaclass=_Tokenizer): KEYWORDS = { "/*+": TokenType.HINT, - "*/": TokenType.HINT, "==": TokenType.EQ, "::": TokenType.DCOLON, "||": TokenType.DPIPE, @@ -435,15 +430,7 @@ class Tokenizer(metaclass=_Tokenizer): "#>": TokenType.HASH_ARROW, "#>>": TokenType.DHASH_ARROW, "<->": TokenType.LR_ARROW, - "ADD ARCHIVE": TokenType.ADD_FILE, - "ADD ARCHIVES": TokenType.ADD_FILE, - "ADD FILE": TokenType.ADD_FILE, - "ADD FILES": TokenType.ADD_FILE, - "ADD JAR": TokenType.ADD_FILE, - "ADD JARS": TokenType.ADD_FILE, "ALL": TokenType.ALL, - "ALTER": TokenType.ALTER, - "ANALYZE": TokenType.ANALYZE, "AND": TokenType.AND, "ANTI": TokenType.ANTI, "ANY": TokenType.ANY, @@ -455,10 +442,10 @@ class Tokenizer(metaclass=_Tokenizer): "BETWEEN": TokenType.BETWEEN, "BOTH": TokenType.BOTH, "BUCKET": TokenType.BUCKET, - "CALL": TokenType.CALL, "CACHE": TokenType.CACHE, "UNCACHE": TokenType.UNCACHE, "CASE": TokenType.CASE, + "CASCADE": TokenType.CASCADE, "CHARACTER SET": TokenType.CHARACTER_SET, "CHECK": TokenType.CHECK, "CLUSTER BY": TokenType.CLUSTER_BY, @@ -479,7 +466,9 @@ class Tokenizer(metaclass=_Tokenizer): "DETERMINISTIC": TokenType.DETERMINISTIC, "DISTINCT": TokenType.DISTINCT, "DISTINCT FROM": TokenType.DISTINCT_FROM, + "DISTKEY": TokenType.DISTKEY, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, + "DISTSTYLE": TokenType.DISTSTYLE, "DIV": TokenType.DIV, "DROP": TokenType.DROP, "ELSE": TokenType.ELSE, @@ -489,7 +478,6 @@ class Tokenizer(metaclass=_Tokenizer): "EXCEPT": TokenType.EXCEPT, "EXECUTE": TokenType.EXECUTE, "EXISTS": TokenType.EXISTS, - "EXPLAIN": TokenType.EXPLAIN, "FALSE": TokenType.FALSE, "FETCH": TokenType.FETCH, "FILTER": TokenType.FILTER, @@ -541,7 +529,6 @@ class Tokenizer(metaclass=_Tokenizer): "OFFSET": TokenType.OFFSET, "ON": TokenType.ON, "ONLY": TokenType.ONLY, - "OPTIMIZE": TokenType.OPTIMIZE, "OPTIONS": TokenType.OPTIONS, "OR": TokenType.OR, "ORDER BY": TokenType.ORDER_BY, @@ -579,6 +566,7 @@ class Tokenizer(metaclass=_Tokenizer): "SET": TokenType.SET, "SHOW": TokenType.SHOW, "SOME": TokenType.SOME, + "SORTKEY": TokenType.SORTKEY, "SORT BY": TokenType.SORT_BY, "STABLE": TokenType.STABLE, "STORED": TokenType.STORED, @@ -592,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer): "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, "TRAILING": TokenType.TRAILING, - "TRUNCATE": TokenType.TRUNCATE, "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, "UNPIVOT": TokenType.UNPIVOT, @@ -600,7 +587,6 @@ class Tokenizer(metaclass=_Tokenizer): "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, "USING": TokenType.USING, - "VACUUM": TokenType.VACUUM, "VALUES": TokenType.VALUES, "VIEW": TokenType.VIEW, "VOLATILE": TokenType.VOLATILE, @@ -659,6 +645,14 @@ class Tokenizer(metaclass=_Tokenizer): "UNIQUE": TokenType.UNIQUE, "STRUCT": TokenType.STRUCT, "VARIANT": TokenType.VARIANT, + "ALTER": TokenType.COMMAND, + "ANALYZE": TokenType.COMMAND, + "CALL": TokenType.COMMAND, + "EXPLAIN": TokenType.COMMAND, + "OPTIMIZE": TokenType.COMMAND, + "PREPARE": TokenType.COMMAND, + "TRUNCATE": TokenType.COMMAND, + "VACUUM": TokenType.COMMAND, } WHITE_SPACE = { @@ -670,20 +664,11 @@ class Tokenizer(metaclass=_Tokenizer): } COMMANDS = { - TokenType.ALTER, - TokenType.ADD_FILE, - TokenType.ANALYZE, - TokenType.BEGIN, - TokenType.CALL, - TokenType.COMMENT_ON, - TokenType.COMMIT, - TokenType.EXPLAIN, - TokenType.OPTIMIZE, + TokenType.COMMAND, + TokenType.EXECUTE, + TokenType.FETCH, TokenType.SET, TokenType.SHOW, - TokenType.TRUNCATE, - TokenType.VACUUM, - TokenType.ROLLBACK, } # handle numeric literals like in hive (3L = BIGINT) @@ -885,6 +870,7 @@ class Tokenizer(metaclass=_Tokenizer): if comment_start_line == self._prev_token_line: if self._prev_token_comment is None: self.tokens[-1].comment = self._comment + self._prev_token_comment = self._comment self._comment = None diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py index e36667b..24850bc 100644 --- a/tests/dataframe/unit/test_dataframe.py +++ b/tests/dataframe/unit/test_dataframe.py @@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator class TestDataframe(DataFrameSQLValidator): + maxDiff = None + def test_hash_select_expression(self): expression = exp.select("cola").from_("table") self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) @@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator): def test_cache(self): df = self.df_employee.select("fname").cache() expected_statements = [ - "DROP VIEW IF EXISTS t11623", - "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + "DROP VIEW IF EXISTS t31563", + "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`", ] self.compare_sql(df, expected_statements) def test_persist_default(self): df = self.df_employee.select("fname").persist() expected_statements = [ - "DROP VIEW IF EXISTS t11623", - "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + "DROP VIEW IF EXISTS t31563", + "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`", ] self.compare_sql(df, expected_statements) def test_persist_storagelevel(self): df = self.df_employee.select("fname").persist("DISK_ONLY_2") expected_statements = [ - "DROP VIEW IF EXISTS t11623", - "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + "DROP VIEW IF EXISTS t31563", + "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`", ] self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 14b4a0a..7c646f5 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator class TestDataFrameWriter(DataFrameSQLValidator): + maxDiff = None + def test_insertInto_full_path(self): df = self.df_employee.write.insertInto("catalog.db.table_name") - expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_db_table(self): df = self.df_employee.write.insertInto("db.table_name") - expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_table(self): df = self.df_employee.write.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_overwrite(self): df = self.df_employee.write.insertInto("table_name", overwrite=True) - expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) @mock.patch("sqlglot.schema", MappingSchema()) def test_insertInto_byName(self): sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) df = self.df_employee.write.byName.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_cache(self): df = self.df_employee.cache().write.insertInto("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t35612", - "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + "DROP VIEW IF EXISTS t37164", + "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", ] self.compare_sql(df, expected_statements) @@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator): def test_saveAsTable_append(self): df = self.df_employee.write.saveAsTable("table_name", mode="append") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_overwrite(self): df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_error(self): df = self.df_employee.write.saveAsTable("table_name", mode="error") - expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_ignore(self): df = self.df_employee.write.saveAsTable("table_name", mode="ignore") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_standalone(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_override(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_cache(self): df = self.df_employee.cache().write.saveAsTable("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t35612", - "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + "DROP VIEW IF EXISTS t37164", + "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", ] self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 7e8bfad..55aa547 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator class TestDataframeSession(DataFrameSQLValidator): def test_cdf_one_row(self): df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"]) - expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_cdf_multiple_rows(self): df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"]) - expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_cdf_no_schema(self): df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) - expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)" self.compare_sql(df, expected) def test_cdf_row_mixed_primitives(self): df = self.spark.createDataFrame([[1, 10.1, "test", False, None]]) - expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)" + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)" self.compare_sql(df, expected) def test_cdf_dict_rows(self): df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}]) - expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_cdf_str_schema(self): df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") - expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_basic(self): @@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator): ] ) df = self.spark.createDataFrame([[1, "test"]], schema) - expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_nested(self): @@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator): ] ) df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema) - expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)" + expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)" + self.compare_sql(df, expected) @mock.patch("sqlglot.schema", MappingSchema()) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index a0ebc45..cc44311 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -286,6 +286,10 @@ class TestBigQuery(Validator): "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", }, ) + self.validate_identity("BEGIN A B C D E F") + self.validate_identity("BEGIN TRANSACTION") + self.validate_identity("COMMIT TRANSACTION") + self.validate_identity("ROLLBACK TRANSACTION") def test_user_defined_functions(self): self.validate_identity( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 1913f53..1b2f9c1 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -69,6 +69,7 @@ class TestDialect(Validator): write={ "bigquery": "CAST(a AS STRING)", "clickhouse": "CAST(a AS TEXT)", + "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS TEXT)", "hive": "CAST(a AS STRING)", @@ -86,6 +87,7 @@ class TestDialect(Validator): write={ "bigquery": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))", + "drill": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BINARY(4))", "mysql": "CAST(a AS BINARY(4))", "hive": "CAST(a AS BINARY(4))", @@ -146,6 +148,7 @@ class TestDialect(Validator): "CAST(a AS STRING)", write={ "bigquery": "CAST(a AS STRING)", + "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS TEXT)", "hive": "CAST(a AS STRING)", @@ -162,6 +165,7 @@ class TestDialect(Validator): "CAST(a AS VARCHAR)", write={ "bigquery": "CAST(a AS STRING)", + "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS VARCHAR)", "hive": "CAST(a AS STRING)", @@ -178,6 +182,7 @@ class TestDialect(Validator): "CAST(a AS VARCHAR(3))", write={ "bigquery": "CAST(a AS STRING(3))", + "drill": "CAST(a AS VARCHAR(3))", "duckdb": "CAST(a AS TEXT(3))", "mysql": "CAST(a AS VARCHAR(3))", "hive": "CAST(a AS VARCHAR(3))", @@ -194,6 +199,7 @@ class TestDialect(Validator): "CAST(a AS SMALLINT)", write={ "bigquery": "CAST(a AS INT64)", + "drill": "CAST(a AS INTEGER)", "duckdb": "CAST(a AS SMALLINT)", "mysql": "CAST(a AS SMALLINT)", "hive": "CAST(a AS SMALLINT)", @@ -215,6 +221,7 @@ class TestDialect(Validator): }, write={ "duckdb": "TRY_CAST(a AS DOUBLE)", + "drill": "CAST(a AS DOUBLE)", "postgres": "CAST(a AS DOUBLE PRECISION)", "redshift": "CAST(a AS DOUBLE PRECISION)", }, @@ -225,6 +232,7 @@ class TestDialect(Validator): write={ "bigquery": "CAST(a AS FLOAT64)", "clickhouse": "CAST(a AS Float64)", + "drill": "CAST(a AS DOUBLE)", "duckdb": "CAST(a AS DOUBLE)", "mysql": "CAST(a AS DOUBLE)", "hive": "CAST(a AS DOUBLE)", @@ -279,6 +287,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", + "drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')", "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", }, @@ -286,6 +295,7 @@ class TestDialect(Validator): self.validate_all( "STR_TO_TIME('2020-01-01', '%Y-%m-%d')", write={ + "drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", @@ -298,6 +308,7 @@ class TestDialect(Validator): self.validate_all( "STR_TO_TIME(x, '%y')", write={ + "drill": "TO_TIMESTAMP(x, 'yy')", "duckdb": "STRPTIME(x, '%y')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%y')", @@ -319,6 +330,7 @@ class TestDialect(Validator): self.validate_all( "TIME_STR_TO_DATE('2020-01-01')", write={ + "drill": "CAST('2020-01-01' AS DATE)", "duckdb": "CAST('2020-01-01' AS DATE)", "hive": "TO_DATE('2020-01-01')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", @@ -328,6 +340,7 @@ class TestDialect(Validator): self.validate_all( "TIME_STR_TO_TIME('2020-01-01')", write={ + "drill": "CAST('2020-01-01' AS TIMESTAMP)", "duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", @@ -344,6 +357,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_STR(x, '%Y-%m-%d')", write={ + "drill": "TO_CHAR(x, 'yyyy-MM-dd')", "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", "oracle": "TO_CHAR(x, 'YYYY-MM-DD')", @@ -355,6 +369,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_TIME_STR(x)", write={ + "drill": "CAST(x AS VARCHAR)", "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", @@ -364,6 +379,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_UNIX(x)", write={ + "drill": "UNIX_TIMESTAMP(x)", "duckdb": "EPOCH(x)", "hive": "UNIX_TIMESTAMP(x)", "presto": "TO_UNIXTIME(x)", @@ -425,6 +441,7 @@ class TestDialect(Validator): self.validate_all( "DATE_TO_DATE_STR(x)", write={ + "drill": "CAST(x AS VARCHAR)", "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", @@ -433,6 +450,7 @@ class TestDialect(Validator): self.validate_all( "DATE_TO_DI(x)", write={ + "drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)", "duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)", "hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)", "presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)", @@ -441,6 +459,7 @@ class TestDialect(Validator): self.validate_all( "DI_TO_DATE(x)", write={ + "drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')", "duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)", "hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')", "presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)", @@ -463,6 +482,7 @@ class TestDialect(Validator): }, write={ "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "drill": "DATE_ADD(x, INTERVAL '1' DAY)", "duckdb": "x + INTERVAL 1 day", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", @@ -477,6 +497,7 @@ class TestDialect(Validator): "DATE_ADD(x, 1)", write={ "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "drill": "DATE_ADD(x, INTERVAL '1' DAY)", "duckdb": "x + INTERVAL 1 DAY", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", @@ -546,6 +567,7 @@ class TestDialect(Validator): "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", }, write={ + "drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')", "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", @@ -556,6 +578,7 @@ class TestDialect(Validator): self.validate_all( "STR_TO_DATE(x, '%Y-%m-%d')", write={ + "drill": "CAST(x AS DATE)", "mysql": "STR_TO_DATE(x, '%Y-%m-%d')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%d')", "hive": "CAST(x AS DATE)", @@ -566,6 +589,7 @@ class TestDialect(Validator): self.validate_all( "DATE_STR_TO_DATE(x)", write={ + "drill": "CAST(x AS DATE)", "duckdb": "CAST(x AS DATE)", "hive": "TO_DATE(x)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", @@ -575,6 +599,7 @@ class TestDialect(Validator): self.validate_all( "TS_OR_DS_ADD('2021-02-01', 1, 'DAY')", write={ + "drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)", "duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY", "hive": "DATE_ADD('2021-02-01', 1)", "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))", @@ -584,6 +609,7 @@ class TestDialect(Validator): self.validate_all( "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", write={ + "drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)", "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", "hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", "presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))", @@ -593,6 +619,7 @@ class TestDialect(Validator): self.validate_all( "TIMESTAMP '2022-01-01'", write={ + "drill": "CAST('2022-01-01' AS TIMESTAMP)", "mysql": "CAST('2022-01-01' AS TIMESTAMP)", "starrocks": "CAST('2022-01-01' AS DATETIME)", "hive": "CAST('2022-01-01' AS TIMESTAMP)", @@ -614,6 +641,7 @@ class TestDialect(Validator): dialect: f"{unit}(x)" for dialect in ( "bigquery", + "drill", "duckdb", "mysql", "presto", @@ -624,6 +652,7 @@ class TestDialect(Validator): dialect: f"{unit}(x)" for dialect in ( "bigquery", + "drill", "duckdb", "mysql", "presto", @@ -649,6 +678,7 @@ class TestDialect(Validator): write={ "bigquery": "ARRAY_LENGTH(x)", "duckdb": "ARRAY_LENGTH(x)", + "drill": "REPEATED_COUNT(x)", "presto": "CARDINALITY(x)", "spark": "SIZE(x)", }, @@ -736,6 +766,7 @@ class TestDialect(Validator): self.validate_all( "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)", write={ + "drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", }, @@ -743,6 +774,7 @@ class TestDialect(Validator): self.validate_all( "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)", write={ + "drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", }, @@ -775,6 +807,7 @@ class TestDialect(Validator): }, write={ "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "drill": "SELECT * FROM a UNION SELECT * FROM b", "duckdb": "SELECT * FROM a UNION SELECT * FROM b", "presto": "SELECT * FROM a UNION SELECT * FROM b", "spark": "SELECT * FROM a UNION SELECT * FROM b", @@ -887,6 +920,7 @@ class TestDialect(Validator): write={ "bigquery": "LOWER(x) LIKE '%y'", "clickhouse": "x ILIKE '%y'", + "drill": "x `ILIKE` '%y'", "duckdb": "x ILIKE '%y'", "hive": "LOWER(x) LIKE '%y'", "mysql": "LOWER(x) LIKE '%y'", @@ -910,32 +944,38 @@ class TestDialect(Validator): self.validate_all( "POSITION(' ' in x)", write={ + "drill": "STRPOS(x, ' ')", "duckdb": "STRPOS(x, ' ')", "postgres": "STRPOS(x, ' ')", "presto": "STRPOS(x, ' ')", "spark": "LOCATE(' ', x)", "clickhouse": "position(x, ' ')", "snowflake": "POSITION(' ', x)", + "mysql": "LOCATE(' ', x)", }, ) self.validate_all( "STR_POSITION('a', x)", write={ + "drill": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')", "postgres": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')", "spark": "LOCATE('a', x)", "clickhouse": "position(x, 'a')", "snowflake": "POSITION('a', x)", + "mysql": "LOCATE('a', x)", }, ) self.validate_all( "POSITION('a', x, 3)", write={ + "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "spark": "LOCATE('a', x, 3)", "clickhouse": "position(x, 'a', 3)", "snowflake": "POSITION('a', x, 3)", + "mysql": "LOCATE('a', x, 3)", }, ) self.validate_all( @@ -960,6 +1000,7 @@ class TestDialect(Validator): self.validate_all( "IF(x > 1, 1, 0)", write={ + "drill": "`IF`(x > 1, 1, 0)", "duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END", "presto": "IF(x > 1, 1, 0)", "hive": "IF(x > 1, 1, 0)", @@ -970,6 +1011,7 @@ class TestDialect(Validator): self.validate_all( "CASE WHEN 1 THEN x ELSE 0 END", write={ + "drill": "CASE WHEN 1 THEN x ELSE 0 END", "duckdb": "CASE WHEN 1 THEN x ELSE 0 END", "presto": "CASE WHEN 1 THEN x ELSE 0 END", "hive": "CASE WHEN 1 THEN x ELSE 0 END", @@ -980,6 +1022,7 @@ class TestDialect(Validator): self.validate_all( "x[y]", write={ + "drill": "x[y]", "duckdb": "x[y]", "presto": "x[y]", "hive": "x[y]", @@ -1000,6 +1043,7 @@ class TestDialect(Validator): 'true or null as "foo"', write={ "bigquery": "TRUE OR NULL AS `foo`", + "drill": "TRUE OR NULL AS `foo`", "duckdb": 'TRUE OR NULL AS "foo"', "presto": 'TRUE OR NULL AS "foo"', "hive": "TRUE OR NULL AS `foo`", @@ -1020,6 +1064,7 @@ class TestDialect(Validator): "LEVENSHTEIN(col1, col2)", write={ "duckdb": "LEVENSHTEIN(col1, col2)", + "drill": "LEVENSHTEIN_DISTANCE(col1, col2)", "presto": "LEVENSHTEIN_DISTANCE(col1, col2)", "hive": "LEVENSHTEIN(col1, col2)", "spark": "LEVENSHTEIN(col1, col2)", @@ -1029,6 +1074,7 @@ class TestDialect(Validator): "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", write={ "duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + "drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", "spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", @@ -1152,6 +1198,7 @@ class TestDialect(Validator): self.validate_all( "SELECT a AS b FROM x GROUP BY b", write={ + "drill": "SELECT a AS b FROM x GROUP BY b", "duckdb": "SELECT a AS b FROM x GROUP BY b", "presto": "SELECT a AS b FROM x GROUP BY 1", "hive": "SELECT a AS b FROM x GROUP BY 1", @@ -1162,6 +1209,7 @@ class TestDialect(Validator): self.validate_all( "SELECT y x FROM my_table t", write={ + "drill": "SELECT y AS x FROM my_table AS t", "hive": "SELECT y AS x FROM my_table AS t", "oracle": "SELECT y AS x FROM my_table t", "postgres": "SELECT y AS x FROM my_table AS t", @@ -1230,3 +1278,36 @@ SELECT }, pretty=True, ) + + def test_transactions(self): + self.validate_all( + "BEGIN TRANSACTION", + write={ + "bigquery": "BEGIN TRANSACTION", + "mysql": "BEGIN", + "postgres": "BEGIN", + "presto": "START TRANSACTION", + "trino": "START TRANSACTION", + "redshift": "BEGIN", + "snowflake": "BEGIN", + "sqlite": "BEGIN TRANSACTION", + }, + ) + self.validate_all( + "BEGIN", + read={ + "presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", + "trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", + }, + ) + self.validate_all( + "BEGIN", + read={ + "presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ", + "trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ", + }, + ) + self.validate_all( + "BEGIN IMMEDIATE TRANSACTION", + write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"}, + ) diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py new file mode 100644 index 0000000..9819daa --- /dev/null +++ b/tests/dialects/test_drill.py @@ -0,0 +1,53 @@ +from tests.dialects.test_dialect import Validator + + +class TestDrill(Validator): + dialect = "drill" + + def test_string_literals(self): + self.validate_all( + "SELECT '2021-01-01' + INTERVAL 1 MONTH", + write={ + "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + }, + ) + + def test_quotes(self): + self.validate_all( + "'\\''", + write={ + "duckdb": "''''", + "presto": "''''", + "hive": "'\\''", + "spark": "'\\''", + }, + ) + self.validate_all( + "'\"x\"'", + write={ + "duckdb": "'\"x\"'", + "presto": "'\"x\"'", + "hive": "'\"x\"'", + "spark": "'\"x\"'", + }, + ) + self.validate_all( + "'\\\\a'", + read={ + "presto": "'\\a'", + }, + write={ + "duckdb": "'\\a'", + "presto": "'\\a'", + "hive": "'\\\\a'", + "spark": "'\\\\a'", + }, + ) + + def test_table_function(self): + self.validate_all( + "SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))", + write={ + "drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 1ba118b..af98249 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -58,6 +58,16 @@ class TestMySQL(Validator): self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'") self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci") self.validate_identity("SET autocommit = ON") + self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE") + self.validate_identity("SET TRANSACTION READ ONLY") + self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE") + self.validate_identity("SELECT SCHEMA()") + + def test_canonical_functions(self): + self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") + self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')") + self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')") + self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')") def test_escape(self): self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 098ad2b..8179cf7 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -178,6 +178,15 @@ class TestPresto(Validator): }, ) self.validate_all( + "CREATE TABLE test STORED = 'PARQUET' AS SELECT 1", + write={ + "duckdb": "CREATE TABLE test AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", + }, + ) + self.validate_all( "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", write={ "duckdb": "CREATE TABLE test AS SELECT 1", @@ -427,3 +436,69 @@ class TestPresto(Validator): "spark": UnsupportedError, }, ) + self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") + self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") + + def test_encode_decode(self): + self.validate_all( + "TO_UTF8(x)", + write={ + "spark": "ENCODE(x, 'utf-8')", + }, + ) + self.validate_all( + "FROM_UTF8(x)", + write={ + "spark": "DECODE(x, 'utf-8')", + }, + ) + self.validate_all( + "ENCODE(x, 'utf-8')", + write={ + "presto": "TO_UTF8(x)", + }, + ) + self.validate_all( + "DECODE(x, 'utf-8')", + write={ + "presto": "FROM_UTF8(x)", + }, + ) + self.validate_all( + "ENCODE(x, 'invalid')", + write={ + "presto": UnsupportedError, + }, + ) + self.validate_all( + "DECODE(x, 'invalid')", + write={ + "presto": UnsupportedError, + }, + ) + + def test_hex_unhex(self): + self.validate_all( + "TO_HEX(x)", + write={ + "spark": "HEX(x)", + }, + ) + self.validate_all( + "FROM_HEX(x)", + write={ + "spark": "UNHEX(x)", + }, + ) + self.validate_all( + "HEX(x)", + write={ + "presto": "TO_HEX(x)", + }, + ) + self.validate_all( + "UNHEX(x)", + write={ + "presto": "FROM_HEX(x)", + }, + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 1846b17..0e69f4e 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -169,6 +169,17 @@ class TestSnowflake(Validator): "snowflake": "SELECT a FROM test AS unpivot", }, ) + self.validate_all( + "trim(date_column, 'UTC')", + write={ + "snowflake": "TRIM(date_column, 'UTC')", + "postgres": "TRIM('UTC' FROM date_column)", + }, + ) + self.validate_all( + "trim(date_column)", + write={"snowflake": "TRIM(date_column)"}, + ) def test_null_treatment(self): self.validate_all( diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 836ab28..75bd25d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC' CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' SET x = 1 SET -v -ADD JAR s3://bucket -ADD JARS s3://bucket, c -ADD FILE s3://file -ADD FILES s3://file, s3://a -ADD ARCHIVE s3://file -ADD ARCHIVES s3://file, s3://a -BEGIN IMMEDIATE TRANSACTION COMMIT USE db NOT 1 @@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test SELECT COUNT(a) FROM test SELECT COUNT(1) FROM test SELECT COUNT(*) FROM test +SELECT COUNT() FROM test SELECT COUNT(DISTINCT a) FROM test SELECT EXP(a) FROM test SELECT FLOOR(a) FROM test @@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2 WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2 WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 +WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query) +WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query)))) (SELECT 1) UNION (SELECT 2) (SELECT 1) UNION SELECT 2 SELECT 1 UNION (SELECT 2) @@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI CREATE TABLE z (a INT(11) DEFAULT UUID()) CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) +CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1) CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) CREATE TABLE z (a INT, PRIMARY KEY(a)) CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 @@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte)) CREATE TABLE z (a INT UNIQUE) CREATE TABLE z (a INT AUTO_INCREMENT) CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) +CREATE TABLE z (a INT REFERENCES parent(b, c)) +CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) +CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f AS 'g' CREATE FUNCTION f @@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1 DELETE FROM y DELETE FROM event USING sales WHERE event.eventid = sales.eventid DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid +DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid +PREPARE statement +EXECUTE statement DROP TABLE a DROP TABLE a.b DROP TABLE IF EXISTS a DROP TABLE IF EXISTS a.b +DROP TABLE a CASCADE DROP VIEW a DROP VIEW a.b DROP VIEW IF EXISTS a DROP VIEW IF EXISTS a.b SHOW TABLES USE db +BEGIN ROLLBACK +ROLLBACK TO b EXPLAIN SELECT * FROM x INSERT INTO x SELECT * FROM y INSERT INTO x (SELECT * FROM y) @@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ +SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql new file mode 100644 index 0000000..7fcdbb8 --- /dev/null +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -0,0 +1,5 @@ +SELECT w.d + w.e AS c FROM w AS w; +SELECT CONCAT(w.d, w.e) AS c FROM w AS w; + +SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w; +SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index eb7e9cb..a1e531b 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -119,7 +119,7 @@ GROUP BY LIMIT 1; # title: Root subquery is union -(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1; +(SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1; ( SELECT "x"."b" AS "b" @@ -128,6 +128,8 @@ LIMIT 1; SELECT "y"."b" AS "b" FROM "y" AS "y" + ORDER BY + "b" ) LIMIT 1; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index b91205c..8138b11 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -15,7 +15,7 @@ select from lineitem where - CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day + l_shipdate <= date '1998-12-01' - interval '90' day group by l_returnflag, l_linestatus @@ -250,8 +250,8 @@ FROM "orders" AS "orders" LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."l_orderkey" = "orders"."o_orderkey" WHERE - "orders"."o_orderdate" < CAST('1993-10-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE) + CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE) AND NOT "_u_0"."l_orderkey" IS NULL GROUP BY "orders"."o_orderpriority" @@ -293,8 +293,8 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" - AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE) JOIN "region" AS "region" ON "region"."r_name" = 'ASIA' JOIN "nation" AS "nation" @@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem" WHERE "lineitem"."l_discount" BETWEEN 0.05 AND 0.07 AND "lineitem"."l_quantity" < 24 - AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE); + AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE); -------------------------------------- -- TPC-H 7 @@ -384,13 +384,13 @@ WITH "n1" AS ( SELECT "n1"."n_name" AS "supp_nation", "n2"."n_name" AS "cust_nation", - EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year", SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue" FROM "supplier" AS "supplier" JOIN "lineitem" AS "lineitem" - ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" JOIN "orders" AS "orders" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" @@ -409,7 +409,7 @@ JOIN "n1" AS "n2" GROUP BY "n1"."n_name", "n2"."n_name", - EXTRACT(year FROM "lineitem"."l_shipdate") + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) ORDER BY "supp_nation", "cust_nation", @@ -456,7 +456,7 @@ group by order by o_year; SELECT - EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", SUM( CASE WHEN "nation_2"."n_name" = 'BRAZIL' @@ -477,7 +477,7 @@ JOIN "customer" AS "customer" ON "customer"."c_nationkey" = "nation"."n_nationkey" JOIN "orders" AS "orders" ON "orders"."o_custkey" = "customer"."c_custkey" - AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "part"."p_partkey" = "lineitem"."l_partkey" @@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2" WHERE "part"."p_type" = 'ECONOMY ANODIZED STEEL' GROUP BY - EXTRACT(year FROM "orders"."o_orderdate") + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) ORDER BY "o_year"; @@ -529,7 +529,7 @@ order by o_year desc; SELECT "nation"."n_name" AS "nation", - EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", SUM( "lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" @@ -551,7 +551,7 @@ WHERE "part"."p_name" LIKE '%green%' GROUP BY "nation"."n_name", - EXTRACT(year FROM "orders"."o_orderdate") + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) ORDER BY "nation", "o_year" DESC; @@ -606,8 +606,8 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" - AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R' JOIN "nation" AS "nation" @@ -740,8 +740,8 @@ SELECT FROM "orders" AS "orders" JOIN "lineitem" AS "lineitem" ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" - AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) + AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE) AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') AND "orders"."o_orderkey" = "lineitem"."l_orderkey" @@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem" JOIN "part" AS "part" ON "lineitem"."l_partkey" = "part"."p_partkey" WHERE - "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE); + CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE); -------------------------------------- -- TPC-H 15 @@ -876,8 +876,8 @@ WITH "revenue" AS ( )) AS "total_revenue" FROM "lineitem" AS "lineitem" WHERE - "lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE) + CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE) GROUP BY "lineitem"."l_suppkey" ) @@ -1220,8 +1220,8 @@ WITH "_u_0" AS ( "lineitem"."l_suppkey" AS "_u_2" FROM "lineitem" AS "lineitem" WHERE - "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE) + CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE) GROUP BY "lineitem"."l_partkey", "lineitem"."l_suppkey" diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 5e27b5e..067fe77 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -315,3 +315,10 @@ FROM ( WHERE id = 1 ) /* x */; +SELECT * /* multi + line + comment */; +SELECT + * /* multi + line + comment */; diff --git a/tests/helpers.py b/tests/helpers.py index dabaf1c..9abdaae 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower( TPCH_SCHEMA = { "lineitem": { - "l_orderkey": "uint64", - "l_partkey": "uint64", - "l_suppkey": "uint64", - "l_linenumber": "uint64", - "l_quantity": "float64", - "l_extendedprice": "float64", - "l_discount": "float64", - "l_tax": "float64", + "l_orderkey": "bigint", + "l_partkey": "bigint", + "l_suppkey": "bigint", + "l_linenumber": "bigint", + "l_quantity": "double", + "l_extendedprice": "double", + "l_discount": "double", + "l_tax": "double", "l_returnflag": "string", "l_linestatus": "string", - "l_shipdate": "date32", - "l_commitdate": "date32", - "l_receiptdate": "date32", + "l_shipdate": "string", + "l_commitdate": "string", + "l_receiptdate": "string", "l_shipinstruct": "string", "l_shipmode": "string", "l_comment": "string", }, "orders": { - "o_orderkey": "uint64", - "o_custkey": "uint64", + "o_orderkey": "bigint", + "o_custkey": "bigint", "o_orderstatus": "string", - "o_totalprice": "float64", - "o_orderdate": "date32", + "o_totalprice": "double", + "o_orderdate": "string", "o_orderpriority": "string", "o_clerk": "string", - "o_shippriority": "int32", + "o_shippriority": "int", "o_comment": "string", }, "customer": { - "c_custkey": "uint64", + "c_custkey": "bigint", "c_name": "string", "c_address": "string", - "c_nationkey": "uint64", + "c_nationkey": "bigint", "c_phone": "string", - "c_acctbal": "float64", + "c_acctbal": "double", "c_mktsegment": "string", "c_comment": "string", }, "part": { - "p_partkey": "uint64", + "p_partkey": "bigint", "p_name": "string", "p_mfgr": "string", "p_brand": "string", "p_type": "string", - "p_size": "int32", + "p_size": "int", "p_container": "string", - "p_retailprice": "float64", + "p_retailprice": "double", "p_comment": "string", }, "supplier": { - "s_suppkey": "uint64", + "s_suppkey": "bigint", "s_name": "string", "s_address": "string", - "s_nationkey": "uint64", + "s_nationkey": "bigint", "s_phone": "string", - "s_acctbal": "float64", + "s_acctbal": "double", "s_comment": "string", }, "partsupp": { - "ps_partkey": "uint64", - "ps_suppkey": "uint64", - "ps_availqty": "int32", - "ps_supplycost": "float64", + "ps_partkey": "bigint", + "ps_suppkey": "bigint", + "ps_availqty": "int", + "ps_supplycost": "double", "ps_comment": "string", }, "nation": { - "n_nationkey": "uint64", + "n_nationkey": "bigint", "n_name": "string", - "n_regionkey": "uint64", + "n_regionkey": "bigint", "n_comment": "string", }, "region": { - "r_regionkey": "uint64", + "r_regionkey": "bigint", "r_name": "string", "r_comment": "string", }, diff --git a/tests/test_executor.py b/tests/test_executor.py index 49805b9..2c4d7cd 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,12 +1,15 @@ import unittest +from datetime import date import duckdb import pandas as pd from pandas.testing import assert_frame_equal from sqlglot import exp, parse_one +from sqlglot.errors import ExecuteError from sqlglot.executor import execute from sqlglot.executor.python import Python +from sqlglot.executor.table import Table, ensure_tables from tests.helpers import ( FIXTURES_DIR, SKIP_INTEGRATION, @@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase): def to_csv(expression): if isinstance(expression, exp.Table): return parse_one( - f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" + f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" ) return expression - for sql, _ in self.sqls[0:3]: - a = self.cached_execute(sql) - sql = parse_one(sql).transform(to_csv).sql(pretty=True) - table = execute(sql, TPCH_SCHEMA) - b = pd.DataFrame(table.rows, columns=table.columns) - assert_frame_equal(a, b, check_dtype=False) + for i, (sql, _) in enumerate(self.sqls[0:7]): + with self.subTest(f"tpch-h {i + 1}"): + a = self.cached_execute(sql) + sql = parse_one(sql).transform(to_csv).sql(pretty=True) + table = execute(sql, TPCH_SCHEMA) + b = pd.DataFrame(table.rows, columns=table.columns) + assert_frame_equal(a, b, check_dtype=False) + + def test_execute_callable(self): + tables = { + "x": [ + {"a": "a", "b": "d"}, + {"a": "b", "b": "e"}, + {"a": "c", "b": "f"}, + ], + "y": [ + {"b": "d", "c": "g"}, + {"b": "e", "c": "h"}, + {"b": "f", "c": "i"}, + ], + "z": [], + } + schema = { + "x": { + "a": "VARCHAR", + "b": "VARCHAR", + }, + "y": { + "b": "VARCHAR", + "c": "VARCHAR", + }, + "z": {"d": "VARCHAR"}, + } + + for sql, cols, rows in [ + ("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]), + ( + "SELECT * FROM x JOIN y ON x.b = y.b", + ["a", "b", "b", "c"], + [("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")], + ), + ( + "SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b", + ["d"], + [("g",), ("h",), ("i",)], + ), + ( + "SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'", + ["_col_0"], + [("bh",)], + ), + ( + "SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'", + ["a", "b", "b", "c"], + [("b", "e", "e", "h")], + ), + ( + "SELECT * FROM z", + ["d"], + [], + ), + ( + "SELECT d FROM z ORDER BY d", + ["d"], + [], + ), + ( + "SELECT a FROM x WHERE x.a <> 'b'", + ["a"], + [("a",), ("c",)], + ), + ( + "SELECT a AS i FROM x ORDER BY a", + ["i"], + [("a",), ("b",), ("c",)], + ), + ( + "SELECT a AS i FROM x ORDER BY i", + ["i"], + [("a",), ("b",), ("c",)], + ), + ( + "SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a", + ["a", "i"], + [(1, "c"), (2, "b"), (3, "a")], + ), + ( + "SELECT a /* test */ FROM x LIMIT 1", + ["a"], + [("a",)], + ), + ]: + with self.subTest(sql): + result = execute(sql, schema=schema, tables=tables) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(result.rows, rows) + + def test_set_operations(self): + tables = { + "x": [ + {"a": "a"}, + {"a": "b"}, + {"a": "c"}, + ], + "y": [ + {"a": "b"}, + {"a": "c"}, + {"a": "d"}, + ], + } + schema = { + "x": { + "a": "VARCHAR", + }, + "y": { + "a": "VARCHAR", + }, + } + + for sql, cols, rows in [ + ( + "SELECT a FROM x UNION ALL SELECT a FROM y", + ["a"], + [("a",), ("b",), ("c",), ("b",), ("c",), ("d",)], + ), + ( + "SELECT a FROM x UNION SELECT a FROM y", + ["a"], + [("a",), ("b",), ("c",), ("d",)], + ), + ( + "SELECT a FROM x EXCEPT SELECT a FROM y", + ["a"], + [("a",)], + ), + ( + "SELECT a FROM x INTERSECT SELECT a FROM y", + ["a"], + [("b",), ("c",)], + ), + ( + """SELECT i.a + FROM ( + SELECT a FROM x UNION SELECT a FROM y + ) AS i + JOIN ( + SELECT a FROM x UNION SELECT a FROM y + ) AS j + ON i.a = j.a""", + ["a"], + [("a",), ("b",), ("c",), ("d",)], + ), + ( + "SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a", + ["a"], + [(1,), (2,), (3,)], + ), + ]: + with self.subTest(sql): + result = execute(sql, schema=schema, tables=tables) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(set(result.rows), set(rows)) + + def test_execute_catalog_db_table(self): + tables = { + "catalog": { + "db": { + "x": [ + {"a": "a"}, + {"a": "b"}, + {"a": "c"}, + ], + } + } + } + schema = { + "catalog": { + "db": { + "x": { + "a": "VARCHAR", + } + } + } + } + result1 = execute("SELECT * FROM x", schema=schema, tables=tables) + result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables) + assert result1.columns == result2.columns + assert result1.rows == result2.rows + + def test_execute_tables(self): + tables = { + "sushi": [ + {"id": 1, "price": 1.0}, + {"id": 2, "price": 2.0}, + {"id": 3, "price": 3.0}, + ], + "order_items": [ + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 2, "order_id": 1}, + {"sushi_id": 3, "order_id": 2}, + ], + "orders": [ + {"id": 1, "user_id": 1}, + {"id": 2, "user_id": 2}, + ], + } + + self.assertEqual( + execute( + """ + SELECT + o.user_id, + SUM(s.price) AS price + FROM orders o + JOIN order_items i + ON o.id = i.order_id + JOIN sushi s + ON i.sushi_id = s.id + GROUP BY o.user_id + """, + tables=tables, + ).rows, + [ + (1, 4.0), + (2, 3.0), + ], + ) + + self.assertEqual( + execute( + """ + SELECT + o.id, x.* + FROM orders o + LEFT JOIN ( + SELECT + 1 AS id, 'b' AS x + UNION ALL + SELECT + 3 AS id, 'c' AS x + ) x + ON o.id = x.id + """, + tables=tables, + ).rows, + [(1, 1, "b"), (2, None, None)], + ) + self.assertEqual( + execute( + """ + SELECT + o.id, x.* + FROM orders o + RIGHT JOIN ( + SELECT + 1 AS id, + 'b' AS x + UNION ALL + SELECT + 3 AS id, 'c' AS x + ) x + ON o.id = x.id + """, + tables=tables, + ).rows, + [ + (1, 1, "b"), + (None, 3, "c"), + ], + ) + + def test_table_depth_mismatch(self): + tables = {"table": []} + schema = {"db": {"table": {"col": "VARCHAR"}}} + with self.assertRaises(ExecuteError): + execute("SELECT * FROM table", schema=schema, tables=tables) + + def test_tables(self): + tables = ensure_tables( + { + "catalog1": { + "db1": { + "t1": [ + {"a": 1}, + ], + "t2": [ + {"a": 1}, + ], + }, + "db2": { + "t3": [ + {"a": 1}, + ], + "t4": [ + {"a": 1}, + ], + }, + }, + "catalog2": { + "db3": { + "t5": Table(columns=("a",), rows=[(1,)]), + "t6": Table(columns=("a",), rows=[(1,)]), + }, + "db4": { + "t7": Table(columns=("a",), rows=[(1,)]), + "t8": Table(columns=("a",), rows=[(1,)]), + }, + }, + } + ) + + t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1")) + self.assertEqual(t1.columns, ("a",)) + self.assertEqual(t1.rows, [(1,)]) + + t8 = tables.find(exp.table_(table="t8")) + self.assertEqual(t1.columns, t8.columns) + self.assertEqual(t1.rows, t8.rows) + + def test_static_queries(self): + for sql, cols, rows in [ + ("SELECT 1", ["_col_0"], [(1,)]), + ("SELECT 1 + 2 AS x", ["x"], [(3,)]), + ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), + ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), + ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), + ]: + result = execute(sql) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(result.rows, rows) + + def test_aggregate_without_group_by(self): + result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]}) + self.assertEqual(result.columns, ("_col_0",)) + self.assertEqual(result.rows, [(3,)]) + + def test_scalar_functions(self): + for sql, expected in [ + ("CONCAT('a', 'b')", "ab"), + ("CONCAT('a', NULL)", None), + ("CONCAT_WS('_', 'a', 'b')", "a_b"), + ("STR_POSITION('bar', 'foobarbar')", 4), + ("STR_POSITION('bar', 'foobarbar', 5)", 7), + ("STR_POSITION(NULL, 'foobarbar')", None), + ("STR_POSITION('bar', NULL)", None), + ("UPPER('foo')", "FOO"), + ("UPPER(NULL)", None), + ("LOWER('FOO')", "foo"), + ("LOWER(NULL)", None), + ("IFNULL('a', 'b')", "a"), + ("IFNULL(NULL, 'b')", "b"), + ("IFNULL(NULL, NULL)", None), + ("SUBSTRING('12345')", "12345"), + ("SUBSTRING('12345', 3)", "345"), + ("SUBSTRING('12345', 3, 0)", ""), + ("SUBSTRING('12345', 3, 1)", "3"), + ("SUBSTRING('12345', 3, 2)", "34"), + ("SUBSTRING('12345', 3, 3)", "345"), + ("SUBSTRING('12345', 3, 4)", "345"), + ("SUBSTRING('12345', -3)", "345"), + ("SUBSTRING('12345', -3, 0)", ""), + ("SUBSTRING('12345', -3, 1)", "3"), + ("SUBSTRING('12345', -3, 2)", "34"), + ("SUBSTRING('12345', 0)", ""), + ("SUBSTRING('12345', 0, 1)", ""), + ("SUBSTRING(NULL)", None), + ("SUBSTRING(NULL, 1)", None), + ("CAST(1 AS TEXT)", "1"), + ("CAST('1' AS LONG)", 1), + ("CAST('1.1' AS FLOAT)", 1.1), + ("COALESCE(NULL)", None), + ("COALESCE(NULL, NULL)", None), + ("COALESCE(NULL, 'b')", "b"), + ("COALESCE('a', 'b')", "a"), + ("1 << 1", 2), + ("1 >> 1", 0), + ("1 & 1", 1), + ("1 | 1", 1), + ("1 < 1", False), + ("1 <= 1", True), + ("1 > 1", False), + ("1 >= 1", True), + ("1 + NULL", None), + ("IF(true, 1, 0)", 1), + ("IF(false, 1, 0)", 0), + ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"), + ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)), + ]: + with self.subTest(sql): + result = execute(f"SELECT {sql}") + self.assertEqual(result.rows, [(expected,)]) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 63371d8..c0927ad 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance) self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) + self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction) + self.assertIsInstance(parse_one("COMMIT"), exp.Commit) + self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback) def test_column(self): dot = parse_one("a.b.c") @@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(column.text("expression"), "c") self.assertEqual(column.text("y"), "") self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x") - self.assertEqual(parse_one("select *").text("this"), "") - self.assertEqual(parse_one("1 + 1").text("this"), "1") - self.assertEqual(parse_one("'a'").text("this"), "a") + self.assertEqual(parse_one("select *").name, "") + self.assertEqual(parse_one("1 + 1").name, "1") + self.assertEqual(parse_one("'a'").name, "a") def test_alias(self): self.assertEqual(alias("foo", "bar").sql(), "foo AS bar") @@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase): this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format"), ), - exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL), - exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE), + exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()), + exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()), ] ), ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a1b7e70..6637a1d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase): CREATE TABLE x (a INT, b INT); CREATE TABLE y (b INT, c INT); CREATE TABLE z (b INT, c INT); + CREATE TABLE w (d TEXT, e TEXT); INSERT INTO x VALUES (1, 1); INSERT INTO x VALUES (2, 2); @@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase): INSERT INTO y VALUES (4, 4); INSERT INTO y VALUES (5, 5); INSERT INTO y VALUES (null, null); + + INSERT INTO w VALUES ('a', 'b'); """ ) @@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase): "b": "INT", "c": "INT", }, + "w": { + "d": "TEXT", + "e": "TEXT", + }, } def check_file(self, file, func, pretty=False, execute=False, **kwargs): @@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase): def test_eliminate_subqueries(self): self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) + def test_canonicalize(self): + optimize = partial( + optimizer.optimize, + rules=[ + optimizer.qualify_tables.qualify_tables, + optimizer.qualify_columns.qualify_columns, + annotate_types, + optimizer.canonicalize.canonicalize, + ], + ) + self.check_file("canonicalize", optimize, schema=self.schema) + def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) diff --git a/tests/test_parser.py b/tests/test_parser.py index 04c20b1..c747ea3 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -41,12 +41,41 @@ class TestParser(unittest.TestCase): ) def test_command(self): - expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1") + expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive") self.assertEqual(len(expressions), 3) self.assertEqual(expressions[0].sql(), "SET x = 1") self.assertEqual(expressions[1].sql(), "ADD JAR s3://a") self.assertEqual(expressions[2].sql(), "SELECT 1") + def test_transactions(self): + expression = parse_one("BEGIN TRANSACTION") + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("START TRANSACTION", read="mysql") + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("BEGIN DEFERRED TRANSACTION") + self.assertEqual(expression.this, "DEFERRED") + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one( + "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto" + ) + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"][0], "READ WRITE") + self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE") + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("BEGIN", read="bigquery") + self.assertNotIsInstance(expression, exp.Transaction) + self.assertIsNone(expression.expression) + self.assertEqual(expression.sql(), "BEGIN") + def test_identify(self): expression = parse_one( """ @@ -55,14 +84,14 @@ class TestParser(unittest.TestCase): """ ) - assert expression.expressions[0].text("this") == "a" - assert expression.expressions[1].text("this") == "b" - assert expression.expressions[2].text("alias") == "c" - assert expression.expressions[3].text("alias") == "D" - assert expression.expressions[4].text("alias") == "y|z'" + assert expression.expressions[0].name == "a" + assert expression.expressions[1].name == "b" + assert expression.expressions[2].alias == "c" + assert expression.expressions[3].alias == "D" + assert expression.expressions[4].alias == "y|z'" table = expression.args["from"].expressions[0] - assert table.args["this"].args["this"] == "z" - assert table.args["db"].args["this"] == "y" + assert table.this.name == "z" + assert table.args["db"].name == "y" def test_multi(self): expressions = parse( @@ -72,8 +101,8 @@ class TestParser(unittest.TestCase): ) assert len(expressions) == 2 - assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" - assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b" + assert expressions[0].args["from"].expressions[0].this.name == "a" + assert expressions[1].args["from"].expressions[0].this.name == "b" def test_expression(self): ignore = Parser(error_level=ErrorLevel.IGNORE) @@ -200,7 +229,7 @@ class TestParser(unittest.TestCase): @patch("sqlglot.parser.logger") def test_comment_error_n(self, logger): parse_one( - """CREATE TABLE x + """SUM ( -- test )""", @@ -208,19 +237,19 @@ class TestParser(unittest.TestCase): ) assert_logger_contains( - "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.", + "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.", logger, ) @patch("sqlglot.parser.logger") def test_comment_error_r(self, logger): parse_one( - """CREATE TABLE x (-- test\r)""", + """SUM(-- test\r)""", error_level=ErrorLevel.WARN, ) assert_logger_contains( - "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.", + "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.", logger, ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 943c2b0..d4772ba 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase): ("--comment\nfoo --test", "comment"), ("foo --comment", "comment"), ("foo", None), + ("foo /*comment 1*/ /*comment 2*/", "comment 1"), ] for sql, comment in sql_comment: diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 942053e..1bd2527 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase): self.assertEqual(transpile(sql, **kwargs)[0], target) def test_alias(self): + self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time") + self.assertEqual( + transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp" + ) + self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date") + self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") + for key in ("union", "filter", "over", "from", "join"): with self.subTest(f"alias {key}"): self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") @@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): + self.validate("SELECT */*comment*/", "SELECT * /* comment */") + self.validate( + "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */" + ) self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo") |