From 244a05de60c9417daab9528b51788c3d2a00dc5f Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 17 Jan 2023 11:32:12 +0100 Subject: Adding upstream version 10.5.2. Signed-off-by: Daniel Baumann --- CHANGELOG.md | 29 + README.md | 21 +- benchmarks/bench.py | 2 +- posts/python_sql_engine.md | 6 +- setup.py | 2 +- sqlglot/__init__.py | 13 +- sqlglot/dialects/bigquery.py | 7 +- sqlglot/dialects/clickhouse.py | 35 +- sqlglot/dialects/dialect.py | 17 + sqlglot/dialects/hive.py | 23 +- sqlglot/dialects/oracle.py | 3 +- sqlglot/dialects/postgres.py | 21 +- sqlglot/dialects/snowflake.py | 8 +- sqlglot/dialects/tsql.py | 22 +- sqlglot/expressions.py | 117 +++- sqlglot/generator.py | 69 ++- sqlglot/helper.py | 20 +- sqlglot/optimizer/annotate_types.py | 2 +- sqlglot/optimizer/eliminate_joins.py | 4 +- sqlglot/optimizer/merge_subqueries.py | 54 +- sqlglot/optimizer/optimizer.py | 6 +- sqlglot/optimizer/pushdown_projections.py | 4 + sqlglot/optimizer/qualify_columns.py | 4 +- sqlglot/optimizer/simplify.py | 19 +- sqlglot/optimizer/unnest_subqueries.py | 38 +- sqlglot/parser.py | 652 +++++++++++++++------- sqlglot/schema.py | 45 +- sqlglot/serde.py | 67 +++ sqlglot/tokens.py | 19 +- sqlglot/transforms.py | 24 + sqlglot/trie.py | 2 +- tests/dataframe/unit/test_functions.py | 10 +- tests/dataframe/unit/test_window.py | 20 +- tests/dialects/test_bigquery.py | 42 +- tests/dialects/test_clickhouse.py | 9 + tests/dialects/test_dialect.py | 6 +- tests/dialects/test_hive.py | 20 +- tests/dialects/test_postgres.py | 4 + tests/dialects/test_presto.py | 6 +- tests/dialects/test_snowflake.py | 30 + tests/dialects/test_spark.py | 1 + tests/dialects/test_tsql.py | 24 + tests/fixtures/identity.sql | 16 + tests/fixtures/optimizer/merge_subqueries.sql | 39 ++ tests/fixtures/optimizer/optimizer.sql | 2 +- tests/fixtures/optimizer/pushdown_projections.sql | 3 + tests/fixtures/optimizer/simplify.sql | 12 + tests/fixtures/optimizer/tpc-h/tpc-h.sql | 8 +- tests/fixtures/optimizer/unnest_subqueries.sql | 59 +- tests/helpers.py | 3 +- tests/test_executor.py | 30 + tests/test_expressions.py | 69 +++ tests/test_optimizer.py | 17 +- tests/test_parser.py | 11 +- tests/test_schema.py | 16 + tests/test_serde.py | 33 ++ tests/test_transforms.py | 13 +- tests/test_transpile.py | 5 + 58 files changed, 1480 insertions(+), 383 deletions(-) create mode 100644 sqlglot/serde.py create mode 100644 tests/test_serde.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bf8699e..762ab92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,35 @@ Changelog ========= +v10.5.0 +------ + +Changes: + +- Breaking: Added python type hints in the parser module, which may result in some mypy errors. + +- New: SQLGlot expressions can [now be serialized / deserialized into JSON](https://github.com/tobymao/sqlglot/commit/bac38151a8d72687247922e6898696be43ff4992). + +- New: Added support for T-SQL [hints](https://github.com/tobymao/sqlglot/commit/3220ec1adb1e1130b109677d03c9be947b03f9ca) and [EOMONTH](https://github.com/tobymao/sqlglot/commit/1ac05d9265667c883b9f6db5d825a6d864c95c73). + +- New: Added support for Clickhouse's parametric function syntax. + +- New: Added [wider support](https://github.com/tobymao/sqlglot/commit/beb660f943b73c730f1b06fce4986e26642ee8dc) for timestr and datestr. + +- New: CLI now accepts a flag [for parsing SQL from the standard input stream](https://github.com/tobymao/sqlglot/commit/f89b38ebf3e24ba951ee8b249d73bbf48685928a). + +- Improvement: Fixed BigQuery transpilation for [parameterized types and unnest](https://github.com/tobymao/sqlglot/pull/924). + +- Improvement: Hive / Spark identifiers can now begin with a digit. + +- Improvement: Bug fixes in [date/datetime simplification](https://github.com/tobymao/sqlglot/commit/b26b8d88af14f72d90c0019ec332d268a23b078f). + +- Improvement: Bug fixes in [merge_subquery](https://github.com/tobymao/sqlglot/commit/e30e21b6c572d0931bfb5873cc6ac3949c6ef5aa). + +- Improvement: Schema identifiers are now [converted to lowercase](https://github.com/tobymao/sqlglot/commit/8212032968a519c199b461eba1a2618e89bf0326) unless they're quoted. + +- Improvement: Identifiers with a leading underscore are now regarded as [safe](https://github.com/tobymao/sqlglot/commit/de3b0804bb7606673d0bbb989997c13744957f7c#diff-7857fedd1d1451b1b9a5b8efaa1cc292c02e7ee4f0d04d7e2f9d5bfb9565802c) and hence are not quoted. + v10.4.0 ------ diff --git a/README.md b/README.md index 06cb791..85a76e5 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # SQLGlot -SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. +SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [18 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python. You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. -Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. +Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that the parser is very lenient when it comes to detecting errors, because it aims to consume as much SQL as possible. On one hand, this makes its implementation simpler, and thus more comprehensible, but on the other hand it means that syntax errors may sometimes go unnoticed. Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started! @@ -25,6 +25,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/ * [AST Diff](#ast-diff) * [Custom Dialects](#custom-dialects) * [SQL Execution](#sql-execution) +* [Used By](#used-by) * [Documentation](#documentation) * [Run Tests and Lint](#run-tests-and-lint) * [Benchmarks](#benchmarks) @@ -165,7 +166,7 @@ for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table): ### Parser Errors -A syntax error will result in a parser error: +When the parser detects an error in the syntax, it raises a ParserError: ```python import sqlglot @@ -283,13 +284,13 @@ print( ```sql SELECT ( - "x"."A" OR "x"."B" OR "x"."C" + "x"."a" OR "x"."b" OR "x"."c" ) AND ( - "x"."A" OR "x"."B" OR "x"."D" + "x"."a" OR "x"."b" OR "x"."d" ) AS "_col_0" FROM "x" AS "x" WHERE - "x"."Z" = CAST('2021-02-01' AS DATE) + CAST("x"."z" AS DATE) = CAST('2021-02-01' AS DATE) ``` ### AST Introspection @@ -432,6 +433,14 @@ user_id price 2 3.0 ``` +## Used By +* [Fugue](https://github.com/fugue-project/fugue) +* [ibis](https://github.com/ibis-project/ibis) +* [mysql-mimic](https://github.com/kelsin/mysql-mimic) +* [Querybook](https://github.com/pinterest/querybook) +* [Quokka](https://github.com/marsupialtail/quokka) +* [Splink](https://github.com/moj-analytical-services/splink) + ## Documentation SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation: diff --git a/benchmarks/bench.py b/benchmarks/bench.py index 2475608..002d330 100644 --- a/benchmarks/bench.py +++ b/benchmarks/bench.py @@ -23,7 +23,7 @@ SELECT "e"."phone_number" AS "Phone", TO_CHAR("e"."hire_date", 'MM/DD/YYYY') AS "Hire Date", TO_CHAR("e"."salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Salary", - "e"."commission_pct" AS "Comission %", + "e"."commission_pct" AS "Commission %", 'works as ' || "j"."job_title" || ' in ' || "d"."department_name" || ' department (manager: ' || "dm"."first_name" || ' ' || "dm"."last_name" || ') and immediate supervisor: ' || "m"."first_name" || ' ' || "m"."last_name" AS "Current Job", TO_CHAR("j"."min_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') || ' - ' || TO_CHAR("j"."max_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Current Salary", "l"."street_address" || ', ' || "l"."postal_code" || ', ' || "l"."city" || ', ' || "l"."state_province" || ', ' || "c"."country_name" || ' (' || "r"."region_name" || ')' AS "Location", diff --git a/posts/python_sql_engine.md b/posts/python_sql_engine.md index 1c74680..01ad71d 100644 --- a/posts/python_sql_engine.md +++ b/posts/python_sql_engine.md @@ -14,13 +14,13 @@ This post will cover [why](#why) I went through the effort of creating a Python * [Executing](#executing) ## Why? -I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler. +I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programmatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler. Why did I do this? Isn't a Python SQL engine going to be extremely slow? The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers. -In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 millon rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds. +In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 million rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds. Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance. @@ -77,7 +77,7 @@ Once we have our AST, we can transform it into an equivalent query that produces 1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL. -2. Rules can be applied a la carte to transform SQL into a more desireable form. +2. Rules can be applied a la carte to transform SQL into a more desirable form. 3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`). diff --git a/setup.py b/setup.py index 2c0a3be..8d8a923 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ setup( "black", "duckdb", "isort", - "mypy", + "mypy>=0.990", "pandas", "pyspark", "python-dateutil", diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 04c3195..87fa081 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -32,7 +32,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.4.2" +__version__ = "10.5.2" pretty = False @@ -60,9 +60,9 @@ def parse( def parse_one( sql: str, read: t.Optional[str | Dialect] = None, - into: t.Optional[Expression | str] = None, + into: t.Optional[t.Type[Expression] | str] = None, **opts, -) -> t.Optional[Expression]: +) -> Expression: """ Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. @@ -83,7 +83,12 @@ def parse_one( else: result = dialect.parse(sql, **opts) - return result[0] if result else None + for expression in result: + if not expression: + raise ParseError(f"No expression was parsed from '{sql}'") + return expression + else: + raise ParseError(f"No expression was parsed from '{sql}'") def transpile( diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index d10cc54..f0089e1 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -2,7 +2,7 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, @@ -46,8 +46,9 @@ def _date_add_sql(data_type, kind): def _derived_table_values_to_unnest(self, expression): if not isinstance(expression.unnest().parent, exp.From): + expression = transforms.remove_precision_parameterized_types(expression) return self.values_sql(expression) - rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)] + rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)] structs = [] for row in rows: aliases = [ @@ -118,6 +119,7 @@ class BigQuery(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_TIME": TokenType.CURRENT_TIME, + "DECLARE": TokenType.COMMAND, "GEOGRAPHY": TokenType.GEOGRAPHY, "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, @@ -166,6 +168,7 @@ class BigQuery(Dialect): class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 7136340..04d46d2 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.parser import parse_var_map @@ -22,6 +24,7 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ASOF": TokenType.ASOF, + "GLOBAL": TokenType.GLOBAL, "DATETIME64": TokenType.DATETIME, "FINAL": TokenType.FINAL, "FLOAT32": TokenType.FLOAT, @@ -37,14 +40,32 @@ class ClickHouse(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "MAP": parse_var_map, + "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), + "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), + "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), + } + + RANGE_PARSERS = { + **parser.Parser.RANGE_PARSERS, + TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN) + and self._parse_in(this, is_global=True), } JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore - def _parse_table(self, schema=False): - this = super()._parse_table(schema) + def _parse_in( + self, this: t.Optional[exp.Expression], is_global: bool = False + ) -> exp.Expression: + this = super()._parse_in(this) + this.set("is_global", is_global) + return this + + def _parse_table( + self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: + this = super()._parse_table(schema=schema, alias_tokens=alias_tokens) if self._match(TokenType.FINAL): this = self.expression(exp.Final, this=this) @@ -76,6 +97,16 @@ class ClickHouse(Dialect): exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), + exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", + exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", } EXPLICIT_UNION = True + + def _param_args_sql( + self, expression: exp.Expression, params_name: str, args_name: str + ) -> str: + params = self.format_args(self.expressions(expression, params_name)) + args = self.format_args(self.expressions(expression, args_name)) + return f"({params})({args})" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index e788852..1c840da 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -381,3 +381,20 @@ def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + remove_chars = self.sql(expression, "expression") + collation = self.sql(expression, "collation") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific + if not remove_chars and not collation: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + collation = f" COLLATE {collation}" if collation else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 088555c..ead13b1 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -175,14 +175,6 @@ class Hive(Dialect): ESCAPES = ["\\"] ENCODE = "utf-8" - NUMERIC_LITERALS = { - "L": "BIGINT", - "S": "SMALLINT", - "Y": "TINYINT", - "D": "DOUBLE", - "F": "FLOAT", - "BD": "DECIMAL", - } KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ADD ARCHIVE": TokenType.COMMAND, @@ -191,9 +183,21 @@ class Hive(Dialect): "ADD FILES": TokenType.COMMAND, "ADD JAR": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND, + "MSCK REPAIR": TokenType.COMMAND, "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, } + NUMERIC_LITERALS = { + "L": "BIGINT", + "S": "SMALLINT", + "Y": "TINYINT", + "D": "DOUBLE", + "F": "FLOAT", + "BD": "DECIMAL", + } + + IDENTIFIER_CAN_START_WITH_DIGIT = True + class Parser(parser.Parser): STRICT_CAST = False @@ -315,6 +319,7 @@ class Hive(Dialect): exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}", exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), + exp.LastDateOfMonth: rename_func("LAST_DAY"), } WITH_PROPERTIES = {exp.Property} @@ -342,4 +347,6 @@ class Hive(Dialect): and not expression.expressions ): expression = exp.DataType.build("text") + elif expression.this in exp.DataType.TEMPORAL_TYPES: + expression = exp.DataType.build(expression.this) return super().datatype_sql(expression) diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index af3d353..86caa6b 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func +from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql from sqlglot.helper import csv from sqlglot.tokens import TokenType @@ -64,6 +64,7 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, + exp.Trim: trim_sql, exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index a092cad..f3fec31 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, str_position_sql, + trim_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -81,23 +82,6 @@ def _substring_sql(self, expression): return f"SUBSTRING({this}{from_part}{for_part})" -def _trim_sql(self, expression): - target = self.sql(expression, "this") - trim_type = self.sql(expression, "position") - remove_chars = self.sql(expression, "expression") - collation = self.sql(expression, "collation") - - # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific - if not remove_chars and not collation: - return self.trim_sql(expression) - - trim_type = f"{trim_type} " if trim_type else "" - remove_chars = f"{remove_chars} " if remove_chars else "" - from_part = "FROM " if trim_type or remove_chars else "" - collation = f" COLLATE {collation}" if collation else "" - return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" - - def _string_agg_sql(self, expression): expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") @@ -248,7 +232,6 @@ class Postgres(Dialect): "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, - "DOUBLE PRECISION": TokenType.DOUBLE, "GENERATED": TokenType.GENERATED, "GRANT": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, @@ -318,7 +301,7 @@ class Postgres(Dialect): exp.Substring: _substring_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, - exp.Trim: _trim_sql, + exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.DataType: _datatype_sql, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 77b09e9..24d3bdf 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -195,7 +195,6 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "QUALIFY": TokenType.QUALIFY, - "DOUBLE PRECISION": TokenType.DOUBLE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -294,3 +293,10 @@ class Snowflake(Dialect): ) return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) return super().select_sql(expression) + + def describe_sql(self, expression: exp.Describe) -> str: + # Default to table if kind is unknown + kind_value = expression.args.get("kind") or "TABLE" + kind = f" {kind_value}" if kind_value else "" + this = f" {self.sql(expression, 'this')}" + return f"DESCRIBE{kind}{this}" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 7f0f2d7..465f534 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -75,6 +75,20 @@ def _parse_format(args): ) +def _parse_eomonth(args): + date = seq_get(args, 0) + month_lag = seq_get(args, 1) + unit = DATE_DELTA_INTERVAL.get("month") + + if month_lag is None: + return exp.LastDateOfMonth(this=date) + + # Remove month lag argument in parser as its compared with the number of arguments of the resulting class + args.remove(month_lag) + + return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) + + def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" @@ -256,12 +270,14 @@ class TSQL(Dialect): "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), - "GETDATE": exp.CurrentDate.from_arg_list, + "GETDATE": exp.CurrentTimestamp.from_arg_list, + "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "IIF": exp.If.from_arg_list, "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, "FORMAT": _parse_format, + "EOMONTH": _parse_eomonth, } VAR_LENGTH_DATATYPES = { @@ -271,6 +287,9 @@ class TSQL(Dialect): DataType.Type.NCHAR, } + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table + TABLE_PREFIX_TOKENS = {TokenType.HASH} + def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) @@ -323,6 +342,7 @@ class TSQL(Dialect): exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), + exp.CurrentTimestamp: rename_func("GETDATE"), exp.If: rename_func("IIF"), exp.NumberToStr: _format_sql, exp.TimeToStr: _format_sql, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 711ec4b..d093e29 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -22,6 +22,7 @@ from sqlglot.helper import ( split_num_words, subclasses, ) +from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import Dialect @@ -457,6 +458,23 @@ class Expression(metaclass=_Expression): assert isinstance(self, type_) return self + def dump(self): + """ + Dump this Expression to a JSON-serializable dict. + """ + from sqlglot.serde import dump + + return dump(self) + + @classmethod + def load(cls, obj): + """ + Load a dict (as returned by `Expression.dump`) into an Expression instance. + """ + from sqlglot.serde import load + + return load(obj) + class Condition(Expression): def and_(self, *expressions, dialect=None, **opts): @@ -631,11 +649,15 @@ class Create(Expression): "replace": False, "unique": False, "materialized": False, + "data": False, + "statistics": False, + "no_primary_index": False, + "indexes": False, } class Describe(Expression): - pass + arg_types = {"this": True, "kind": False} class Set(Expression): @@ -731,7 +753,7 @@ class Column(Condition): class ColumnDef(Expression): arg_types = { "this": True, - "kind": True, + "kind": False, "constraints": False, "exists": False, } @@ -879,7 +901,15 @@ class Identifier(Expression): class Index(Expression): - arg_types = {"this": False, "table": False, "where": False, "columns": False} + arg_types = { + "this": False, + "table": False, + "where": False, + "columns": False, + "unique": False, + "primary": False, + "amp": False, # teradata + } class Insert(Expression): @@ -1361,6 +1391,7 @@ class Table(Expression): "laterals": False, "joins": False, "pivots": False, + "hints": False, } @@ -1818,7 +1849,12 @@ class Select(Subqueryable): join.this.replace(join.this.subquery()) if join_type: + natural: t.Optional[Token] + side: t.Optional[Token] + kind: t.Optional[Token] + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore + if natural: join.set("natural", True) if side: @@ -2111,6 +2147,7 @@ class DataType(Expression): JSON = auto() JSONB = auto() INTERVAL = auto() + TIME = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -2171,11 +2208,24 @@ class DataType(Expression): } @classmethod - def build(cls, dtype, **kwargs) -> DataType: - return DataType( - this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], - **kwargs, - ) + def build( + cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs + ) -> DataType: + from sqlglot import parse_one + + if isinstance(dtype, str): + data_type_exp: t.Optional[Expression] + if dtype.upper() in cls.Type.__members__: + data_type_exp = DataType(this=DataType.Type[dtype.upper()]) + else: + data_type_exp = parse_one(dtype, read=dialect, into=DataType) + if data_type_exp is None: + raise ValueError(f"Unparsable data type value: {dtype}") + elif isinstance(dtype, DataType.Type): + data_type_exp = DataType(this=dtype) + else: + raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") + return DataType(**{**data_type_exp.args, **kwargs}) # https://www.postgresql.org/docs/15/datatype-pseudo.html @@ -2429,6 +2479,7 @@ class In(Predicate): "query": False, "unnest": False, "field": False, + "is_global": False, } @@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class LastDateOfMonth(Func): + pass + + class Extract(Func): arg_types = {"this": True, "expression": True} @@ -2815,7 +2870,13 @@ class Length(Func): class Levenshtein(Func): - arg_types = {"this": True, "expression": False} + arg_types = { + "this": True, + "expression": False, + "ins_cost": False, + "del_cost": False, + "sub_cost": False, + } class Ln(Func): @@ -2890,6 +2951,16 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} +# Clickhouse-specific: +# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles +class Quantiles(AggFunc): + arg_types = {"parameters": True, "expressions": True} + + +class QuantileIf(AggFunc): + arg_types = {"parameters": True, "expressions": True} + + class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False} @@ -2962,8 +3033,10 @@ class StrToTime(Func): arg_types = {"this": True, "format": True} +# Spark allows unix_timestamp() +# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html class StrToUnix(Func): - arg_types = {"this": True, "format": True} + arg_types = {"this": False, "format": False} class NumberToStr(Func): @@ -3131,7 +3204,7 @@ def maybe_parse( dialect=None, prefix=None, **opts, -) -> t.Optional[Expression]: +) -> Expression: """Gracefully handle a possible string or expression. Example: @@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") - catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)] + catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3)) return Table(this=table_name, db=db, catalog=catalog, **kwargs) -def to_column(sql_path: str, **kwargs) -> Column: +def to_column(sql_path: str | Column, **kwargs) -> Column: """ Create a column from a `[table].[column]` sql path. Schema is optional. @@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column: return sql_path if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)] + table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2)) return Column(this=column_name, table=table_name, **kwargs) @@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: def values( values: t.Iterable[t.Tuple[t.Any, ...]], alias: t.Optional[str] = None, - columns: t.Optional[t.Iterable[str]] = None, + columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None, ) -> Values: """Build VALUES statement. @@ -3759,7 +3832,10 @@ def values( Args: values: values statements that will be converted to SQL alias: optional alias - columns: Optional list of ordered column names. An alias is required when providing column names. + columns: Optional list of ordered column names or ordered dictionary of column names to types. + If either are provided then an alias is also required. + If a dictionary is provided then the first column of the values will be casted to the expected type + in order to help with type inference. Returns: Values: the Values expression object @@ -3771,8 +3847,15 @@ def values( if columns else TableAlias(this=to_identifier(alias) if alias else None) ) + expressions = [convert(tup) for tup in values] + if columns and isinstance(columns, dict): + types = list(columns.values()) + expressions[0].set( + "expressions", + [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)], + ) return Values( - expressions=[convert(tup) for tup in values], + expressions=expressions, alias=table_alias, ) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 0c1578a..3935133 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -50,7 +50,7 @@ class Generator: The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 - comments: Whether or not to preserve comments in the ouput SQL code. + comments: Whether or not to preserve comments in the output SQL code. Default: True """ @@ -236,7 +236,10 @@ class Generator: return sql sep = "\n" if self.pretty else " " - comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments) + comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment) + + if not comments: + return sql if isinstance(expression, self.WITH_SEPARATED_COMMENTS): return f"{comments}{self.sep()}{sql}" @@ -362,10 +365,10 @@ class Generator: kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) exists = "IF NOT EXISTS " if expression.args.get("exists") else "" + kind = f" {kind}" if kind else "" + constraints = f" {constraints}" if constraints else "" - if not constraints: - return f"{exists}{column} {kind}" - return f"{exists}{column} {kind} {constraints}" + return f"{exists}{column}{kind}{constraints}" def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") @@ -416,7 +419,7 @@ class Generator: this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() expression_sql = self.sql(expression, "expression") - expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" + expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" @@ -427,6 +430,40 @@ class Generator: unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") + data = expression.args.get("data") + if data is None: + data = "" + elif data: + data = " WITH DATA" + else: + data = " WITH NO DATA" + statistics = expression.args.get("statistics") + if statistics is None: + statistics = "" + elif statistics: + statistics = " AND STATISTICS" + else: + statistics = " AND NO STATISTICS" + no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else "" + + indexes = expression.args.get("indexes") + index_sql = "" + if indexes is not None: + indexes_sql = [] + for index in indexes: + ind_unique = " UNIQUE" if index.args.get("unique") else "" + ind_primary = " PRIMARY" if index.args.get("primary") else "" + ind_amp = " AMP" if index.args.get("amp") else "" + ind_name = f" {index.name}" if index.name else "" + ind_columns = ( + f' ({self.expressions(index, key="columns", flat=True)})' + if index.args.get("columns") + else "" + ) + indexes_sql.append( + f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" + ) + index_sql = "".join(indexes_sql) modifiers = "".join( ( @@ -438,7 +475,10 @@ class Generator: materialized, ) ) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}" + + post_expression_modifiers = "".join((data, statistics, no_primary_index)) + + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression: exp.Describe) -> str: @@ -668,6 +708,8 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" + hints = self.expressions(expression, key="hints", sep=", ", flat=True) + hints = f" WITH ({hints})" if hints else "" laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") pivots = self.expressions(expression, key="pivots", sep="") @@ -676,7 +718,7 @@ class Generator: pivots = f"{pivots}{alias}" alias = "" - return f"{table}{alias}{laterals}{joins}{pivots}" + return f"{table}{alias}{hints}{laterals}{joins}{pivots}" def tablesample_sql(self, expression: exp.TableSample) -> str: if self.alias_post_tablesample and expression.this.alias: @@ -1020,7 +1062,9 @@ class Generator: if not partition and not order and not spec and alias: return f"{this} {alias}" - return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" + window_args = alias + partition_sql + order_sql + spec_sql + + return f"{this} ({window_args.strip()})" def window_spec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") @@ -1130,6 +1174,8 @@ class Generator: query = expression.args.get("query") unnest = expression.args.get("unnest") field = expression.args.get("field") + is_global = " GLOBAL" if expression.args.get("is_global") else "" + if query: in_sql = self.wrap(query) elif unnest: @@ -1138,7 +1184,8 @@ class Generator: in_sql = self.sql(field) else: in_sql = f"({self.expressions(expression, flat=True)})" - return f"{self.sql(expression, 'this')} IN {in_sql}" + + return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}" def in_unnest_op(self, unnest: exp.Unnest) -> str: return f"(SELECT {self.sql(unnest)})" @@ -1433,7 +1480,7 @@ class Generator: result_sqls = [] for i, e in enumerate(expressions): sql = self.sql(e, comment=False) - comments = self.maybe_comment("", e) + comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else "" if self.pretty: if self._leading_comma: diff --git a/sqlglot/helper.py b/sqlglot/helper.py index ed37e6c..5a0f2ac 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -131,7 +131,7 @@ def subclasses( ] -def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: +def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]: """ Applies an offset to a given integer literal expression. @@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: expression = expressions[0] - if expression.is_int: + if expression and expression.is_int: expression = expression.copy() logger.warning("Applying array index offset (%s)", offset) - expression.args["this"] = str(int(expression.this) + offset) + expression.args["this"] = str(int(expression.this) + offset) # type: ignore return [expression] return expressions @@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO: return gzip.open(file_name, "rt", newline="") - return open(file_name, "rt", encoding="utf-8", newline="") + return open(file_name, encoding="utf-8", newline="") @contextmanager @@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any: file.close() -def find_new_name(taken: t.Sequence[str], base: str) -> str: +def find_new_name(taken: t.Collection[str], base: str) -> str: """ Searches for a new name. @@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, yield value +def count_params(function: t.Callable) -> int: + """ + Returns the number of formal parameters expected by a function, without counting "self" + and "cls", in case of instance and class methods, respectively. + """ + count = function.__code__.co_argcount + return count - 1 if inspect.ismethod(function) else count + + def dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. @@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int: Args: d (dict): dictionary + Returns: int: depth """ diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index be17f15..bfb2bb8 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -43,7 +43,7 @@ class TypeAnnotator: }, exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 3b40710..8e6a520 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias): # But columns in the ON clause shouldn't count. on = join.args.get("on") if on: - on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) + on_clause_columns = {id(column) for column in on.find_all(exp.Column)} else: on_clause_columns = set() return any( @@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join): return False _, join_keys, _ = join_condition(join) - remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys) + remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} return not remaining_unique_outputs diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 9ae4966..16aaf17 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: - inner_select = inner_scope.expression.unnest() from_or_join = table.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): alias = table.alias_or_name - _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, table, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False): _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) + outer_scope.clear_cache() return expression def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: - inner_select = subquery.unnest() from_or_join = subquery.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): - alias = subquery.alias_or_name - inner_scope = outer_scope.sources[alias] - + alias = subquery.alias_or_name + inner_scope = outer_scope.sources[alias] + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, subquery, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) + outer_scope.clear_cache() return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): +def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. Args: outer_scope (Scope) - inner_select (exp.Select) + inner_scope (Scope) leave_tables_isolated (bool) from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ + inner_select = inner_scope.expression.unnest() def _is_a_window_expression_in_unmergable_operation(): window_expressions = inner_select.find_all(exp.Window) @@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): ] return any(window_expressions_in_unmergable) + def _outer_select_joins_on_inner_select_join(): + """ + All columns from the inner select in the ON clause must be from the first FROM table. + + That is, this can be merged: + SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + But this can't: + SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + """ + if not isinstance(from_or_join, exp.Join): + return False + + alias = from_or_join.this.alias_or_name + + on = from_or_join.args.get("on") + if not on: + return False + selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] + inner_from = inner_scope.expression.args.get("from") + if not inner_from: + return False + inner_from_table = inner_from.expressions[0].alias_or_name + inner_projections = {s.alias_or_name: s for s in inner_scope.selects} + return any( + col.table != inner_from_table + for selection in selections + for col in inner_projections[selection].find_all(exp.Column) + ) + return ( isinstance(outer_scope.expression, exp.Select) and isinstance(inner_select, exp.Select) - and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) @@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) ) ) + and not _outer_select_joins_on_inner_select_join() and not _is_a_window_expression_in_unmergable_operation() ) @@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): """ taken = set(outer_scope.selected_sources) conflicts = taken.intersection(set(inner_scope.selected_sources)) - conflicts = conflicts - {alias} + conflicts -= {alias} for conflict in conflicts: new_name = find_new_name(taken, conflict) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 72e67d4..46b6b30 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.unnest_subqueries import unnest_subqueries +from sqlglot.schema import ensure_schema RULES = ( lower_identities, @@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar If no schema is provided then the default schema defined at `sqlgot.schema` will be used db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement - rules (list): sequence of optimizer rules to use + rules (sequence): sequence of optimizer rules to use **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. Returns: sqlglot.Expression: optimized expression """ - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs} + schema = ensure_schema(schema or sqlglot.schema) + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = expression.copy() for rule in rules: diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 49789ac..a73647c 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections): order_refs = set() new_selections = [] + removed = False for i, selection in enumerate(scope.selects): if ( SELECT_ALL in parent_selections @@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections): new_selections.append(selection) else: removed_indexes.append(i) + removed = True # If there are no remaining selections, just select a single constant if not new_selections: new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) + if removed: + scope.clear_cache() return removed_indexes diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e16a635..f4568c2 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -365,9 +365,9 @@ class _Resolver: def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set( + self._all_columns = { column for columns in self._get_all_source_columns().values() for column in columns - ) + } return self._all_columns def get_source_columns(self, name, only_visible=False): diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c0719f2..f560760 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b): return boolean elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) - if b: + if a and b: if isinstance(expression, exp.Add): return date_literal(a + b) if isinstance(expression, exp.Sub): @@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b): elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval - if a and isinstance(expression, exp.Add): + if a and b and isinstance(expression, exp.Add): return date_literal(a + b) return None @@ -424,9 +424,15 @@ def eval_boolean(expression, a, b): def extract_date(cast): - if cast.args["to"].this == exp.DataType.Type.DATE: - return datetime.date.fromisoformat(cast.name) - return None + # The "fromisoformat" conversion could fail if the cast is used on an identifier, + # so in that case we can't extract the date. + try: + if cast.args["to"].this == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(cast.name) + if cast.args["to"].this == exp.DataType.Type.DATETIME: + return datetime.datetime.fromisoformat(cast.name) + except ValueError: + return None def extract_interval(interval): @@ -450,7 +456,8 @@ def extract_interval(interval): def date_literal(date): - return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) + expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE") + return exp.Cast(this=exp.Literal.string(date), to=expr_type) def boolean_literal(condition): diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 8d78294..a515489 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -15,8 +15,7 @@ def unnest_subqueries(expression): >>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() - 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ - AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)' + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' Args: expression (sqlglot.Expression): expression to unnest @@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence): other = _other_operand(parent_predicate) if isinstance(parent_predicate, exp.Exists): - if value.this in group_by: - parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") - else: - parent_predicate = _replace(parent_predicate, "TRUE") + alias = exp.column(list(key_aliases.values())[0], table_alias) + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") elif isinstance(parent_predicate, exp.All): parent_predicate = _replace( parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" @@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence): else: if is_subquery_projection: alias = exp.alias_(alias, select.parent.alias) + + # COUNT always returns 0 on empty datasets, so we need take that into consideration here + # by transforming all counts into 0 and using that as the coalesced value + if value.find(exp.Count): + + def remove_aggs(node): + if isinstance(node, exp.Count): + return exp.Literal.number(0) + elif isinstance(node, exp.AggFunc): + return exp.null() + return node + + alias = exp.Coalesce( + this=alias, + expressions=[value.this.transform(remove_aggs)], + ) + select.parent.replace(alias) for key, column, predicate in keys: @@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace( - parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" - ) elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, @@ -245,7 +256,14 @@ def _other_operand(expression): if isinstance(expression, exp.In): return expression.this + if isinstance(expression, (exp.Any, exp.All)): + return _other_operand(expression.parent) + if isinstance(expression, exp.Binary): - return expression.right if expression.arg_key == "this" else expression.left + return ( + expression.right + if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) + else expression.left + ) return None diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 308f363..bd95db8 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -5,7 +5,13 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get +from sqlglot.helper import ( + apply_index_offset, + count_params, + ensure_collection, + ensure_list, + seq_get, +) from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -54,7 +60,7 @@ class Parser(metaclass=_Parser): Default: "nulls_are_small" """ - FUNCTIONS = { + FUNCTIONS: t.Dict[str, t.Callable] = { **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, "DATE_TO_DATE_STR": lambda args: exp.Cast( this=seq_get(args, 0), @@ -106,6 +112,7 @@ class Parser(metaclass=_Parser): TokenType.JSON, TokenType.JSONB, TokenType.INTERVAL, + TokenType.TIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -164,6 +171,7 @@ class Parser(metaclass=_Parser): TokenType.DELETE, TokenType.DESCRIBE, TokenType.DETERMINISTIC, + TokenType.DIV, TokenType.DISTKEY, TokenType.DISTSTYLE, TokenType.EXECUTE, @@ -252,6 +260,7 @@ class Parser(metaclass=_Parser): TokenType.FIRST, TokenType.FORMAT, TokenType.IDENTIFIER, + TokenType.INDEX, TokenType.ISNULL, TokenType.MERGE, TokenType.OFFSET, @@ -312,6 +321,7 @@ class Parser(metaclass=_Parser): } TIMESTAMPS = { + TokenType.TIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -387,6 +397,7 @@ class Parser(metaclass=_Parser): } EXPRESSION_PARSERS = { + exp.Column: lambda self: self._parse_column(), exp.DataType: lambda self: self._parse_types(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), @@ -419,6 +430,7 @@ class Parser(metaclass=_Parser): TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.CREATE: lambda self: self._parse_create(), TokenType.DELETE: lambda self: self._parse_delete(), + TokenType.DESC: lambda self: self._parse_describe(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), TokenType.END: lambda self: self._parse_commit_or_rollback(), @@ -583,6 +595,11 @@ class Parser(metaclass=_Parser): TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} + + # allows tables to have special tokens as prefixes + TABLE_PREFIX_TOKENS: t.Set[TokenType] = set() + STRICT_CAST = True __slots__ = ( @@ -608,13 +625,13 @@ class Parser(metaclass=_Parser): def __init__( self, - error_level=None, - error_message_context=100, - index_offset=0, - unnest_column_only=False, - alias_post_tablesample=False, - max_errors=3, - null_ordering=None, + error_level: t.Optional[ErrorLevel] = None, + error_message_context: int = 100, + index_offset: int = 0, + unnest_column_only: bool = False, + alias_post_tablesample: bool = False, + max_errors: int = 3, + null_ordering: t.Optional[str] = None, ): self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context @@ -636,23 +653,43 @@ class Parser(metaclass=_Parser): self._prev = None self._prev_comments = None - def parse(self, raw_tokens, sql=None): + def parse( + self, raw_tokens: t.List[Token], sql: t.Optional[str] = None + ) -> t.List[t.Optional[exp.Expression]]: """ - Parses the given list of tokens and returns a list of syntax trees, one tree + Parses a list of tokens and returns a list of syntax trees, one tree per parsed SQL statement. - Args - raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`). - sql (str): the original SQL string. Used to produce helpful debug messages. + Args: + raw_tokens: the list of tokens. + sql: the original SQL string, used to produce helpful debug messages. - Returns - the list of syntax trees (:class:`~sqlglot.expressions.Expression`). + Returns: + The list of syntax trees. """ return self._parse( parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql ) - def parse_into(self, expression_types, raw_tokens, sql=None): + def parse_into( + self, + expression_types: str | exp.Expression | t.Collection[exp.Expression | str], + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: + """ + Parses a list of tokens into a given Expression type. If a collection of Expression + types is given instead, this method will try to parse the token list into each one + of them, stopping at the first for which the parsing succeeds. + + Args: + expression_types: the expression type(s) to try and parse the token list into. + raw_tokens: the list of tokens. + sql: the original SQL string, used to produce helpful debug messages. + + Returns: + The target Expression. + """ errors = [] for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) @@ -668,7 +705,12 @@ class Parser(metaclass=_Parser): errors=merge_errors(errors), ) from errors[-1] - def _parse(self, parse_method, raw_tokens, sql=None): + def _parse( + self, + parse_method: t.Callable[[Parser], t.Optional[exp.Expression]], + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: self.reset() self.sql = sql or "" total = len(raw_tokens) @@ -686,6 +728,7 @@ class Parser(metaclass=_Parser): self._index = -1 self._tokens = tokens self._advance() + expressions.append(parse_method(self)) if self._index < len(self._tokens): @@ -695,7 +738,10 @@ class Parser(metaclass=_Parser): return expressions - def check_errors(self): + def check_errors(self) -> None: + """ + Logs or raises any found errors, depending on the chosen error level setting. + """ if self.error_level == ErrorLevel.WARN: for error in self.errors: logger.error(str(error)) @@ -705,13 +751,18 @@ class Parser(metaclass=_Parser): errors=merge_errors(self.errors), ) - def raise_error(self, message, token=None): + def raise_error(self, message: str, token: t.Optional[Token] = None) -> None: + """ + Appends an error in the list of recorded errors or raises it, depending on the chosen + error level setting. + """ token = token or self._curr or self._prev or Token.string("") start = self._find_token(token, self.sql) end = start + len(token.text) start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] + error = ParseError.new( f"{message}. Line {token.line}, Col: {token.col}.\n" f" {start_context}\033[4m{highlight}\033[0m{end_context}", @@ -722,11 +773,26 @@ class Parser(metaclass=_Parser): highlight=highlight, end_context=end_context, ) + if self.error_level == ErrorLevel.IMMEDIATE: raise error + self.errors.append(error) - def expression(self, exp_class, comments=None, **kwargs): + def expression( + self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs + ) -> exp.Expression: + """ + Creates a new, validated Expression. + + Args: + exp_class: the expression class to instantiate. + comments: an optional list of comments to attach to the expression. + kwargs: the arguments to set for the expression along with their respective values. + + Returns: + The target expression. + """ instance = exp_class(**kwargs) if self._prev_comments: instance.comments = self._prev_comments @@ -736,7 +802,17 @@ class Parser(metaclass=_Parser): self.validate_expression(instance) return instance - def validate_expression(self, expression, args=None): + def validate_expression( + self, expression: exp.Expression, args: t.Optional[t.List] = None + ) -> None: + """ + Validates an already instantiated expression, making sure that all its mandatory arguments + are set. + + Args: + expression: the expression to validate. + args: an optional list of items that was used to instantiate the expression, if it's a Func. + """ if self.error_level == ErrorLevel.IGNORE: return @@ -748,13 +824,18 @@ class Parser(metaclass=_Parser): if mandatory and (v is None or (isinstance(v, list) and not v)): self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args: + if ( + args + and isinstance(expression, exp.Func) + and len(args) > len(expression.arg_types) + and not expression.is_var_len_args + ): self.raise_error( f"The number of provided arguments ({len(args)}) is greater than " f"the maximum number of supported arguments ({len(expression.arg_types)})" ) - def _find_token(self, token, sql): + def _find_token(self, token: Token, sql: str) -> int: line = 1 col = 1 index = 0 @@ -769,7 +850,7 @@ class Parser(metaclass=_Parser): return index - def _advance(self, times=1): + def _advance(self, times: int = 1) -> None: self._index += times self._curr = seq_get(self._tokens, self._index) self._next = seq_get(self._tokens, self._index + 1) @@ -780,10 +861,10 @@ class Parser(metaclass=_Parser): self._prev = None self._prev_comments = None - def _retreat(self, index): + def _retreat(self, index: int) -> None: self._advance(index - self._index) - def _parse_statement(self): + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -803,7 +884,7 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self, default_kind=None): + def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text @@ -812,7 +893,7 @@ class Parser(metaclass=_Parser): kind = default_kind else: self.raise_error(f"Expected {self.CREATABLES}") - return + return None return self.expression( exp.Drop, @@ -824,14 +905,14 @@ class Parser(metaclass=_Parser): cascade=self._match(TokenType.CASCADE), ) - def _parse_exists(self, not_=False): + def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: return ( self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) ) - def _parse_create(self): + def _parse_create(self) -> t.Optional[exp.Expression]: replace = self._match_pair(TokenType.OR, TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) transient = self._match_text_seq("TRANSIENT") @@ -846,12 +927,16 @@ class Parser(metaclass=_Parser): if not create_token: self.raise_error(f"Expected {self.CREATABLES}") - return + return None exists = self._parse_exists(not_=True) this = None expression = None properties = None + data = None + statistics = None + no_primary_index = None + indexes = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function() @@ -868,7 +953,28 @@ class Parser(metaclass=_Parser): this = self._parse_table(schema=True) properties = self._parse_properties() if self._match(TokenType.ALIAS): - expression = self._parse_select(nested=True) + expression = self._parse_ddl_select() + + if create_token.token_type == TokenType.TABLE: + if self._match_text_seq("WITH", "DATA"): + data = True + elif self._match_text_seq("WITH", "NO", "DATA"): + data = False + + if self._match_text_seq("AND", "STATISTICS"): + statistics = True + elif self._match_text_seq("AND", "NO", "STATISTICS"): + statistics = False + + no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX") + + indexes = [] + while True: + index = self._parse_create_table_index() + if not index: + break + else: + indexes.append(index) return self.expression( exp.Create, @@ -883,9 +989,13 @@ class Parser(metaclass=_Parser): replace=replace, unique=unique, materialized=materialized, + data=data, + statistics=statistics, + no_primary_index=no_primary_index, + indexes=indexes, ) - def _parse_property(self): + def _parse_property(self) -> t.Optional[exp.Expression]: if self._match_set(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.token_type](self) @@ -906,7 +1016,7 @@ class Parser(metaclass=_Parser): return None - def _parse_property_assignment(self, exp_class): + def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: self._match(TokenType.EQ) self._match(TokenType.ALIAS) return self.expression( @@ -914,42 +1024,50 @@ class Parser(metaclass=_Parser): this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), ) - def _parse_partitioned_by(self): + def _parse_partitioned_by(self) -> exp.Expression: self._match(TokenType.EQ) return self.expression( exp.PartitionedByProperty, this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_distkey(self): + def _parse_distkey(self) -> exp.Expression: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) - def _parse_create_like(self): + def _parse_create_like(self) -> t.Optional[exp.Expression]: table = self._parse_table(schema=True) options = [] while self._match_texts(("INCLUDING", "EXCLUDING")): + this = self._prev.text.upper() + id_var = self._parse_id_var() + + if not id_var: + return None + options.append( self.expression( exp.Property, - this=self._prev.text.upper(), - value=exp.Var(this=self._parse_id_var().this.upper()), + this=this, + value=exp.Var(this=id_var.this.upper()), ) ) return self.expression(exp.LikeProperty, this=table, expressions=options) - def _parse_sortkey(self, compound=False): + def _parse_sortkey(self, compound: bool = False) -> exp.Expression: return self.expression( exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound ) - def _parse_character_set(self, default=False): + def _parse_character_set(self, default: bool = False) -> exp.Expression: self._match(TokenType.EQ) return self.expression( exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) - def _parse_returns(self): + def _parse_returns(self) -> exp.Expression: + value: t.Optional[exp.Expression] is_table = self._match(TokenType.TABLE) + if is_table: if self._match(TokenType.LT): value = self.expression( @@ -960,13 +1078,13 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") else: - value = self._parse_schema("TABLE") + value = self._parse_schema(exp.Literal.string("TABLE")) else: value = self._parse_types() return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_properties(self): + def _parse_properties(self) -> t.Optional[exp.Expression]: properties = [] while True: @@ -978,15 +1096,21 @@ class Parser(metaclass=_Parser): if properties: return self.expression(exp.Properties, expressions=properties) + return None - def _parse_describe(self): - self._match(TokenType.TABLE) - return self.expression(exp.Describe, this=self._parse_id_var()) + def _parse_describe(self) -> exp.Expression: + kind = self._match_set(self.CREATABLES) and self._prev.text + this = self._parse_table() - def _parse_insert(self): + return self.expression(exp.Describe, this=this, kind=kind) + + def _parse_insert(self) -> exp.Expression: overwrite = self._match(TokenType.OVERWRITE) local = self._match(TokenType.LOCAL) + + this: t.Optional[exp.Expression] + if self._match_text_seq("DIRECTORY"): this = self.expression( exp.Directory, @@ -998,21 +1122,22 @@ class Parser(metaclass=_Parser): self._match(TokenType.INTO) self._match(TokenType.TABLE) this = self._parse_table(schema=True) + return self.expression( exp.Insert, this=this, exists=self._parse_exists(), partition=self._parse_partition(), - expression=self._parse_select(nested=True), + expression=self._parse_ddl_select(), overwrite=overwrite, ) - def _parse_row(self): + def _parse_row(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.FORMAT): return None return self._parse_row_format() - def _parse_row_format(self, match_row=False): + def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]: if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None @@ -1035,9 +1160,10 @@ class Parser(metaclass=_Parser): kwargs["lines"] = self._parse_string() if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() - return self.expression(exp.RowFormatDelimitedProperty, **kwargs) - def _parse_load_data(self): + return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore + + def _parse_load_data(self) -> exp.Expression: local = self._match(TokenType.LOCAL) self._match_text_seq("INPATH") inpath = self._parse_string() @@ -1055,7 +1181,7 @@ class Parser(metaclass=_Parser): serde=self._match_text_seq("SERDE") and self._parse_string(), ) - def _parse_delete(self): + def _parse_delete(self) -> exp.Expression: self._match(TokenType.FROM) return self.expression( @@ -1065,10 +1191,10 @@ class Parser(metaclass=_Parser): where=self._parse_where(), ) - def _parse_update(self): + def _parse_update(self) -> exp.Expression: return self.expression( exp.Update, - **{ + **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), @@ -1076,16 +1202,17 @@ class Parser(metaclass=_Parser): }, ) - def _parse_uncache(self): + def _parse_uncache(self) -> exp.Expression: if not self._match(TokenType.TABLE): self.raise_error("Expecting TABLE after UNCACHE") + return self.expression( exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True), ) - def _parse_cache(self): + def _parse_cache(self) -> exp.Expression: lazy = self._match(TokenType.LAZY) self._match(TokenType.TABLE) table = self._parse_table(schema=True) @@ -1108,21 +1235,23 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_partition(self): + def _parse_partition(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.PARTITION): return None - def parse_values(): + def parse_values() -> exp.Property: props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ) return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1)) return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) - def _parse_value(self): + def _parse_value(self) -> exp.Expression: expressions = self._parse_wrapped_csv(self._parse_conjunction) return self.expression(exp.Tuple, expressions=expressions) - def _parse_select(self, nested=False, table=False): + def _parse_select( + self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True + ) -> t.Optional[exp.Expression]: cte = self._parse_with() if cte: this = self._parse_statement() @@ -1178,10 +1307,11 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(this) this = self._parse_set_operations(this) self._match_r_paren() + # early return so that subquery unions aren't parsed again # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 # Union ALL should be a property of the top select node, not the subquery - return self._parse_subquery(this) + return self._parse_subquery(this, parse_alias=parse_subquery_alias) elif self._match(TokenType.VALUES): if self._curr.token_type == TokenType.L_PAREN: # We don't consume the left paren because it's consumed in _parse_value @@ -1203,7 +1333,7 @@ class Parser(metaclass=_Parser): return self._parse_set_operations(this) - def _parse_with(self, skip_with_token=False): + def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]: if not skip_with_token and not self._match(TokenType.WITH): return None @@ -1220,7 +1350,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.With, expressions=expressions, recursive=recursive) - def _parse_cte(self): + def _parse_cte(self) -> exp.Expression: alias = self._parse_table_alias() if not alias or not alias.this: self.raise_error("Expected CTE to have alias") @@ -1234,7 +1364,9 @@ class Parser(metaclass=_Parser): alias=alias, ) - def _parse_table_alias(self, alias_tokens=None): + def _parse_table_alias( + self, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) alias = self._parse_id_var( any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS @@ -1251,15 +1383,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.TableAlias, this=alias, columns=columns) - def _parse_subquery(self, this): + def _parse_subquery( + self, this: t.Optional[exp.Expression], parse_alias: bool = True + ) -> exp.Expression: return self.expression( exp.Subquery, this=this, pivots=self._parse_pivots(), - alias=self._parse_table_alias(), + alias=self._parse_table_alias() if parse_alias else None, ) - def _parse_query_modifiers(self, this): + def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None: if not isinstance(this, self.MODIFIABLES): return @@ -1284,15 +1418,16 @@ class Parser(metaclass=_Parser): if expression: this.set(key, expression) - def _parse_hint(self): + def _parse_hint(self) -> t.Optional[exp.Expression]: if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") return self.expression(exp.Hint, expressions=hints) + return None - def _parse_into(self): + def _parse_into(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.INTO): return None @@ -1304,14 +1439,15 @@ class Parser(metaclass=_Parser): exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged ) - def _parse_from(self): + def _parse_from(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.FROM): return None + return self.expression( exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) ) - def _parse_lateral(self): + def _parse_lateral(self) -> t.Optional[exp.Expression]: outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) @@ -1334,6 +1470,8 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) + table_alias: t.Optional[exp.Expression] + if view: table = self._parse_id_var(any_token=False) columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] @@ -1354,20 +1492,24 @@ class Parser(metaclass=_Parser): return expression - def _parse_join_side_and_kind(self): + def _parse_join_side_and_kind( + self, + ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: return ( self._match(TokenType.NATURAL) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token=False): + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: natural, side, kind = self._parse_join_side_and_kind() if not skip_join_token and not self._match(TokenType.JOIN): return None - kwargs = {"this": self._parse_table()} + kwargs: t.Dict[ + str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]] + ] = {"this": self._parse_table()} if natural: kwargs["natural"] = True @@ -1381,12 +1523,13 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() - return self.expression(exp.Join, **kwargs) + return self.expression(exp.Join, **kwargs) # type: ignore - def _parse_index(self): + def _parse_index(self) -> exp.Expression: index = self._parse_id_var() self._match(TokenType.ON) self._match(TokenType.TABLE) # hive + return self.expression( exp.Index, this=index, @@ -1394,7 +1537,28 @@ class Parser(metaclass=_Parser): columns=self._parse_expression(), ) - def _parse_table(self, schema=False, alias_tokens=None): + def _parse_create_table_index(self) -> t.Optional[exp.Expression]: + unique = self._match(TokenType.UNIQUE) + primary = self._match_text_seq("PRIMARY") + amp = self._match_text_seq("AMP") + if not self._match(TokenType.INDEX): + return None + index = self._parse_id_var() + columns = None + if self._curr and self._curr.token_type == TokenType.L_PAREN: + columns = self._parse_wrapped_csv(self._parse_column) + return self.expression( + exp.Index, + this=index, + columns=columns, + unique=unique, + primary=primary, + amp=amp, + ) + + def _parse_table( + self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -1417,7 +1581,9 @@ class Parser(metaclass=_Parser): catalog = None db = None - table = (not schema and self._parse_function()) or self._parse_id_var(False) + table = (not schema and self._parse_function()) or self._parse_id_var( + any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS + ) while self._match(TokenType.DOT): if catalog: @@ -1446,6 +1612,14 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) + if self._match(TokenType.WITH): + this.set( + "hints", + self._parse_wrapped_csv( + lambda: self._parse_function() or self._parse_var(any_token=True) + ), + ) + if not self.alias_post_tablesample: table_sample = self._parse_table_sample() @@ -1455,7 +1629,7 @@ class Parser(metaclass=_Parser): return this - def _parse_unnest(self): + def _parse_unnest(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.UNNEST): return None @@ -1473,7 +1647,7 @@ class Parser(metaclass=_Parser): exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias ) - def _parse_derived_table_values(self): + def _parse_derived_table_values(self) -> t.Optional[exp.Expression]: is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) if not is_derived and not self._match(TokenType.VALUES): return None @@ -1485,7 +1659,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) - def _parse_table_sample(self): + def _parse_table_sample(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE_SAMPLE): return None @@ -1533,10 +1707,10 @@ class Parser(metaclass=_Parser): seed=seed, ) - def _parse_pivots(self): + def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: return list(iter(self._parse_pivot, None)) - def _parse_pivot(self): + def _parse_pivot(self) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.PIVOT): @@ -1572,16 +1746,18 @@ class Parser(metaclass=_Parser): return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) - def _parse_where(self, skip_where_token=False): + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): return None + return self.expression( exp.Where, comments=self._prev_comments, this=self._parse_conjunction() ) - def _parse_group(self, skip_group_by_token=False): + def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]: if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None + return self.expression( exp.Group, expressions=self._parse_csv(self._parse_conjunction), @@ -1590,29 +1766,33 @@ class Parser(metaclass=_Parser): rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(), ) - def _parse_grouping_sets(self): + def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.GROUPING_SETS): return None + return self._parse_wrapped_csv(self._parse_grouping_set) - def _parse_grouping_set(self): + def _parse_grouping_set(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): grouping_set = self._parse_csv(self._parse_id_var) self._match_r_paren() return self.expression(exp.Tuple, expressions=grouping_set) + return self._parse_id_var() - def _parse_having(self, skip_having_token=False): + def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]: if not skip_having_token and not self._match(TokenType.HAVING): return None return self.expression(exp.Having, this=self._parse_conjunction()) - def _parse_qualify(self): + def _parse_qualify(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.QUALIFY): return None return self.expression(exp.Qualify, this=self._parse_conjunction()) - def _parse_order(self, this=None, skip_order_token=False): + def _parse_order( + self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False + ) -> t.Optional[exp.Expression]: if not skip_order_token and not self._match(TokenType.ORDER_BY): return this @@ -1620,12 +1800,14 @@ class Parser(metaclass=_Parser): exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) ) - def _parse_sort(self, token_type, exp_class): + def _parse_sort( + self, token_type: TokenType, exp_class: t.Type[exp.Expression] + ) -> t.Optional[exp.Expression]: if not self._match(token_type): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - def _parse_ordered(self): + def _parse_ordered(self) -> exp.Expression: this = self._parse_conjunction() self._match(TokenType.ASC) is_desc = self._match(TokenType.DESC) @@ -1647,7 +1829,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first) - def _parse_limit(self, this=None, top=False): + def _parse_limit( + self, this: t.Optional[exp.Expression] = None, top: bool = False + ) -> t.Optional[exp.Expression]: if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) @@ -1667,7 +1851,7 @@ class Parser(metaclass=_Parser): return this - def _parse_offset(self, this=None): + def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): return this @@ -1675,7 +1859,7 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) - def _parse_set_operations(self, this): + def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set(self.SET_OPERATIONS): return this @@ -1695,19 +1879,19 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_expression(self): + def _parse_expression(self) -> t.Optional[exp.Expression]: return self._parse_alias(self._parse_conjunction()) - def _parse_conjunction(self): + def _parse_conjunction(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_equality, self.CONJUNCTION) - def _parse_equality(self): + def _parse_equality(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_comparison, self.EQUALITY) - def _parse_comparison(self): + def _parse_comparison(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_range, self.COMPARISON) - def _parse_range(self): + def _parse_range(self) -> t.Optional[exp.Expression]: this = self._parse_bitwise() negate = self._match(TokenType.NOT) @@ -1730,7 +1914,7 @@ class Parser(metaclass=_Parser): return this - def _parse_is(self, this): + def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression: negate = self._match(TokenType.NOT) if self._match(TokenType.DISTINCT_FROM): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ @@ -1743,7 +1927,7 @@ class Parser(metaclass=_Parser): ) return self.expression(exp.Not, this=this) if negate else this - def _parse_in(self, this): + def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) @@ -1761,18 +1945,18 @@ class Parser(metaclass=_Parser): return this - def _parse_between(self, this): + def _parse_between(self, this: exp.Expression) -> exp.Expression: low = self._parse_bitwise() self._match(TokenType.AND) high = self._parse_bitwise() return self.expression(exp.Between, this=this, low=low, high=high) - def _parse_escape(self, this): + def _parse_escape(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.ESCAPE): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_bitwise(self): + def _parse_bitwise(self) -> t.Optional[exp.Expression]: this = self._parse_term() while True: @@ -1795,18 +1979,18 @@ class Parser(metaclass=_Parser): return this - def _parse_term(self): + def _parse_term(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_factor, self.TERM) - def _parse_factor(self): + def _parse_factor(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_unary, self.FACTOR) - def _parse_unary(self): + def _parse_unary(self) -> t.Optional[exp.Expression]: if self._match_set(self.UNARY_PARSERS): return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) - def _parse_type(self): + def _parse_type(self) -> t.Optional[exp.Expression]: if self._match(TokenType.INTERVAL): return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var()) @@ -1824,7 +2008,7 @@ class Parser(metaclass=_Parser): return this - def _parse_types(self, check_func=False): + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: index = self._index if not self._match_set(self.TYPE_TOKENS): @@ -1875,7 +2059,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") - value = None + value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) @@ -1884,7 +2068,10 @@ class Parser(metaclass=_Parser): ): value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match(TokenType.WITHOUT_TIME_ZONE): - value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + if type_token == TokenType.TIME: + value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions) + else: + value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) maybe_func = maybe_func and value is None @@ -1912,7 +2099,7 @@ class Parser(metaclass=_Parser): nested=nested, ) - def _parse_struct_kwargs(self): + def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() @@ -1921,12 +2108,12 @@ class Parser(metaclass=_Parser): return None return self.expression(exp.StructKwarg, this=this, expression=data_type) - def _parse_at_time_zone(self, this): + def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.AT_TIME_ZONE): return this return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) - def _parse_column(self): + def _parse_column(self) -> t.Optional[exp.Expression]: this = self._parse_field() if isinstance(this, exp.Identifier): this = self.expression(exp.Column, this=this) @@ -1943,7 +2130,8 @@ class Parser(metaclass=_Parser): if not field: self.raise_error("Expected type") elif op: - field = exp.Literal.string(self._advance() or self._prev.text) + self._advance() + field = exp.Literal.string(self._prev.text) else: field = self._parse_star() or self._parse_function() or self._parse_id_var() @@ -1963,7 +2151,7 @@ class Parser(metaclass=_Parser): return this - def _parse_primary(self): + def _parse_primary(self) -> t.Optional[exp.Expression]: if self._match_set(self.PRIMARY_PARSERS): token_type = self._prev.token_type primary = self.PRIMARY_PARSERS[token_type](self, self._prev) @@ -1995,21 +2183,27 @@ class Parser(metaclass=_Parser): self._match_r_paren() if isinstance(this, exp.Subqueryable): - this = self._parse_set_operations(self._parse_subquery(this)) + this = self._parse_set_operations( + self._parse_subquery(this=this, parse_alias=False) + ) elif len(expressions) > 1: this = self.expression(exp.Tuple, expressions=expressions) else: this = self.expression(exp.Paren, this=this) - if comments: + + if this and comments: this.comments = comments + return this return None - def _parse_field(self, any_token=False): + def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]: return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) - def _parse_function(self, functions=None): + def _parse_function( + self, functions: t.Optional[t.Dict[str, t.Callable]] = None + ) -> t.Optional[exp.Expression]: if not self._curr: return None @@ -2020,7 +2214,9 @@ class Parser(metaclass=_Parser): if not self._next or self._next.token_type != TokenType.L_PAREN: if token_type in self.NO_PAREN_FUNCTIONS: - return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type]) + self._advance() + return self.expression(self.NO_PAREN_FUNCTIONS[token_type]) + return None if token_type not in self.FUNC_TOKENS: @@ -2049,7 +2245,18 @@ class Parser(metaclass=_Parser): args = self._parse_csv(self._parse_lambda) if function: - this = function(args) + + # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the + # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists. + if count_params(function) == 2: + params = None + if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): + params = self._parse_csv(self._parse_lambda) + + this = function(args, params) + else: + this = function(args) + self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -2057,7 +2264,7 @@ class Parser(metaclass=_Parser): self._match_r_paren(this) return self._parse_window(this) - def _parse_user_defined_function(self): + def _parse_user_defined_function(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() while self._match(TokenType.DOT): @@ -2070,27 +2277,27 @@ class Parser(metaclass=_Parser): self._match_r_paren() return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) - def _parse_introducer(self, token): + def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: literal = self._parse_primary() if literal: return self.expression(exp.Introducer, this=token.text, expression=literal) return self.expression(exp.Identifier, this=token.text) - def _parse_national(self, token): + def _parse_national(self, token: Token) -> exp.Expression: return self.expression(exp.National, this=exp.Literal.string(token.text)) - def _parse_session_parameter(self): + def _parse_session_parameter(self) -> exp.Expression: kind = None this = self._parse_id_var() or self._parse_primary() - if self._match(TokenType.DOT): + if this and self._match(TokenType.DOT): kind = this.name this = self._parse_var() or self._parse_primary() return self.expression(exp.SessionParameter, this=this, kind=kind) - def _parse_udf_kwarg(self): + def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() kind = self._parse_types() @@ -2099,7 +2306,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind) - def _parse_lambda(self): + def _parse_lambda(self) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.L_PAREN): @@ -2115,6 +2322,8 @@ class Parser(metaclass=_Parser): self._retreat(index) + this: t.Optional[exp.Expression] + if self._match(TokenType.DISTINCT): this = self.expression( exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) @@ -2129,7 +2338,7 @@ class Parser(metaclass=_Parser): return self._parse_limit(self._parse_order(this)) - def _parse_schema(self, this=None): + def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): self._retreat(index) @@ -2140,14 +2349,15 @@ class Parser(metaclass=_Parser): or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() + + if isinstance(this, exp.Literal): + this = this.name + return self.expression(exp.Schema, this=this, expressions=args) - def _parse_column_def(self, this): + def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: kind = self._parse_types() - if not kind: - return this - constraints = [] while True: constraint = self._parse_column_constraint() @@ -2155,9 +2365,12 @@ class Parser(metaclass=_Parser): break constraints.append(constraint) + if not kind and not constraints: + return this + return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) - def _parse_column_constraint(self): + def _parse_column_constraint(self) -> t.Optional[exp.Expression]: this = self._parse_references() if this: @@ -2166,6 +2379,8 @@ class Parser(metaclass=_Parser): if self._match(TokenType.CONSTRAINT): this = self._parse_id_var() + kind: exp.Expression + if self._match(TokenType.AUTO_INCREMENT): kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): @@ -2202,7 +2417,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnConstraint, this=this, kind=kind) - def _parse_constraint(self): + def _parse_constraint(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.CONSTRAINT): return self._parse_unnamed_constraint() @@ -2217,24 +2432,25 @@ class Parser(metaclass=_Parser): return self.expression(exp.Constraint, this=this, expressions=expressions) - def _parse_unnamed_constraint(self): + def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]: if not self._match_set(self.CONSTRAINT_PARSERS): return None return self.CONSTRAINT_PARSERS[self._prev.token_type](self) - def _parse_unique(self): + def _parse_unique(self) -> exp.Expression: return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) - def _parse_references(self): + def _parse_references(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.REFERENCES): return None + return self.expression( exp.Reference, this=self._parse_id_var(), expressions=self._parse_wrapped_id_vars(), ) - def _parse_foreign_key(self): + def _parse_foreign_key(self) -> exp.Expression: expressions = self._parse_wrapped_id_vars() reference = self._parse_references() options = {} @@ -2260,13 +2476,15 @@ class Parser(metaclass=_Parser): exp.ForeignKey, expressions=expressions, reference=reference, - **options, + **options, # type: ignore ) - def _parse_bracket(self, this): + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.L_BRACKET): return this + expressions: t.List[t.Optional[exp.Expression]] + if self._match(TokenType.COLON): expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())] else: @@ -2284,12 +2502,12 @@ class Parser(metaclass=_Parser): this.comments = self._prev_comments return self._parse_bracket(this) - def _parse_slice(self, this): + def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if self._match(TokenType.COLON): return self.expression(exp.Slice, this=this, expression=self._parse_conjunction()) return this - def _parse_case(self): + def _parse_case(self) -> t.Optional[exp.Expression]: ifs = [] default = None @@ -2311,7 +2529,7 @@ class Parser(metaclass=_Parser): self.expression(exp.Case, this=expression, ifs=ifs, default=default) ) - def _parse_if(self): + def _parse_if(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): args = self._parse_csv(self._parse_conjunction) this = exp.If.from_arg_list(args) @@ -2324,9 +2542,10 @@ class Parser(metaclass=_Parser): false = self._parse_conjunction() if self._match(TokenType.ELSE) else None self._match(TokenType.END) this = self.expression(exp.If, this=condition, true=true, false=false) + return self._parse_window(this) - def _parse_extract(self): + def _parse_extract(self) -> exp.Expression: this = self._parse_function() or self._parse_var() or self._parse_type() if self._match(TokenType.FROM): @@ -2337,7 +2556,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) - def _parse_cast(self, strict): + def _parse_cast(self, strict: bool) -> exp.Expression: this = self._parse_conjunction() if not self._match(TokenType.ALIAS): @@ -2353,7 +2572,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_string_agg(self): + def _parse_string_agg(self) -> exp.Expression: + expression: t.Optional[exp.Expression] + if self._match(TokenType.DISTINCT): args = self._parse_csv(self._parse_conjunction) expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)]) @@ -2380,8 +2601,10 @@ class Parser(metaclass=_Parser): order = self._parse_order(this=expression) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) - def _parse_convert(self, strict): + def _parse_convert(self, strict: bool) -> exp.Expression: + to: t.Optional[exp.Expression] this = self._parse_column() + if self._match(TokenType.USING): to = self.expression(exp.CharacterSet, this=self._parse_var()) elif self._match(TokenType.COMMA): @@ -2390,7 +2613,7 @@ class Parser(metaclass=_Parser): to = None return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_position(self): + def _parse_position(self) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) if self._match(TokenType.IN): @@ -2402,11 +2625,11 @@ class Parser(metaclass=_Parser): return this - def _parse_join_hint(self, func_name): + def _parse_join_hint(self, func_name: str) -> exp.Expression: args = self._parse_csv(self._parse_table) return exp.JoinHint(this=func_name.upper(), expressions=args) - def _parse_substring(self): + def _parse_substring(self) -> exp.Expression: # Postgres supports the form: substring(string [from int] [for int]) # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 @@ -2422,7 +2645,7 @@ class Parser(metaclass=_Parser): return this - def _parse_trim(self): + def _parse_trim(self) -> exp.Expression: # https://www.w3resource.com/sql/character-functions/trim.php # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html @@ -2450,13 +2673,15 @@ class Parser(metaclass=_Parser): collation=collation, ) - def _parse_window_clause(self): + def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) - def _parse_named_window(self): + def _parse_named_window(self) -> t.Optional[exp.Expression]: return self._parse_window(self._parse_id_var(), alias=True) - def _parse_window(self, this, alias=False): + def _parse_window( + self, this: t.Optional[exp.Expression], alias: bool = False + ) -> t.Optional[exp.Expression]: if self._match(TokenType.FILTER): where = self._parse_wrapped(self._parse_where) this = self.expression(exp.Filter, this=this, expression=where) @@ -2495,7 +2720,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN): return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) - alias = self._parse_id_var(False) + window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) partition = None if self._match(TokenType.PARTITION_BY): @@ -2529,10 +2754,10 @@ class Parser(metaclass=_Parser): partition_by=partition, order=order, spec=spec, - alias=alias, + alias=window_alias, ) - def _parse_window_spec(self): + def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: self._match(TokenType.BETWEEN) return { @@ -2543,7 +2768,9 @@ class Parser(metaclass=_Parser): "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } - def _parse_alias(self, this, explicit=False): + def _parse_alias( + self, this: t.Optional[exp.Expression], explicit: bool = False + ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) if explicit and not any_token: @@ -2565,63 +2792,74 @@ class Parser(metaclass=_Parser): return this - def _parse_id_var(self, any_token=True, tokens=None): + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + prefix_tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: identifier = self._parse_identifier() if identifier: return identifier + prefix = "" + + if prefix_tokens: + while self._match_set(prefix_tokens): + prefix += self._prev.text + if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): - return exp.Identifier(this=self._prev.text, quoted=False) + return exp.Identifier(this=prefix + self._prev.text, quoted=False) return None - def _parse_string(self): + def _parse_string(self) -> t.Optional[exp.Expression]: if self._match(TokenType.STRING): return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() - def _parse_number(self): + def _parse_number(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NUMBER): return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) return self._parse_placeholder() - def _parse_identifier(self): + def _parse_identifier(self) -> t.Optional[exp.Expression]: if self._match(TokenType.IDENTIFIER): return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() - def _parse_var(self, any_token=False): + def _parse_var(self, any_token: bool = False) -> t.Optional[exp.Expression]: if (any_token and self._advance_any()) or self._match(TokenType.VAR): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() - def _advance_any(self): + def _advance_any(self) -> t.Optional[Token]: if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: self._advance() return self._prev return None - def _parse_var_or_string(self): + def _parse_var_or_string(self) -> t.Optional[exp.Expression]: return self._parse_var() or self._parse_string() - def _parse_null(self): + def _parse_null(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NULL): return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) return None - def _parse_boolean(self): + def _parse_boolean(self) -> t.Optional[exp.Expression]: if self._match(TokenType.TRUE): return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) if self._match(TokenType.FALSE): return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) return None - def _parse_star(self): + def _parse_star(self) -> t.Optional[exp.Expression]: if self._match(TokenType.STAR): return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None - def _parse_placeholder(self): + def _parse_placeholder(self) -> t.Optional[exp.Expression]: if self._match(TokenType.PLACEHOLDER): return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): @@ -2630,18 +2868,20 @@ class Parser(metaclass=_Parser): self._advance(-1) return None - def _parse_except(self): + def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.EXCEPT): return None return self._parse_wrapped_id_vars() - def _parse_replace(self): + def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.REPLACE): return None return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) - def _parse_csv(self, parse_method, sep=TokenType.COMMA): + def _parse_csv( + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + ) -> t.List[t.Optional[exp.Expression]]: parse_result = parse_method() items = [parse_result] if parse_result is not None else [] @@ -2655,7 +2895,9 @@ class Parser(metaclass=_Parser): return items - def _parse_tokens(self, parse_method, expressions): + def _parse_tokens( + self, parse_method: t.Callable, expressions: t.Dict + ) -> t.Optional[exp.Expression]: this = parse_method() while self._match_set(expressions): @@ -2668,22 +2910,29 @@ class Parser(metaclass=_Parser): return this - def _parse_wrapped_id_vars(self): + def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]: return self._parse_wrapped_csv(self._parse_id_var) - def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA): + def _parse_wrapped_csv( + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + ) -> t.List[t.Optional[exp.Expression]]: return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep)) - def _parse_wrapped(self, parse_method): + def _parse_wrapped(self, parse_method: t.Callable) -> t.Any: self._match_l_paren() parse_result = parse_method() self._match_r_paren() return parse_result - def _parse_select_or_expression(self): + def _parse_select_or_expression(self) -> t.Optional[exp.Expression]: return self._parse_select() or self._parse_expression() - def _parse_transaction(self): + def _parse_ddl_select(self) -> t.Optional[exp.Expression]: + return self._parse_set_operations( + self._parse_select(nested=True, parse_subquery_alias=False) + ) + + def _parse_transaction(self) -> exp.Expression: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text @@ -2703,7 +2952,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Transaction, this=this, modes=modes) - def _parse_commit_or_rollback(self): + def _parse_commit_or_rollback(self) -> exp.Expression: chain = None savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK @@ -2722,27 +2971,30 @@ class Parser(metaclass=_Parser): return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Commit, chain=chain) - def _parse_add_column(self): + def _parse_add_column(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("ADD"): return None self._match(TokenType.COLUMN) exists_column = self._parse_exists(not_=True) expression = self._parse_column_def(self._parse_field(any_token=True)) - expression.set("exists", exists_column) + + if expression: + expression.set("exists", exists_column) + return expression - def _parse_drop_column(self): + def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") - def _parse_alter(self): + def _parse_alter(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE): return None exists = self._parse_exists() this = self._parse_table(schema=True) - actions = None + actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None if self._match_text_seq("ADD", advance=False): actions = self._parse_csv(self._parse_add_column) elif self._match_text_seq("DROP", advance=False): @@ -2770,24 +3022,24 @@ class Parser(metaclass=_Parser): actions = ensure_list(actions) return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) - def _parse_show(self): - parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) + def _parse_show(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore if parser: return parser(self) self._advance() return self.expression(exp.Show, this=self._prev.text.upper()) - def _default_parse_set_item(self): + def _default_parse_set_item(self) -> exp.Expression: return self.expression( exp.SetItem, this=self._parse_statement(), ) - def _parse_set_item(self): - parser = self._find_parser(self.SET_PARSERS, self._set_trie) + def _parse_set_item(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore return parser(self) if parser else self._default_parse_set_item() - def _parse_merge(self): + def _parse_merge(self) -> exp.Expression: self._match(TokenType.INTO) target = self._parse_table(schema=True) @@ -2835,10 +3087,12 @@ class Parser(metaclass=_Parser): expressions=whens, ) - def _parse_set(self): + def _parse_set(self) -> exp.Expression: return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) - def _find_parser(self, parsers, trie): + def _find_parser( + self, parsers: t.Dict[str, t.Callable], trie: t.Dict + ) -> t.Optional[t.Callable]: index = self._index this = [] while True: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index d9a4004..a0d69a7 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc import typing as t +import sqlglot from sqlglot import expressions as exp from sqlglot.errors import SchemaError from sqlglot.helper import dict_depth @@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): visible: t.Optional[t.Dict] = None, dialect: t.Optional[str] = None, ) -> None: - super().__init__(schema) - self.visible = visible or {} self.dialect = dialect + self.visible = visible or {} self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + super().__init__(self._normalize(schema or {})) @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: @@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): } ) + def _normalize(self, schema: t.Dict) -> t.Dict: + """ + Converts all identifiers in the schema into lowercase, unless they're quoted. + + Args: + schema: the schema to normalize. + + Returns: + The normalized schema mapping. + """ + flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) + + normalized_mapping: t.Dict = {} + for keys in flattened_schema: + columns = _nested_get(schema, *zip(keys, keys)) + assert columns is not None + + normalized_keys = [self._normalize_name(key) for key in keys] + for column_name, column_type in columns.items(): + _nested_set( + normalized_mapping, + normalized_keys + [self._normalize_name(column_name)], + column_type, + ) + + return normalized_mapping + def add_table( self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None ) -> None: @@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): ) self.mapping_trie = self._build_trie(self.mapping) + def _normalize_name(self, name: str) -> str: + try: + identifier: t.Optional[exp.Expression] = sqlglot.parse_one( + name, read=self.dialect, into=exp.Identifier + ) + except: + identifier = exp.to_identifier(name) + assert isinstance(identifier, exp.Identifier) + + if identifier.quoted: + return identifier.name + return identifier.name.lower() + def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 diff --git a/sqlglot/serde.py b/sqlglot/serde.py new file mode 100644 index 0000000..a47ffdb --- /dev/null +++ b/sqlglot/serde.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp + +if t.TYPE_CHECKING: + JSON = t.Union[dict, list, str, float, int, bool] + Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON] + + +def dump(node: Node) -> JSON: + """ + Recursively dump an AST into a JSON-serializable dict. + """ + if isinstance(node, list): + return [dump(i) for i in node] + if isinstance(node, exp.DataType.Type): + return { + "class": "DataType.Type", + "value": node.value, + } + if isinstance(node, exp.Expression): + klass = node.__class__.__qualname__ + if node.__class__.__module__ != exp.__name__: + klass = f"{node.__module__}.{klass}" + obj = { + "class": klass, + "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []}, + } + if node.type: + obj["type"] = node.type.sql() + if node.comments: + obj["comments"] = node.comments + return obj + return node + + +def load(obj: JSON) -> Node: + """ + Recursively load a dict (as returned by `dump`) into an AST. + """ + if isinstance(obj, list): + return [load(i) for i in obj] + if isinstance(obj, dict): + class_name = obj["class"] + + if class_name == "DataType.Type": + return exp.DataType.Type(obj["value"]) + + if "." in class_name: + module_path, class_name = class_name.rsplit(".", maxsplit=1) + module = __import__(module_path, fromlist=[class_name]) + else: + module = exp + + klass = getattr(module, class_name) + + expression = klass(**{k: load(v) for k, v in obj["args"].items()}) + type_ = obj.get("type") + if type_: + expression.type = exp.DataType.build(type_) + comments = obj.get("comments") + if comments: + expression.comments = load(comments) + return expression + return obj diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 0efa7d0..8e312a7 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -86,6 +86,7 @@ class TokenType(AutoName): VARBINARY = auto() JSON = auto() JSONB = auto() + TIME = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -181,6 +182,7 @@ class TokenType(AutoName): FUNCTION = auto() FROM = auto() GENERATED = auto() + GLOBAL = auto() GROUP_BY = auto() GROUPING_SETS = auto() HAVING = auto() @@ -656,6 +658,7 @@ class Tokenizer(metaclass=_Tokenizer): "FLOAT4": TokenType.FLOAT, "FLOAT8": TokenType.DOUBLE, "DOUBLE": TokenType.DOUBLE, + "DOUBLE PRECISION": TokenType.DOUBLE, "JSON": TokenType.JSON, "CHAR": TokenType.CHAR, "NCHAR": TokenType.NCHAR, @@ -671,6 +674,7 @@ class Tokenizer(metaclass=_Tokenizer): "BLOB": TokenType.VARBINARY, "BYTEA": TokenType.VARBINARY, "VARBINARY": TokenType.VARBINARY, + "TIME": TokenType.TIME, "TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, @@ -721,6 +725,8 @@ class Tokenizer(metaclass=_Tokenizer): COMMENTS = ["--", ("/*", "*/")] KEYWORD_TRIE = None # autofilled + IDENTIFIER_CAN_START_WITH_DIGIT = False + __slots__ = ( "sql", "size", @@ -938,17 +944,24 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek.upper() == "E" and not scientific: # type: ignore scientific += 1 self._advance() - elif self._peek.isalpha(): # type: ignore - self._add(TokenType.NUMBER) + elif self._peek.isidentifier(): # type: ignore + number_text = self._text literal = [] - while self._peek.isalpha(): # type: ignore + while self._peek.isidentifier(): # type: ignore literal.append(self._peek.upper()) # type: ignore self._advance() + literal = "".join(literal) # type: ignore token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore + if token_type: + self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") return self._add(token_type, literal) # type: ignore + elif self.IDENTIFIER_CAN_START_WITH_DIGIT: + return self._add(TokenType.VAR) + + self._add(TokenType.NUMBER, number_text) return self._advance(-len(literal)) else: return self._add(TokenType.NUMBER) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 99949a1..35ff75a 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -82,6 +82,27 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: return expression +def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: + """ + Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. + This transforms removes the precision from parameterized types in expressions. + """ + return expression.transform( + lambda node: exp.DataType( + **{ + **node.args, + "expressions": [ + node_expression + for node_expression in node.expressions + if isinstance(node_expression, exp.DataType) + ], + } + ) + if isinstance(node, exp.DataType) + else node, + ) + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], to_sql: t.Callable[[Generator, exp.Expression], str], @@ -121,3 +142,6 @@ def delegate(attr: str) -> t.Callable: UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} +REMOVE_PRECISION_PARAMETERIZED_TYPES = { + exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql")) +} diff --git a/sqlglot/trie.py b/sqlglot/trie.py index fa2aaf1..f3b1c38 100644 --- a/sqlglot/trie.py +++ b/sqlglot/trie.py @@ -52,7 +52,7 @@ def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]: Returns: A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value` - is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`). + is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`). """ if not key: return (0, trie) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 37ea2e1..8b44b9f 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -1152,17 +1152,17 @@ class TestFunctions(unittest.TestCase): def test_regexp_extract(self): col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql()) col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql()) col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)") - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql()) def test_regexp_replace(self): col_str = SF.regexp_replace("cola", r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql()) col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql()) def test_initcap(self): col_str = SF.initcap("cola") diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py index 70a868a..45d736f 100644 --- a/tests/dataframe/unit/test_window.py +++ b/tests/dataframe/unit/test_window.py @@ -15,11 +15,11 @@ class TestDataframeWindow(unittest.TestCase): def test_window_spec_rows_between(self): rows_between = WindowSpec().rowsBetween(3, 5) - self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) def test_window_spec_range_between(self): range_between = WindowSpec().rangeBetween(3, 5) - self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) def test_window_partition_by(self): partition_by = Window.partitionBy(F.col("cola"), F.col("colb")) @@ -31,46 +31,46 @@ class TestDataframeWindow(unittest.TestCase): def test_window_rows_between(self): rows_between = Window.rowsBetween(3, 5) - self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) def test_window_range_between(self): range_between = Window.rangeBetween(3, 5) - self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) def test_window_rows_unbounded(self): rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) self.assertEqual( - "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql(), ) rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) self.assertEqual( - "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql(), ) rows_between_unbounded_both = Window.rowsBetween( Window.unboundedPreceding, Window.unboundedFollowing ) self.assertEqual( - "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql(), ) def test_window_range_unbounded(self): range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql(), ) range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) self.assertEqual( - "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql(), ) range_between_unbounded_both = Window.rangeBetween( Window.unboundedPreceding, Window.unboundedFollowing ) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql(), ) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 258e47f..c61a2f3 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -125,7 +125,7 @@ class TestBigQuery(Validator): }, ) self.validate_all( - "CURRENT_DATE", + "CURRENT_TIMESTAMP()", read={ "tsql": "GETDATE()", }, @@ -299,6 +299,14 @@ class TestBigQuery(Validator): "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", }, ) + self.validate_all( + "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", + write={ + "spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)", + "bigquery": "SELECT cola, colb, colc FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb, NULL AS colc)])", + "snowflake": "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", + }, + ) self.validate_all( "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))", write={ @@ -324,3 +332,35 @@ class TestBigQuery(Validator): "SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a", write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"}, ) + + def test_remove_precision_parameterized_types(self): + self.validate_all( + "SELECT CAST(1 AS NUMERIC(10, 2))", + write={ + "bigquery": "SELECT CAST(1 AS NUMERIC)", + }, + ) + self.validate_all( + "CREATE TABLE test (a NUMERIC(10, 2))", + write={ + "bigquery": "CREATE TABLE test (a NUMERIC(10, 2))", + }, + ) + self.validate_all( + "SELECT CAST('1' AS STRING(10)) UNION ALL SELECT CAST('2' AS STRING(10))", + write={ + "bigquery": "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)", + }, + ) + self.validate_all( + "SELECT cola FROM (SELECT CAST('1' AS STRING(10)) AS cola UNION ALL SELECT CAST('2' AS STRING(10)) AS cola)", + write={ + "bigquery": "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)", + }, + ) + self.validate_all( + "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))", + write={ + "bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))", + }, + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index c95c967..109e9f3 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -14,6 +14,9 @@ class TestClickhouse(Validator): self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla") self.validate_identity("SELECT * FROM foo ASOF JOIN bla") self.validate_identity("SELECT * FROM foo ANY JOIN bla") + self.validate_identity("SELECT quantile(0.5)(a)") + self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") + self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", @@ -38,3 +41,9 @@ class TestClickhouse(Validator): "SELECT x #! comment", write={"": "SELECT x /* comment */"}, ) + self.validate_all( + "SELECT quantileIf(0.5)(a, true)", + write={ + "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", + }, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index ced7102..284a30d 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -85,7 +85,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS BINARY(4))", write={ - "bigquery": "CAST(a AS BINARY(4))", + "bigquery": "CAST(a AS BINARY)", "clickhouse": "CAST(a AS BINARY(4))", "drill": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BINARY(4))", @@ -104,7 +104,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS VARBINARY(4))", write={ - "bigquery": "CAST(a AS VARBINARY(4))", + "bigquery": "CAST(a AS VARBINARY)", "clickhouse": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS VARBINARY(4))", "mysql": "CAST(a AS VARBINARY(4))", @@ -181,7 +181,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS VARCHAR(3))", write={ - "bigquery": "CAST(a AS STRING(3))", + "bigquery": "CAST(a AS STRING)", "drill": "CAST(a AS VARCHAR(3))", "duckdb": "CAST(a AS TEXT(3))", "mysql": "CAST(a AS VARCHAR(3))", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index a7f3b8f..bbf00b1 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -338,6 +338,24 @@ class TestHive(Validator): ) def test_hive(self): + self.validate_all( + "SELECT A.1a AS b FROM test_a AS A", + write={ + "spark": "SELECT A.1a AS b FROM test_a AS A", + }, + ) + self.validate_all( + "SELECT 1_a AS a FROM test_table", + write={ + "spark": "SELECT 1_a AS a FROM test_table", + }, + ) + self.validate_all( + "SELECT a_b AS 1_a FROM test_table", + write={ + "spark": "SELECT a_b AS 1_a FROM test_table", + }, + ) self.validate_all( "PERCENTILE(x, 0.5)", write={ @@ -411,7 +429,7 @@ class TestHive(Validator): "INITCAP('new york')", write={ "duckdb": "INITCAP('new york')", - "presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))", + "presto": r"REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))", "hive": "INITCAP('new york')", "spark": "INITCAP('new york')", }, diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 1e048d5..583d349 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -122,6 +122,10 @@ class TestPostgres(Validator): "TO_TIMESTAMP(123::DOUBLE PRECISION)", write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"}, ) + self.validate_all( + "SELECT to_timestamp(123)::time without time zone", + write={"postgres": "SELECT CAST(TO_TIMESTAMP(123) AS TIME)"}, + ) self.validate_identity( "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 70e1059..ee535e9 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -60,11 +60,11 @@ class TestPresto(Validator): self.validate_all( "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", write={ - "bigquery": "CAST(x AS TIMESTAMPTZ(9))", + "bigquery": "CAST(x AS TIMESTAMPTZ)", "duckdb": "CAST(x AS TIMESTAMPTZ(9))", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", - "hive": "CAST(x AS TIMESTAMPTZ(9))", - "spark": "CAST(x AS TIMESTAMPTZ(9))", + "hive": "CAST(x AS TIMESTAMPTZ)", + "spark": "CAST(x AS TIMESTAMPTZ)", }, ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index df62c6c..0e9ce9b 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -523,3 +523,33 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA "spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)", }, ) + + def test_describe_table(self): + self.validate_all( + "DESCRIBE TABLE db.table", + write={ + "snowflake": "DESCRIBE TABLE db.table", + "spark": "DESCRIBE db.table", + }, + ) + self.validate_all( + "DESCRIBE db.table", + write={ + "snowflake": "DESCRIBE TABLE db.table", + "spark": "DESCRIBE db.table", + }, + ) + self.validate_all( + "DESC TABLE db.table", + write={ + "snowflake": "DESCRIBE TABLE db.table", + "spark": "DESCRIBE db.table", + }, + ) + self.validate_all( + "DESC VIEW db.table", + write={ + "snowflake": "DESCRIBE VIEW db.table", + "spark": "DESCRIBE db.table", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 7395e72..f287a89 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -207,6 +207,7 @@ TBLPROPERTIES ( ) def test_spark(self): + self.validate_identity("SELECT UNIX_TIMESTAMP()") self.validate_all( "ARRAY_SORT(x, (left, right) -> -1)", write={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b4ac094..b74c05f 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -6,6 +6,8 @@ class TestTSQL(Validator): def test_tsql(self): self.validate_identity('SELECT "x"."y" FROM foo') + self.validate_identity("SELECT * FROM #foo") + self.validate_identity("SELECT * FROM ##foo") self.validate_identity( "SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee" ) @@ -71,6 +73,12 @@ class TestTSQL(Validator): "tsql": "CAST(x AS DATETIME2)", }, ) + self.validate_all( + "CAST(x AS DATETIME2(6))", + write={ + "hive": "CAST(x AS TIMESTAMP)", + }, + ) def test_charindex(self): self.validate_all( @@ -300,6 +308,12 @@ class TestTSQL(Validator): "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", }, ) + self.validate_all( + "SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test", + write={ + "spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test", + }, + ) def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") @@ -441,3 +455,13 @@ class TestTSQL(Validator): "SELECT '''test'''", write={"spark": r"SELECT '\'test\''"}, ) + + def test_eomonth(self): + self.validate_all( + "EOMONTH(GETDATE())", + write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"}, + ) + self.validate_all( + "EOMONTH(GETDATE(), -1)", + write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index e6a6e6b..beb5703 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -89,6 +89,7 @@ POSEXPLODE("x") AS ("a", "b") POSEXPLODE("x") AS ("a", "b", "c") STR_POSITION(x, 'a') STR_POSITION(x, 'a', 3) +LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1) SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)] x[ORDINAL(1)][SAFE_OFFSET(2)] x LIKE SUBSTR('abc', 1, 1) @@ -425,6 +426,7 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING) +SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) SELECT SUM(x) FILTER(WHERE x > 1) @@ -450,14 +452,24 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) +SELECT * FROM t WITH (TABLOCK, INDEX(myindex)) +SELECT * FROM t WITH (NOWAIT) +CREATE TABLE foo AS (SELECT 1) UNION ALL (SELECT 2) CREATE TABLE foo (id INT PRIMARY KEY ASC) CREATE TABLE a.b AS SELECT 1 +CREATE TABLE a.b AS SELECT 1 WITH DATA AND STATISTICS +CREATE TABLE a.b AS SELECT 1 WITH NO DATA AND NO STATISTICS +CREATE TABLE a.b AS (SELECT 1) NO PRIMARY INDEX +CREATE TABLE a.b AS (SELECT 1) UNIQUE PRIMARY INDEX index1 (a) UNIQUE INDEX index2 (b) +CREATE TABLE a.b AS (SELECT 1) PRIMARY AMP INDEX index1 (a) UNIQUE INDEX index2 (b) CREATE TABLE a.b AS SELECT a FROM a.c CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY TABLE x AS SELECT a FROM d CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE VIEW x AS SELECT a FROM b CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b +CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d +CREATE VIEW IF NOT EXISTS z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d CREATE OR REPLACE VIEW x AS SELECT * CREATE OR REPLACE TEMPORARY VIEW x AS SELECT * CREATE TEMPORARY VIEW x AS SELECT a FROM d @@ -490,6 +502,8 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) CREATE TABLE z (a INT REFERENCES parent(b, c)) CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) +CREATE VIEW z (a, b) +CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f AS 'g' CREATE FUNCTION f @@ -559,6 +573,7 @@ INSERT INTO x.z IF EXISTS SELECT * FROM y INSERT INTO x VALUES (1, 'a', 2.0) INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x) INSERT INTO y (a, b, c) SELECT a, b, c FROM x +INSERT INTO y (SELECT 1) UNION (SELECT 2) INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y INSERT OVERWRITE DIRECTORY 'x' SELECT 1 @@ -627,3 +642,4 @@ ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10 ALTER TABLE integers ALTER COLUMN i DROP DEFAULT ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT +SELECT div.a FROM test_table AS div diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index 4a3ad4b..4c06e42 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -311,3 +311,42 @@ FROM ON t1.cola = t2.cola; SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola; + +# title: Nested subquery selects from same table as another subquery +WITH i AS ( + SELECT + x.a AS a + FROM x AS x +), j AS ( + SELECT + x.a, + x.b + FROM x AS x +), k AS ( + SELECT + j.a, + j.b + FROM j AS j +) +SELECT + i.a, + k.b +FROM i AS i +LEFT JOIN k AS k +ON i.a = k.a; +SELECT x.a AS a, x_2.b AS b FROM x AS x LEFT JOIN x AS x_2 ON x.a = x_2.a; + +# title: Outer select joins on inner select join +WITH i AS ( + SELECT + x.a AS a + FROM y AS y + JOIN x AS x + ON y.b = x.b +) +SELECT + x.a AS a +FROM x AS x +LEFT JOIN i AS i + ON x.a = i.a; +WITH i AS (SELECT x.a AS a FROM y AS y JOIN x AS x ON y.b = x.b) SELECT x.a AS a FROM x AS x LEFT JOIN i AS i ON x.a = i.a; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index b502d81..664b3c7 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -105,7 +105,7 @@ LEFT JOIN "_u_0" AS "_u_0" JOIN "y" AS "y" ON "x"."b" = "y"."b" WHERE - "_u_0"."_col_0" >= 0 AND "x"."a" > 1 AND NOT "_u_0"."_u_1" IS NULL + "_u_0"."_col_0" >= 0 AND "x"."a" > 1 GROUP BY "x"."a"; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index 2a21f65..b9f6c3f 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -54,3 +54,6 @@ WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS SELECT x FROM VALUES(1, 2) AS q(x, y); SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); + +SELECT i.a FROM x AS i LEFT JOIN (SELECT a, b FROM (SELECT a, b FROM x)) AS j ON i.a = j.a; +SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS j ON i.a = j.a; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index cf4195d..4e9e70c 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -375,6 +375,18 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo; date '1998-12-01' + interval '90' foo; CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; +CAST(x AS DATE) + interval '1' week; +CAST(x AS DATE) + INTERVAL '1' week; + +CAST('2008-11-11' AS DATETIME) + INTERVAL '5' MONTH; +CAST('2009-04-11 00:00:00' AS DATETIME); + +datetime '1998-12-01' - interval '90' day; +CAST('1998-09-02 00:00:00' AS DATETIME); + +CAST(x AS DATETIME) + interval '1' week; +CAST(x AS DATETIME) + INTERVAL '1' week; + -------------------------------------- -- Comparisons -------------------------------------- diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 9c1f138..272fb26 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -150,7 +150,6 @@ WHERE "part"."p_size" = 15 AND "part"."p_type" LIKE '%BRASS' AND "partsupp"."ps_supplycost" = "_u_0"."_col_0" - AND NOT "_u_0"."_u_1" IS NULL ORDER BY "s_acctbal" DESC, "n_name", @@ -1008,7 +1007,7 @@ JOIN "part" AS "part" LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."_u_1" = "part"."p_partkey" WHERE - "lineitem"."l_quantity" < "_u_0"."_col_0" AND NOT "_u_0"."_u_1" IS NULL; + "lineitem"."l_quantity" < "_u_0"."_col_0"; -------------------------------------- -- TPC-H 18 @@ -1253,10 +1252,7 @@ WITH "_u_0" AS ( LEFT JOIN "_u_3" AS "_u_3" ON "partsupp"."ps_partkey" = "_u_3"."p_partkey" WHERE - "partsupp"."ps_availqty" > "_u_0"."_col_0" - AND NOT "_u_0"."_u_1" IS NULL - AND NOT "_u_0"."_u_2" IS NULL - AND NOT "_u_3"."p_partkey" IS NULL + "partsupp"."ps_availqty" > "_u_0"."_col_0" AND NOT "_u_3"."p_partkey" IS NULL GROUP BY "partsupp"."ps_suppkey" ) diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index a444945..9d760e0 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -22,6 +22,8 @@ WHERE AND x.a > ANY (SELECT y.a FROM y) AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10) AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10) + AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a) + AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a) ; SELECT * @@ -130,37 +132,42 @@ LEFT JOIN ( y.a ) AS _u_15 ON x.a = _u_15.a +LEFT JOIN ( + SELECT + ARRAY_AGG(c), + y.a AS _u_20 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS _u_19 + ON _u_19._u_20 = x.a +LEFT JOIN ( + SELECT + COUNT(*) AS d, + y.a AS _u_22 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS _u_21 + ON _u_21._u_22 = x.a WHERE x.a = _u_0.a AND NOT "_u_1"."a" IS NULL AND NOT "_u_2"."b" IS NULL AND NOT "_u_3"."a" IS NULL + AND x.a = _u_4.b + AND x.a > _u_6.b + AND x.a = _u_8.a + AND NOT x.a = _u_9.a + AND ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND ( - x.a = _u_4.b AND NOT _u_4._u_5 IS NULL - ) - AND ( - x.a > _u_6.b AND NOT _u_6._u_7 IS NULL - ) - AND ( - None = _u_8.a AND NOT _u_8.a IS NULL - ) - AND NOT ( - x.a = _u_9.a AND NOT _u_9.a IS NULL - ) - AND ( - ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL - ) - AND ( - ( - ( - x.a < _u_12.a AND NOT _u_12._u_13 IS NULL - ) AND NOT _u_12._u_13 IS NULL - ) - AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d) - ) - AND ( - NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL + x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d) ) + AND NOT _u_15.a IS NULL AND x.a IN ( SELECT y.a AS a @@ -199,4 +206,6 @@ WHERE WHERE y.a = x.a OFFSET 10 - ); + ) + AND ARRAY_ALL(_u_19."", _x -> _x = x.a) + AND x.a > COALESCE(_u_21.d, 0); diff --git a/tests/helpers.py b/tests/helpers.py index 9abdaae..bab4da0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -27,8 +27,7 @@ def assert_logger_contains(message, logger, level="error"): def load_sql_fixtures(filename): with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f: - for sql in _filter_comments(f.read()).splitlines(): - yield sql + yield from _filter_comments(f.read()).splitlines() def load_sql_fixture_pairs(filename): diff --git a/tests/test_executor.py b/tests/test_executor.py index b705551..f45a5d4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -401,6 +401,36 @@ class TestExecutor(unittest.TestCase): ], ) + def test_correlated_count(self): + tables = { + "parts": [{"pnum": 0, "qoh": 1}], + "supplies": [], + } + + schema = { + "parts": {"pnum": "int", "qoh": "int"}, + "supplies": {"pnum": "int", "shipdate": "int"}, + } + + self.assertEqual( + execute( + """ + select * + from parts + where parts.qoh >= ( + select count(supplies.shipdate) + 1 + from supplies + where supplies.pnum = parts.pnum and supplies.shipdate < 10 + ) + """, + tables=tables, + schema=schema, + ).rows, + [ + (0, 1), + ], + ) + def test_table_depth_mismatch(self): tables = {"table": []} schema = {"db": {"table": {"col": "VARCHAR"}}} diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 1e23983..906e08c 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -646,3 +646,72 @@ FROM foo""", exp.Column(this=exp.to_identifier("colb")), ], ) + + def test_values(self): + self.assertEqual( + exp.values([(1, 2), (3, 4)], "t", ["a", "b"]).sql(), + "(VALUES (1, 2), (3, 4)) AS t(a, b)", + ) + self.assertEqual( + exp.values( + [(1, 2), (3, 4)], + "t", + {"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")}, + ).sql(), + "(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)", + ) + with self.assertRaises(ValueError): + exp.values([(1, 2), (3, 4)], columns=["a"]) + + def test_data_type_builder(self): + self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT") + self.assertEqual(exp.DataType.build("DECIMAL(10, 2)").sql(), "DECIMAL(10, 2)") + self.assertEqual(exp.DataType.build("VARCHAR(255)").sql(), "VARCHAR(255)") + self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY") + self.assertEqual(exp.DataType.build("CHAR").sql(), "CHAR") + self.assertEqual(exp.DataType.build("NCHAR").sql(), "CHAR") + self.assertEqual(exp.DataType.build("VARCHAR").sql(), "VARCHAR") + self.assertEqual(exp.DataType.build("NVARCHAR").sql(), "VARCHAR") + self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT") + self.assertEqual(exp.DataType.build("BINARY").sql(), "BINARY") + self.assertEqual(exp.DataType.build("VARBINARY").sql(), "VARBINARY") + self.assertEqual(exp.DataType.build("INT").sql(), "INT") + self.assertEqual(exp.DataType.build("TINYINT").sql(), "TINYINT") + self.assertEqual(exp.DataType.build("SMALLINT").sql(), "SMALLINT") + self.assertEqual(exp.DataType.build("BIGINT").sql(), "BIGINT") + self.assertEqual(exp.DataType.build("FLOAT").sql(), "FLOAT") + self.assertEqual(exp.DataType.build("DOUBLE").sql(), "DOUBLE") + self.assertEqual(exp.DataType.build("DECIMAL").sql(), "DECIMAL") + self.assertEqual(exp.DataType.build("BOOLEAN").sql(), "BOOLEAN") + self.assertEqual(exp.DataType.build("JSON").sql(), "JSON") + self.assertEqual(exp.DataType.build("JSONB").sql(), "JSONB") + self.assertEqual(exp.DataType.build("INTERVAL").sql(), "INTERVAL") + self.assertEqual(exp.DataType.build("TIME").sql(), "TIME") + self.assertEqual(exp.DataType.build("TIMESTAMP").sql(), "TIMESTAMP") + self.assertEqual(exp.DataType.build("TIMESTAMPTZ").sql(), "TIMESTAMPTZ") + self.assertEqual(exp.DataType.build("TIMESTAMPLTZ").sql(), "TIMESTAMPLTZ") + self.assertEqual(exp.DataType.build("DATE").sql(), "DATE") + self.assertEqual(exp.DataType.build("DATETIME").sql(), "DATETIME") + self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY") + self.assertEqual(exp.DataType.build("MAP").sql(), "MAP") + self.assertEqual(exp.DataType.build("UUID").sql(), "UUID") + self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY") + self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY") + self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT") + self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE") + self.assertEqual(exp.DataType.build("HLLSKETCH").sql(), "HLLSKETCH") + self.assertEqual(exp.DataType.build("HSTORE").sql(), "HSTORE") + self.assertEqual(exp.DataType.build("SUPER").sql(), "SUPER") + self.assertEqual(exp.DataType.build("SERIAL").sql(), "SERIAL") + self.assertEqual(exp.DataType.build("SMALLSERIAL").sql(), "SMALLSERIAL") + self.assertEqual(exp.DataType.build("BIGSERIAL").sql(), "BIGSERIAL") + self.assertEqual(exp.DataType.build("XML").sql(), "XML") + self.assertEqual(exp.DataType.build("UNIQUEIDENTIFIER").sql(), "UNIQUEIDENTIFIER") + self.assertEqual(exp.DataType.build("MONEY").sql(), "MONEY") + self.assertEqual(exp.DataType.build("SMALLMONEY").sql(), "SMALLMONEY") + self.assertEqual(exp.DataType.build("ROWVERSION").sql(), "ROWVERSION") + self.assertEqual(exp.DataType.build("IMAGE").sql(), "IMAGE") + self.assertEqual(exp.DataType.build("VARIANT").sql(), "VARIANT") + self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT") + self.assertEqual(exp.DataType.build("NULL").sql(), "NULL") + self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 1c97be7..887f427 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -299,10 +299,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"}) self.assertEqual(len(scopes[6].columns), 6) - self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"}) + self.assertEqual({c.table for c in scopes[6].columns}, {"r", "s"}) self.assertEqual(scopes[6].source_columns("q"), []) self.assertEqual(len(scopes[6].source_columns("r")), 2) - self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"}) + self.assertEqual({c.table for c in scopes[6].source_columns("r")}, {"r"}) self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") @@ -578,3 +578,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') scope_t, scope_y = build_scope(query).cte_scopes self.assertEqual(set(scope_t.cte_sources), {"t"}) self.assertEqual(set(scope_y.cte_sources), {"t", "y"}) + + def test_schema_with_spaces(self): + schema = { + "a": { + "b c": "text", + '"d e"': "text", + } + } + + self.assertEqual( + optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema), + parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'), + ) diff --git a/tests/test_parser.py b/tests/test_parser.py index ae2e4cd..03b801b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -8,7 +8,8 @@ from tests.helpers import assert_logger_contains class TestParser(unittest.TestCase): def test_parse_empty(self): - self.assertIsNone(parse_one("")) + with self.assertRaises(ParseError) as ctx: + parse_one("") def test_parse_into(self): self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join) @@ -90,6 +91,9 @@ class TestParser(unittest.TestCase): parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), """SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""", ) + self.assertIsNone( + parse_one("create table a as (select b from c) index").find(exp.TableAlias) + ) def test_command(self): expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive") @@ -155,6 +159,11 @@ class TestParser(unittest.TestCase): assert expressions[0].args["from"].expressions[0].this.name == "a" assert expressions[1].args["from"].expressions[0].this.name == "b" + expressions = parse("SELECT 1; ; SELECT 2") + + assert len(expressions) == 3 + assert expressions[1] is None + def test_expression(self): ignore = Parser(error_level=ErrorLevel.IGNORE) self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint) diff --git a/tests/test_schema.py b/tests/test_schema.py index 6c1ca9c..3dd9103 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -184,3 +184,19 @@ class TestSchema(unittest.TestCase): schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}}) self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT) + + def test_schema_normalization(self): + schema = MappingSchema( + schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}}, + dialect="spark", + ) + + table_z = exp.Table(this="z", db="y", catalog="x") + table_w = exp.Table(this="w", db="y", catalog="x") + + self.assertEqual(schema.column_names(table_z), ["a", "B"]) + self.assertEqual(schema.column_names(table_w), ["c"]) + + # Clickhouse supports both `` and "" for identifier quotes; sqlglot uses "" when generating sql + schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse") + self.assertEqual(schema.column_names(exp.Table(this="x")), ["y"]) diff --git a/tests/test_serde.py b/tests/test_serde.py new file mode 100644 index 0000000..603a155 --- /dev/null +++ b/tests/test_serde.py @@ -0,0 +1,33 @@ +import json +import unittest + +from sqlglot import exp, parse_one +from sqlglot.optimizer.annotate_types import annotate_types +from tests.helpers import load_sql_fixtures + + +class CustomExpression(exp.Expression): + ... + + +class TestSerDe(unittest.TestCase): + def dump_load(self, expression): + return exp.Expression.load(json.loads(json.dumps(expression.dump()))) + + def test_serde(self): + for sql in load_sql_fixtures("identity.sql"): + with self.subTest(sql): + before = parse_one(sql) + after = self.dump_load(before) + self.assertEqual(before, after) + + def test_custom_expression(self): + before = CustomExpression() + after = self.dump_load(before) + self.assertEqual(before, after) + + def test_type_annotations(self): + before = annotate_types(parse_one("CAST('1' AS INT)")) + after = self.dump_load(before) + self.assertEqual(before.type, after.type) + self.assertEqual(before.this.type, after.this.type) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index cfb8d2b..cc9af7e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,7 +1,11 @@ import unittest from sqlglot import parse_one -from sqlglot.transforms import eliminate_distinct_on, unalias_group +from sqlglot.transforms import ( + eliminate_distinct_on, + remove_precision_parameterized_types, + unalias_group, +) class TestTime(unittest.TestCase): @@ -62,3 +66,10 @@ class TestTime(unittest.TestCase): "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1', ) + + def test_remove_precision_parameterized_types(self): + self.validate( + remove_precision_parameterized_types, + "SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))", + "SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)", + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 9253ded..3a7fea4 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -117,6 +117,11 @@ class TestTranspile(unittest.TestCase): "select x from foo -- x", "SELECT x FROM foo /* x */", ) + self.validate( + """select x, -- + from foo""", + "SELECT x FROM foo", + ) self.validate( """ -- comment 1 -- cgit v1.2.3