From 75d158890b303b701c51f12b34c422fb823ba9aa Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 30 Jan 2023 18:08:33 +0100 Subject: Adding upstream version 10.5.10. Signed-off-by: Daniel Baumann --- CONTRIBUTING.md | 1 + Makefile | 4 +- README.md | 3 +- pdoc/docs/expressions.md | 41 --- sqlglot/__init__.py | 15 +- sqlglot/dataframe/README.md | 32 +- sqlglot/dataframe/sql/dataframe.py | 9 +- sqlglot/dataframe/sql/session.py | 4 +- sqlglot/dialects/__init__.py | 61 +++ sqlglot/dialects/bigquery.py | 1 - sqlglot/dialects/clickhouse.py | 9 +- sqlglot/dialects/mysql.py | 4 + sqlglot/dialects/snowflake.py | 8 +- sqlglot/dialects/tsql.py | 116 +++++- sqlglot/diff.py | 12 +- sqlglot/executor/context.py | 2 +- sqlglot/expressions.py | 410 +++++++++++++++++---- sqlglot/generator.py | 67 +++- sqlglot/helper.py | 2 +- sqlglot/lineage.py | 228 ++++++++++++ sqlglot/optimizer/__init__.py | 1 + sqlglot/optimizer/isolate_table_selects.py | 7 +- sqlglot/optimizer/qualify_columns.py | 54 +-- sqlglot/optimizer/scope.py | 4 +- sqlglot/parser.py | 272 +++++++++----- sqlglot/planner.py | 4 +- sqlglot/schema.py | 7 +- sqlglot/tokens.py | 9 +- tests/dialects/test_clickhouse.py | 7 +- tests/dialects/test_dialect.py | 66 ++-- tests/dialects/test_mysql.py | 9 + tests/dialects/test_snowflake.py | 18 + tests/dialects/test_tsql.py | 142 +++++++ tests/fixtures/identity.sql | 25 +- tests/fixtures/optimizer/isolate_table_selects.sql | 3 + tests/fixtures/optimizer/pushdown_projections.sql | 2 +- tests/fixtures/optimizer/qualify_columns.sql | 16 +- .../optimizer/qualify_columns__invalid.sql | 2 +- tests/test_executor.py | 4 +- tests/test_expressions.py | 21 ++ tests/test_lineage.py | 20 + tests/test_optimizer.py | 1 + tests/test_schema.py | 19 +- 43 files changed, 1385 insertions(+), 357 deletions(-) delete mode 100644 pdoc/docs/expressions.md create mode 100644 sqlglot/lineage.py create mode 100644 tests/test_lineage.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 97c795d..4dd7cf0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,6 +32,7 @@ We use GitHub issues to track public bugs. Report a bug by opening a new issue. - What you expected would happen - What actually happens - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) +- References (e.g. documentation pages related to the issue) ## Start a discussion using Github's [discussions](https://github.com/tobymao/sqlglot/discussions) [We use GitHub discussions](https://github.com/tobymao/sqlglot/discussions/190) to discuss about the current state diff --git a/Makefile b/Makefile index 2da2493..8f27ecf 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ style: check: style test docs: - pdoc/cli.py -o pdoc/docs + python pdoc/cli.py -o pdoc/docs docs-serve: - pdoc/cli.py + python pdoc/cli.py diff --git a/README.md b/README.md index 0416521..a2e2836 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # SQLGlot -SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [18 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. +SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python. @@ -189,7 +189,6 @@ except sqlglot.errors.ParseError as e: print(e.errors) ``` -Output: ```python [{ 'description': 'Expecting )', diff --git a/pdoc/docs/expressions.md b/pdoc/docs/expressions.md deleted file mode 100644 index c82674b..0000000 --- a/pdoc/docs/expressions.md +++ /dev/null @@ -1,41 +0,0 @@ -# Expressions - -Every AST node in SQLGlot is represented by a subclass of `Expression`. Each such expression encapsulates any necessary context, such as its child expressions, their names, or arg keys, and whether each child expression is optional or not. - -Furthermore, the following attributes are common across all expressions: - -#### key - -A unique key for each class in the `Expression` hierarchy. This is useful for hashing and representing expressions as strings. - -#### args - -A dictionary used for mapping child arg keys, to the corresponding expressions. A value in this mapping is usually either a single or a list of `Expression` instances, but SQLGlot doesn't impose any constraints on the actual type of the value. - -#### arg_types - -A dictionary used for mapping arg keys to booleans that determine whether the corresponding expressions are optional or not. Consider the following example: - -```python -class Limit(Expression): - arg_types = {"this": False, "expression": True} - -``` - -Here, `Limit` declares that it expects to have one optional and one required child expression, which can be referenced through `this` and `expression`, respectively. The arg keys are generally arbitrary, but there are helper methods for keys like `this`, `expression` and `expressions` that abstract away dictionary lookups and related checks. For this reason, these keys are common throughout SQLGlot's codebase. - -#### parent - -A reference to the parent expression (may be `None`). - -#### arg_key - -The arg key an expression is associated with, i.e. the name its parent expression uses to refer to it. - -#### comments - -A list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code. - -#### type - -The data type of an expression, as inferred by SQLGlot's optimizer. diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index f2db4f1..67a4463 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,5 +1,6 @@ """ .. include:: ../README.md +---- """ from __future__ import annotations @@ -29,14 +30,16 @@ from sqlglot.expressions import table_ as table from sqlglot.expressions import to_column, to_table, union from sqlglot.generator import Generator from sqlglot.parser import Parser -from sqlglot.schema import MappingSchema +from sqlglot.schema import MappingSchema, Schema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.5.6" +__version__ = "10.5.10" pretty = False +"""Whether to format generated SQL by default.""" schema = MappingSchema() +"""The default schema used by SQLGlot (e.g. in the optimizer).""" def parse( @@ -48,7 +51,7 @@ def parse( Args: sql: the SQL code string to parse. read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). - **opts: other options. + **opts: other `sqlglot.parser.Parser` options. Returns: The resulting syntax tree collection. @@ -60,7 +63,7 @@ def parse( def parse_one( sql: str, read: t.Optional[str | Dialect] = None, - into: t.Optional[t.Type[Expression] | str] = None, + into: t.Optional[exp.IntoType] = None, **opts, ) -> Expression: """ @@ -70,7 +73,7 @@ def parse_one( sql: the SQL code string to parse. read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). into: the SQLGlot Expression to parse into. - **opts: other options. + **opts: other `sqlglot.parser.Parser` options. Returns: The syntax tree for the first parsed statement. @@ -110,7 +113,7 @@ def transpile( identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both: the source and the target dialect. error_level: the desired error level of the parser. - **opts: other options. + **opts: other `sqlglot.generator.Generator` options. Returns: The list of transpiled SQL statements. diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md index 54d3856..02179f4 100644 --- a/sqlglot/dataframe/README.md +++ b/sqlglot/dataframe/README.md @@ -1,29 +1,29 @@ # PySpark DataFrame SQL Generator -This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). +This is a drop-in replacement for the PySpark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported. # How to use ## Instructions -* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library -* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe` -* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('', )` +* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library. +* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`. +* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('', )`. * The column structure can be defined the following ways: - * Dictionary where the keys are column names and values are string of the Spark SQL type name - * Ex: {'cola': 'string', 'colb': 'int'} - * PySpark DataFrame `StructType` similar to when using `createDataFrame` + * Dictionary where the keys are column names and values are string of the Spark SQL type name. + * Ex: `{'cola': 'string', 'colb': 'int'}` + * PySpark DataFrame `StructType` similar to when using `createDataFrame`. * Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])` - * A string of names and types similar to what is supported in `createDataFrame` + * A string of names and types similar to what is supported in `createDataFrame`. * Ex: `cola: STRING, colb: INT` - * [Not Recommended] A list of string column names without type - * Ex: ['cola', 'colb'] - * The lack of types may limit functionality in future releases - * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally -* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command + * [Not Recommended] A list of string column names without type. + * Ex: `['cola', 'colb']` + * The lack of types may limit functionality in future releases. + * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally. +* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command. * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects. - * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects + * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects. * Ex: `.sql(pretty=True, dialect='bigquery')` ## Examples @@ -51,7 +51,7 @@ df = ( print(df.sql(pretty=True)) # Spark will be the dialect used by default ``` -Output: + ```sparksql SELECT `employee`.`age` AS `age`, @@ -206,7 +206,7 @@ sql_statements = ( .createDataFrame(data, schema) .groupBy(F.col("age")) .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) - .sql(dialect="bigquery") + .sql(dialect="spark") ) pyspark = PySparkSession.builder.master("local[*]").getOrCreate() diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index a17bb9d..65a37f5 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -111,16 +111,13 @@ class DataFrame: return DataFrameNaFunctions(self) def _replace_cte_names_with_hashes(self, expression: exp.Select): - expression = expression.copy() - ctes = expression.ctes replacement_mapping = {} - for cte in ctes: + for cte in expression.ctes: old_name_id = cte.args["alias"].this new_hashed_id = exp.to_identifier( self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] ) replacement_mapping[old_name_id] = new_hashed_id - cte.set("alias", exp.TableAlias(this=new_hashed_id)) expression = expression.transform(replace_id_value, replacement_mapping) return expression @@ -183,7 +180,7 @@ class DataFrame: expression = df.expression hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) for hint in df.pending_partition_hints: - hint_expression.args.get("expressions").append(hint) + hint_expression.append("expressions", hint) df.pending_hints.remove(hint) join_aliases = { @@ -209,7 +206,7 @@ class DataFrame: sequence_id_expression.set("this", matching_cte.args["alias"].this) df.pending_hints.remove(hint) break - hint_expression.args.get("expressions").append(hint) + hint_expression.append("expressions", hint) if hint_expression.expressions: expression.set("hint", hint_expression) return df diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index c4a22c6..af589b0 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -129,7 +129,7 @@ class SparkSession: @property def _random_name(self) -> str: - return f"a{str(uuid.uuid4())[:8]}" + return "r" + uuid.uuid4().hex @property def _random_branch_id(self) -> str: @@ -145,7 +145,7 @@ class SparkSession: @property def _random_id(self) -> str: - id = f"a{str(uuid.uuid4())[:8]}" + id = self._random_name self.known_ids.add(id) return id diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 2084681..34cf613 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -1,3 +1,64 @@ +""" +## Dialects + +One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a +"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`, +`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming +SQL code. + +However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such +example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to +override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new +dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class), +in order to implement the target behavior. + + +### Implementing a custom Dialect + +Consider the following example: + +```python +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.tokens import Tokenizer, TokenType + + +class Custom(Dialect): + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + IDENTIFIERS = ["`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "INT64": TokenType.BIGINT, + "FLOAT64": TokenType.DOUBLE, + } + + class Generator(Generator): + TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"} + + TYPE_MAPPING = { + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.TEXT: "STRING", + } +``` + +This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string +delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since +the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation +logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping. + +---- +""" + from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 9ddfbea..e7d30ec 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -124,7 +124,6 @@ class BigQuery(Dialect): "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, "NOT DETERMINISTIC": TokenType.VOLATILE, - "QUALIFY": TokenType.QUALIFY, "UNKNOWN": TokenType.NULL, } KEYWORDS.pop("DIV") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 1c173a4..9e8c691 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -73,13 +73,8 @@ class ClickHouse(Dialect): return this - def _parse_position(self) -> exp.Expression: - this = super()._parse_position() - # clickhouse position args are swapped - substr = this.this - this.args["this"] = this.args.get("substr") - this.args["substr"] = substr - return this + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: + return super()._parse_position(haystack_first=True) # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ def _parse_cte(self) -> exp.Expression: diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 1bddfe1..2a0a917 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -124,6 +124,8 @@ class MySQL(Dialect): **tokens.Tokenizer.KEYWORDS, "MEDIUMTEXT": TokenType.MEDIUMTEXT, "LONGTEXT": TokenType.LONGTEXT, + "MEDIUMBLOB": TokenType.MEDIUMBLOB, + "LONGBLOB": TokenType.LONGBLOB, "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, @@ -459,6 +461,8 @@ class MySQL(Dialect): TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) + TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) + TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index c44950a..6225a53 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -194,7 +194,8 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "QUALIFY": TokenType.QUALIFY, + "EXCLUDE": TokenType.EXCEPT, + "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -232,6 +233,11 @@ class Snowflake(Dialect): exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } + STAR_MAPPING = { + "except": "EXCLUDE", + "replace": "RENAME", + } + ROOT_PROPERTIES = { exp.PartitionedByProperty, exp.ReturnsProperty, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 9342e6b..9f9099e 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func @@ -251,6 +252,7 @@ class TSQL(Dialect): "NTEXT": TokenType.TEXT, "NVARCHAR(MAX)": TokenType.TEXT, "PRINT": TokenType.COMMAND, + "PROC": TokenType.PROCEDURE, "REAL": TokenType.FLOAT, "ROWVERSION": TokenType.ROWVERSION, "SMALLDATETIME": TokenType.DATETIME, @@ -263,6 +265,11 @@ class TSQL(Dialect): "XML": TokenType.XML, } + # TSQL allows @, # to appear as a variable/identifier prefix + SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() + SINGLE_TOKENS.pop("@") + SINGLE_TOKENS.pop("#") + class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore @@ -293,26 +300,82 @@ class TSQL(Dialect): DataType.Type.NCHAR, } - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table - TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER} + RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore + TokenType.TABLE, + *parser.Parser.TYPE_TOKENS, # type: ignore + } + + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, # type: ignore + TokenType.END: lambda self: self._parse_command(), + } + + def _parse_system_time(self) -> t.Optional[exp.Expression]: + if not self._match_text_seq("FOR", "SYSTEM_TIME"): + return None - def _parse_convert(self, strict): + if self._match_text_seq("AS", "OF"): + system_time = self.expression( + exp.SystemTime, this=self._parse_bitwise(), kind="AS OF" + ) + elif self._match_set((TokenType.FROM, TokenType.BETWEEN)): + kind = self._prev.text + this = self._parse_bitwise() + self._match_texts(("TO", "AND")) + expression = self._parse_bitwise() + system_time = self.expression( + exp.SystemTime, this=this, expression=expression, kind=kind + ) + elif self._match_text_seq("CONTAINED", "IN"): + args = self._parse_wrapped_csv(self._parse_bitwise) + system_time = self.expression( + exp.SystemTime, + this=seq_get(args, 0), + expression=seq_get(args, 1), + kind="CONTAINED IN", + ) + elif self._match(TokenType.ALL): + system_time = self.expression(exp.SystemTime, kind="ALL") + else: + system_time = None + self.raise_error("Unable to parse FOR SYSTEM_TIME clause") + + return system_time + + def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + table = super()._parse_table_parts(schema=schema) + table.set("system_time", self._parse_system_time()) + return table + + def _parse_returns(self) -> exp.Expression: + table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) + returns = super()._parse_returns() + returns.set("table", table) + return returns + + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_conjunction() + if not to or not this: + return None + # Retrieve length of datatype and override to default if not specified if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) # Check whether a conversion with format is applicable if self._match(TokenType.COMMA): - format_val = self._parse_number().name - if format_val not in TSQL.convert_format_mapping: + format_val = self._parse_number() + format_val_name = format_val.name if format_val else "" + + if format_val_name not in TSQL.convert_format_mapping: raise ValueError( - f"CONVERT function at T-SQL does not support format style {format_val}" + f"CONVERT function at T-SQL does not support format style {format_val_name}" ) - format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) + + format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name]) # Check whether the convert entails a string to date format if to.this == DataType.Type.DATE: @@ -333,6 +396,21 @@ class TSQL(Dialect): # Entails a simple cast without any format requirement return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_user_defined_function( + self, kind: t.Optional[TokenType] = None + ) -> t.Optional[exp.Expression]: + this = super()._parse_user_defined_function(kind=kind) + + if ( + kind == TokenType.FUNCTION + or isinstance(this, exp.UserDefinedFunction) + or self._match(TokenType.ALIAS, advance=False) + ): + return this + + expressions = self._parse_csv(self._parse_udf_kwarg) + return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -354,3 +432,27 @@ class TSQL(Dialect): exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, } + + TRANSFORMS.pop(exp.ReturnsProperty) + + def systemtime_sql(self, expression: exp.SystemTime) -> str: + kind = expression.args["kind"] + if kind == "ALL": + return "FOR SYSTEM_TIME ALL" + + start = self.sql(expression, "this") + if kind == "AS OF": + return f"FOR SYSTEM_TIME AS OF {start}" + + end = self.sql(expression, "expression") + if kind == "FROM": + return f"FOR SYSTEM_TIME FROM {start} TO {end}" + if kind == "BETWEEN": + return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}" + + return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})" + + def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: + table = expression.args.get("table") + table = f"{table} " if table else "" + return f"RETURNS {table}{self.sql(expression, 'this')}" diff --git a/sqlglot/diff.py b/sqlglot/diff.py index fa8bc1b..a5373b0 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -1,5 +1,6 @@ """ .. include:: ../posts/sql_diff.md +---- """ from __future__ import annotations @@ -75,12 +76,13 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: ] Args: - source (sqlglot.Expression): the source expression. - target (sqlglot.Expression): the target expression against which the diff should be calculated. + source: the source expression. + target: the target expression against which the diff should be calculated. Returns: - the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees. - This list represents a sequence of steps needed to transform the source expression tree into the target one. + the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the + target expression trees. This list represents a sequence of steps needed to transform the source + expression tree into the target one. """ return ChangeDistiller().diff(source.copy(), target.copy()) @@ -258,7 +260,7 @@ class ChangeDistiller: return bigram_histo -def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]: +def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: has_child_exprs = False for a in expression.args.values(): diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index 8a58287..c405c45 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -63,7 +63,7 @@ class Context: reader = table[i] yield reader, self - def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]: + def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]: self.env["scope"] = self.row_readers for reader in self.tables[table]: diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index be99fe2..f9751ca 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,5 +1,12 @@ """ -.. include:: ../pdoc/docs/expressions.md +## Expressions + +Every AST node in SQLGlot is represented by a subclass of `Expression`. + +This module contains the implementation of all supported `Expression` types. Additionally, +it exposes a number of helper functions, which are mainly used to programmatically build +SQL expressions, such as `sqlglot.expressions.select`. +---- """ from __future__ import annotations @@ -27,35 +34,66 @@ from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import Dialect + IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], + ] + class _Expression(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) + + # When an Expression class is created, its key is automatically set to be + # the lowercase version of the class' name. klass.key = clsname.lower() + + # This is so that docstrings are not inherited in pdoc + klass.__doc__ = klass.__doc__ or "" + return klass class Expression(metaclass=_Expression): """ - The base class for all expressions in a syntax tree. + The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary + context, such as its child expressions, their names (arg keys), and whether a given child expression + is optional or not. Attributes: - arg_types (dict): determines arguments supported by this expression. - The key in a dictionary defines a unique key of an argument using - which the argument's value can be retrieved. The value is a boolean - flag which indicates whether the argument's value is required (True) - or optional (False). + key: a unique key for each class in the Expression hierarchy. This is useful for hashing + and representing expressions as strings. + arg_types: determines what arguments (child nodes) are supported by an expression. It + maps arg keys to booleans that indicate whether the corresponding args are optional. + + Example: + >>> class Foo(Expression): + ... arg_types = {"this": True, "expression": False} + + The above definition informs us that Foo is an Expression that requires an argument called + "this" and may also optionally receive an argument called "expression". + + Args: + args: a mapping used for retrieving the arguments of an expression, given their arg keys. + parent: a reference to the parent expression (or None, in case of root expressions). + arg_key: the arg key an expression is associated with, i.e. the name its parent expression + uses to refer to it. + comments: a list of comments that are associated with a given expression. This is used in + order to preserve comments when transpiling SQL code. + _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the + optimizer, in order to enable some transformations that require type information. """ - key = "Expression" + key = "expression" arg_types = {"this": True} __slots__ = ("args", "parent", "arg_key", "comments", "_type") - def __init__(self, **args): - self.args = args - self.parent = None - self.arg_key = None - self.comments = None + def __init__(self, **args: t.Any): + self.args: t.Dict[str, t.Any] = args + self.parent: t.Optional[Expression] = None + self.arg_key: t.Optional[str] = None + self.comments: t.Optional[t.List[str]] = None self._type: t.Optional[DataType] = None for arg_key, value in self.args.items(): @@ -76,17 +114,30 @@ class Expression(metaclass=_Expression): @property def this(self): + """ + Retrieves the argument with key "this". + """ return self.args.get("this") @property def expression(self): + """ + Retrieves the argument with key "expression". + """ return self.args.get("expression") @property def expressions(self): + """ + Retrieves the argument with key "expressions". + """ return self.args.get("expressions") or [] def text(self, key): + """ + Returns a textual representation of the argument corresponding to "key". This can only be used + for args that are strings or leaf Expression instances, such as identifiers and literals. + """ field = self.args.get(key) if isinstance(field, str): return field @@ -96,14 +147,23 @@ class Expression(metaclass=_Expression): @property def is_string(self): + """ + Checks whether a Literal expression is a string. + """ return isinstance(self, Literal) and self.args["is_string"] @property def is_number(self): + """ + Checks whether a Literal expression is a number. + """ return isinstance(self, Literal) and not self.args["is_string"] @property def is_int(self): + """ + Checks whether a Literal expression is an integer. + """ if self.is_number: try: int(self.name) @@ -114,6 +174,9 @@ class Expression(metaclass=_Expression): @property def alias(self): + """ + Returns the alias of the expression, or an empty string if it's not aliased. + """ if isinstance(self.args.get("alias"), TableAlias): return self.args["alias"].name return self.text("alias") @@ -128,6 +191,24 @@ class Expression(metaclass=_Expression): return "NULL" return self.alias or self.name + @property + def output_name(self): + """ + Name of the output column if this expression is a selection. + + If the Expression has no output name, an empty string is returned. + + Example: + >>> from sqlglot import parse_one + >>> parse_one("SELECT a").expressions[0].output_name + 'a' + >>> parse_one("SELECT b AS c").expressions[0].output_name + 'c' + >>> parse_one("SELECT 1 + 2").expressions[0].output_name + '' + """ + return "" + @property def type(self) -> t.Optional[DataType]: return self._type @@ -145,6 +226,9 @@ class Expression(metaclass=_Expression): return copy def copy(self): + """ + Returns a deep copy of the expression. + """ new = deepcopy(self) for item, parent, _ in new.bfs(): if isinstance(item, Expression) and parent: @@ -169,7 +253,7 @@ class Expression(metaclass=_Expression): Sets `arg_key` to `value`. Args: - arg_key (str): name of the expression arg + arg_key (str): name of the expression arg. value: value to set the arg to. """ self.args[arg_key] = value @@ -203,8 +287,7 @@ class Expression(metaclass=_Expression): expression_types (type): the expression type(s) to match. Returns: - the node which matches the criteria or None if no node matching - the criteria was found. + The node which matches the criteria or None if no such node was found. """ return next(self.find_all(*expression_types, bfs=bfs), None) @@ -217,7 +300,7 @@ class Expression(metaclass=_Expression): expression_types (type): the expression type(s) to match. Returns: - the generator object. + The generator object. """ for expression, _, _ in self.walk(bfs=bfs): if isinstance(expression, expression_types): @@ -231,7 +314,7 @@ class Expression(metaclass=_Expression): expression_types (type): the expression type(s) to match. Returns: - the parent node + The parent node. """ ancestor = self.parent while ancestor and not isinstance(ancestor, expression_types): @@ -269,7 +352,7 @@ class Expression(metaclass=_Expression): the DFS (Depth-first) order. Returns: - the generator object. + The generator object. """ parent = parent or self.parent yield self, parent, key @@ -287,7 +370,7 @@ class Expression(metaclass=_Expression): the BFS (Breadth-first) order. Returns: - the generator object. + The generator object. """ queue = deque([(self, self.parent, None)]) @@ -341,32 +424,33 @@ class Expression(metaclass=_Expression): return self.sql() def __repr__(self): - return self.to_s() + return self._to_s() def sql(self, dialect: Dialect | str | None = None, **opts) -> str: """ Returns SQL string representation of this tree. - Args - dialect (str): the dialect of the output SQL string - (eg. "spark", "hive", "presto", "mysql"). - opts (dict): other :class:`~sqlglot.generator.Generator` options. + Args: + dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql"). + opts: other `sqlglot.generator.Generator` options. - Returns - the SQL string. + Returns: + The SQL string. """ from sqlglot.dialects import Dialect return Dialect.get_or_raise(dialect)().generate(self, **opts) - def to_s(self, hide_missing: bool = True, level: int = 0) -> str: + def _to_s(self, hide_missing: bool = True, level: int = 0) -> str: indent = "" if not level else "\n" indent += "".join([" "] * level) left = f"({self.key.upper()} " args: t.Dict[str, t.Any] = { k: ", ".join( - v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) + v._to_s(hide_missing=hide_missing, level=level + 1) + if hasattr(v, "_to_s") + else str(v) for v in ensure_collection(vs) if v is not None ) @@ -394,7 +478,7 @@ class Expression(metaclass=_Expression): modified in place. Returns: - the transformed tree. + The transformed tree. """ node = self.copy() if copy else self new_node = fun(node, *args, **kwargs) @@ -423,8 +507,8 @@ class Expression(metaclass=_Expression): Args: expression (Expression|None): new node - Returns : - the new expression or expressions + Returns: + The new expression or expressions. """ if not self.parent: return expression @@ -458,6 +542,40 @@ class Expression(metaclass=_Expression): assert isinstance(self, type_) return self + def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: + """ + Checks if this expression is valid (e.g. all mandatory args are set). + + Args: + args: a sequence of values that were used to instantiate a Func expression. This is used + to check that the provided arguments don't exceed the function argument limit. + + Returns: + A list of error messages for all possible errors that were found. + """ + errors: t.List[str] = [] + + for k in self.args: + if k not in self.arg_types: + errors.append(f"Unexpected keyword: '{k}' for {self.__class__}") + for k, mandatory in self.arg_types.items(): + v = self.args.get(k) + if mandatory and (v is None or (isinstance(v, list) and not v)): + errors.append(f"Required keyword: '{k}' missing for {self.__class__}") + + if ( + args + and isinstance(self, Func) + and len(args) > len(self.arg_types) + and not self.is_var_len_args + ): + errors.append( + f"The number of provided arguments ({len(args)}) is greater than " + f"the maximum number of supported arguments ({len(self.arg_types)})" + ) + + return errors + def dump(self): """ Dump this Expression to a JSON-serializable dict. @@ -552,7 +670,7 @@ class DerivedTable(Expression): @property def named_selects(self): - return [select.alias_or_name for select in self.selects] + return [select.output_name for select in self.selects] class Unionable(Expression): @@ -654,6 +772,7 @@ class Create(Expression): "no_primary_index": False, "indexes": False, "no_schema_binding": False, + "begin": False, } @@ -696,7 +815,7 @@ class Show(Expression): class UserDefinedFunction(Expression): - arg_types = {"this": True, "expressions": False} + arg_types = {"this": True, "expressions": False, "wrapped": False} class UserDefinedFunctionKwarg(Expression): @@ -750,6 +869,10 @@ class Column(Condition): def table(self): return self.text("table") + @property + def output_name(self): + return self.name + class ColumnDef(Expression): arg_types = { @@ -865,6 +988,10 @@ class ForeignKey(Expression): } +class PrimaryKey(Expression): + arg_types = {"expressions": True, "options": False} + + class Unique(Expression): arg_types = {"expressions": True} @@ -904,6 +1031,10 @@ class Identifier(Expression): def __hash__(self): return hash((self.key, self.this.lower())) + @property + def output_name(self): + return self.name + class Index(Expression): arg_types = { @@ -996,6 +1127,10 @@ class Literal(Condition): def string(cls, string) -> Literal: return cls(this=str(string), is_string=True) + @property + def output_name(self): + return self.name + class Join(Expression): arg_types = { @@ -1186,7 +1321,7 @@ class SchemaCommentProperty(Property): class ReturnsProperty(Property): - arg_types = {"this": True, "is_table": False} + arg_types = {"this": True, "is_table": False, "table": False} class LanguageProperty(Property): @@ -1262,8 +1397,13 @@ class Qualify(Expression): pass +# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql +class Return(Expression): + pass + + class Reference(Expression): - arg_types = {"this": True, "expressions": True} + arg_types = {"this": True, "expressions": False, "options": False} class Tuple(Expression): @@ -1397,6 +1537,16 @@ class Table(Expression): "joins": False, "pivots": False, "hints": False, + "system_time": False, + } + + +# See the TSQL "Querying data in a system-versioned temporal table" page +class SystemTime(Expression): + arg_types = { + "this": False, + "expression": False, + "kind": True, } @@ -2027,7 +2177,7 @@ class Select(Subqueryable): @property def named_selects(self) -> t.List[str]: - return [e.alias_or_name for e in self.expressions if e.alias_or_name] + return [e.output_name for e in self.expressions if e.alias_or_name] @property def selects(self) -> t.List[Expression]: @@ -2051,6 +2201,10 @@ class Subquery(DerivedTable, Unionable): expression = expression.this return expression + @property + def output_name(self): + return self.alias + class TableSample(Expression): arg_types = { @@ -2066,6 +2220,16 @@ class TableSample(Expression): } +class Tag(Expression): + """Tags are used for generating arbitrary sql like SELECT x.""" + + arg_types = { + "this": False, + "prefix": False, + "postfix": False, + } + + class Pivot(Expression): arg_types = { "this": False, @@ -2106,6 +2270,10 @@ class Star(Expression): def name(self): return "*" + @property + def output_name(self): + return self.name + class Parameter(Expression): pass @@ -2143,6 +2311,8 @@ class DataType(Expression): TEXT = auto() MEDIUMTEXT = auto() LONGTEXT = auto() + MEDIUMBLOB = auto() + LONGBLOB = auto() BINARY = auto() VARBINARY = auto() INT = auto() @@ -2282,11 +2452,11 @@ class Rollback(Expression): class AlterTable(Expression): - arg_types = { - "this": True, - "actions": True, - "exists": False, - } + arg_types = {"this": True, "actions": True, "exists": False} + + +class AddConstraint(Expression): + arg_types = {"this": False, "expression": False, "enforced": False} # Binary expressions like (ADD a b) @@ -2456,6 +2626,10 @@ class Neg(Unary): class Alias(Expression): arg_types = {"this": True, "alias": False} + @property + def output_name(self): + return self.alias + class Aliases(Expression): arg_types = {"this": True, "expressions": True} @@ -2523,16 +2697,13 @@ class Func(Condition): """ The base class for all function expressions. - Attributes - is_var_len_args (bool): if set to True the last argument defined in - arg_types will be treated as a variable length argument and the - argument's value will be stored as a list. - _sql_names (list): determines the SQL name (1st item in the list) and - aliases (subsequent items) for this function expression. These - values are used to map this node to a name during parsing as well - as to provide the function's name during SQL string generation. By - default the SQL name is set to the expression's class name transformed - to snake case. + Attributes: + is_var_len_args (bool): if set to True the last argument defined in arg_types will be + treated as a variable length argument and the argument's value will be stored as a list. + _sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items) + for this function expression. These values are used to map this node to a name during parsing + as well as to provide the function's name during SQL string generation. By default the SQL + name is set to the expression's class name transformed to snake case. """ is_var_len_args = False @@ -2558,7 +2729,7 @@ class Func(Condition): raise NotImplementedError( "SQL name is only supported by concrete function implementations" ) - if not hasattr(cls, "_sql_names"): + if "_sql_names" not in cls.__dict__: cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2658,6 +2829,10 @@ class Cast(Func): def to(self): return self.args["to"] + @property + def output_name(self): + return self.name + class Collate(Binary): pass @@ -2956,6 +3131,14 @@ class Pow(Func): _sql_names = ["POWER", "POW"] +class PercentileCont(AggFunc): + pass + + +class PercentileDisc(AggFunc): + pass + + class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} @@ -3213,12 +3396,13 @@ def _norm_arg(arg): ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) +# Helpers def maybe_parse( - sql_or_expression, + sql_or_expression: str | Expression, *, - into=None, - dialect=None, - prefix=None, + into: t.Optional[IntoType] = None, + dialect: t.Optional[str] = None, + prefix: t.Optional[str] = None, **opts, ) -> Expression: """Gracefully handle a possible string or expression. @@ -3230,11 +3414,11 @@ def maybe_parse( (IDENTIFIER this: x, quoted: False) Args: - sql_or_expression (str | Expression): the SQL code string or an expression - into (Expression): the SQLGlot Expression to parse into - dialect (str): the dialect used to parse the input expressions (in the case that an + sql_or_expression: the SQL code string or an expression + into: the SQLGlot Expression to parse into + dialect: the dialect used to parse the input expressions (in the case that an input expression is a SQL string). - prefix (str): a string to prefix the sql with before it gets parsed + prefix: a string to prefix the sql with before it gets parsed (automatically includes a space) **opts: other options to use to parse the input expressions (again, in the case that an input expression is a SQL string). @@ -3993,7 +4177,7 @@ def table_name(table) -> str: """Get the full name of a table as a string. Args: - table (exp.Table | str): Table expression node or string. + table (exp.Table | str): table expression node or string. Examples: >>> from sqlglot import exp, parse_one @@ -4001,7 +4185,7 @@ def table_name(table) -> str: 'a.b.c' Returns: - str: the table name + The table name. """ table = maybe_parse(table, into=Table) @@ -4024,8 +4208,8 @@ def replace_tables(expression, mapping): """Replace all tables in expression according to the mapping. Args: - expression (sqlglot.Expression): Expression node to be transformed and replaced - mapping (Dict[str, str]): Mapping of table names + expression (sqlglot.Expression): expression node to be transformed and replaced. + mapping (Dict[str, str]): mapping of table names. Examples: >>> from sqlglot import exp, parse_one @@ -4033,7 +4217,7 @@ def replace_tables(expression, mapping): 'SELECT * FROM c' Returns: - The mapped expression + The mapped expression. """ def _replace_tables(node): @@ -4053,9 +4237,9 @@ def replace_placeholders(expression, *args, **kwargs): """Replace placeholders in an expression. Args: - expression (sqlglot.Expression): Expression node to be transformed and replaced - args: Positional names that will substitute unnamed placeholders in the given order - kwargs: Keyword arguments that will substitute named placeholders + expression (sqlglot.Expression): expression node to be transformed and replaced. + args: positional names that will substitute unnamed placeholders in the given order. + kwargs: keyword arguments that will substitute named placeholders. Examples: >>> from sqlglot import exp, parse_one @@ -4065,7 +4249,7 @@ def replace_placeholders(expression, *args, **kwargs): 'SELECT * FROM foo WHERE a = b' Returns: - The mapped expression + The mapped expression. """ def _replace_placeholders(node, args, **kwargs): @@ -4084,15 +4268,101 @@ def replace_placeholders(expression, *args, **kwargs): return expression.transform(_replace_placeholders, iter(args), **kwargs) +def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression: + """Transforms an expression by expanding all referenced sources into subqueries. + + Examples: + >>> from sqlglot import parse_one + >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql() + 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */' + + Args: + expression: The expression to expand. + sources: A dictionary of name to Subqueryables. + copy: Whether or not to copy the expression during transformation. Defaults to True. + + Returns: + The transformed expression. + """ + + def _expand(node: Expression): + if isinstance(node, Table): + name = table_name(node) + source = sources.get(name) + if source: + subquery = source.subquery(node.alias or name) + subquery.comments = [f"source: {name}"] + return subquery + return node + + return expression.transform(_expand, copy=copy) + + +def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func: + """ + Returns a Func expression. + + Examples: + >>> func("abs", 5).sql() + 'ABS(5)' + + >>> func("cast", this=5, to=DataType.build("DOUBLE")).sql() + 'CAST(5 AS DOUBLE)' + + Args: + name: the name of the function to build. + args: the args used to instantiate the function of interest. + dialect: the source dialect. + kwargs: the kwargs used to instantiate the function of interest. + + Note: + The arguments `args` and `kwargs` are mutually exclusive. + + Returns: + An instance of the function of interest, or an anonymous function, if `name` doesn't + correspond to an existing `sqlglot.expressions.Func` class. + """ + if args and kwargs: + raise ValueError("Can't use both args and kwargs to instantiate a function.") + + from sqlglot.dialects.dialect import Dialect + + args = tuple(convert(arg) for arg in args) + kwargs = {key: convert(value) for key, value in kwargs.items()} + + parser = Dialect.get_or_raise(dialect)().parser() + from_args_list = parser.FUNCTIONS.get(name.upper()) + + if from_args_list: + function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore + else: + kwargs = kwargs or {"expressions": args} + function = Anonymous(this=name, **kwargs) + + for error_message in function.error_messages(args): + raise ValueError(error_message) + + return function + + def true(): + """ + Returns a true Boolean expression. + """ return Boolean(this=True) def false(): + """ + Returns a false Boolean expression. + """ return Boolean(this=False) def null(): + """ + Returns a Null expression. + """ return Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 6375d92..b398d8e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -16,7 +16,7 @@ class Generator: """ Generator interprets the given syntax tree and produces a SQL string as an output. - Args + Args: time_mapping (dict): the dictionary of custom time mappings in which the key represents a python time format and the output the target time format time_trie (trie): a trie of the time_mapping keys @@ -84,6 +84,13 @@ class Generator: exp.DataType.Type.NVARCHAR: "VARCHAR", exp.DataType.Type.MEDIUMTEXT: "TEXT", exp.DataType.Type.LONGTEXT: "TEXT", + exp.DataType.Type.MEDIUMBLOB: "BLOB", + exp.DataType.Type.LONGBLOB: "BLOB", + } + + STAR_MAPPING = { + "except": "EXCEPT", + "replace": "REPLACE", } TOKEN_MAPPING: t.Dict[TokenType, str] = {} @@ -106,6 +113,8 @@ class Generator: exp.TableFormatProperty, } + WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint) + WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" @@ -241,15 +250,17 @@ class Generator: return sql sep = "\n" if self.pretty else " " - comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment) + comments_sql = sep.join( + f"/*{self.pad_comment(comment)}*/" for comment in comments if comment + ) - if not comments: + if not comments_sql: return sql if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"{comments}{self.sep()}{sql}" + return f"{comments_sql}{self.sep()}{sql}" - return f"{sql} {comments}" + return f"{sql} {comments_sql}" def wrap(self, expression: exp.Expression | str) -> str: this_sql = self.indent( @@ -433,8 +444,9 @@ class Generator: def create_sql(self, expression: exp.Create) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() + begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") - expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else "" + expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" @@ -741,12 +753,14 @@ class Generator: laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") pivots = self.expressions(expression, key="pivots", sep="") + system_time = expression.args.get("system_time") + system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" if alias and pivots: pivots = f"{pivots}{alias}" alias = "" - return f"{table}{alias}{hints}{laterals}{joins}{pivots}" + return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" def tablesample_sql(self, expression: exp.TableSample) -> str: if self.alias_post_tablesample and expression.this.alias: @@ -1009,9 +1023,9 @@ class Generator: def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) - except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" + except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else "" replace = self.expressions(expression, key="replace", flat=True) - replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" + replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else "" return f"*{except_}{replace}" def structkwarg_sql(self, expression: exp.StructKwarg) -> str: @@ -1193,6 +1207,12 @@ class Generator: update = f" ON UPDATE {update}" if update else "" return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" + def primarykey_sql(self, expression: exp.ForeignKey) -> str: + expressions = self.expressions(expression, flat=True) + options = self.expressions(expression, "options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"PRIMARY KEY ({expressions}){options}" + def unique_sql(self, expression: exp.Unique) -> str: columns = self.expressions(expression, key="expressions") return f"UNIQUE ({columns})" @@ -1229,10 +1249,16 @@ class Generator: unit = f" {unit}" if unit else "" return f"INTERVAL{this}{unit}" + def return_sql(self, expression: exp.Return) -> str: + return f"RETURN {self.sql(expression, 'this')}" + def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) - return f"REFERENCES {this}({expressions})" + expressions = f"({expressions})" if expressions else "" + options = self.expressions(expression, "options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"REFERENCES {this}{expressions}{options}" def anonymous_sql(self, expression: exp.Anonymous) -> str: args = self.format_args(*expression.expressions) @@ -1362,7 +1388,7 @@ class Generator: actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Drop): actions = self.expressions(expression, "actions") - elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)): + elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION): actions = self.sql(actions[0]) else: self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") @@ -1370,6 +1396,17 @@ class Generator: exists = " IF EXISTS" if expression.args.get("exists") else "" return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" + def addconstraint_sql(self, expression: exp.AddConstraint) -> str: + this = self.sql(expression, "this") + expression_ = self.sql(expression, "expression") + add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD" + + enforced = expression.args.get("enforced") + if enforced is not None: + return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}" + + return f"{add_constraint} {expression_}" + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1550,13 +1587,19 @@ class Generator: expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" ) + def tag_sql(self, expression: exp.Tag) -> str: + return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}" + def token_sql(self, token_type: TokenType) -> str: return self.TOKEN_MAPPING.get(token_type, token_type.name) def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") expressions = self.no_identify(lambda: self.expressions(expression)) - return f"{this}({expressions})" + expressions = ( + self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" + ) + return f"{this}{expressions}" def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: this = self.sql(expression, "this") diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 5a0f2ac..68e0383 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -332,7 +332,7 @@ def is_iterable(value: t.Any) -> bool: return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) -def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]: +def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: """ Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type `str` and `bytes` are not regarded as iterables. diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py new file mode 100644 index 0000000..4e7eab8 --- /dev/null +++ b/sqlglot/lineage.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import json +import typing as t +from dataclasses import dataclass, field + +from sqlglot import Schema, exp, maybe_parse +from sqlglot.optimizer import Scope, build_scope, optimize +from sqlglot.optimizer.qualify_columns import qualify_columns +from sqlglot.optimizer.qualify_tables import qualify_tables + + +@dataclass(frozen=True) +class Node: + name: str + expression: exp.Expression + source: exp.Expression + downstream: t.List[Node] = field(default_factory=list) + + def walk(self) -> t.Iterator[Node]: + yield self + + for d in self.downstream: + if isinstance(d, Node): + yield from d.walk() + else: + yield d + + def to_html(self, **opts) -> LineageHTML: + return LineageHTML(self, **opts) + + +def lineage( + column: str | exp.Column, + sql: str | exp.Expression, + schema: t.Optional[t.Dict | Schema] = None, + sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, + rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns), + dialect: t.Optional[str] = None, +) -> Node: + """Build the lineage graph for a column of a SQL query. + + Args: + column: The column to build the lineage for. + sql: The SQL string or expression. + schema: The schema of tables. + sources: A mapping of queries which will be used to continue building lineage. + rules: Optimizer rules to apply, by default only qualifying tables and columns. + dialect: The dialect of input SQL. + + Returns: + A lineage node. + """ + + expression = maybe_parse(sql, dialect=dialect) + + if sources: + expression = exp.expand( + expression, + { + k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect)) + for k, v in sources.items() + }, + ) + + optimized = optimize(expression, schema=schema, rules=rules) + scope = build_scope(optimized) + tables: t.Dict[str, Node] = {} + + def to_node( + column_name: str, + scope: Scope, + scope_name: t.Optional[str] = None, + upstream: t.Optional[Node] = None, + ) -> Node: + if isinstance(scope.expression, exp.Union): + for scope in scope.union_scopes: + node = to_node( + column_name, + scope=scope, + scope_name=scope_name, + upstream=upstream, + ) + return node + + select = next(select for select in scope.selects if select.alias_or_name == column_name) + source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules) + select = source.selects[0] + + node = Node( + name=f"{scope_name}.{column_name}" if scope_name else column_name, + source=source, + expression=select, + ) + + if upstream: + upstream.downstream.append(node) + + for c in set(select.find_all(exp.Column)): + table = c.table + source = scope.sources[table] + + if isinstance(source, Scope): + to_node( + c.name, + scope=source, + scope_name=table, + upstream=node, + ) + else: + if table not in tables: + tables[table] = Node(name=table, source=source, expression=source) + node.downstream.append(tables[table]) + + return node + + return to_node(column if isinstance(column, str) else column.name, scope) + + +class LineageHTML: + """Node to HTML generator using vis.js. + + https://visjs.github.io/vis-network/docs/network/ + """ + + def __init__( + self, + node: Node, + dialect: t.Optional[str] = None, + imports: bool = True, + **opts: t.Any, + ): + self.node = node + self.imports = imports + + self.options = { + "height": "500px", + "width": "100%", + "layout": { + "hierarchical": { + "enabled": True, + "nodeSpacing": 200, + "sortMethod": "directed", + }, + }, + "interaction": { + "dragNodes": False, + "selectable": False, + }, + "physics": { + "enabled": False, + }, + "edges": { + "arrows": "to", + }, + "nodes": { + "font": "20px monaco", + "shape": "box", + "widthConstraint": { + "maximum": 300, + }, + }, + **opts, + } + + self.nodes = {} + self.edges = [] + + for node in node.walk(): + if isinstance(node.expression, exp.Table): + label = f"FROM {node.expression.this}" + title = f"
SELECT {node.name} FROM {node.expression.this}
" + group = 1 + else: + label = node.expression.sql(pretty=True, dialect=dialect) + source = node.source.transform( + lambda n: exp.Tag(this=n, prefix="", postfix="") + if n is node.expression + else n, + copy=False, + ).sql(pretty=True, dialect=dialect) + title = f"
{source}
" + group = 0 + + node_id = id(node) + + self.nodes[node_id] = { + "id": node_id, + "label": label, + "title": title, + "group": group, + } + + for d in node.downstream: + self.edges.append({"from": node_id, "to": id(d)}) + + def __str__(self): + nodes = json.dumps(list(self.nodes.values())) + edges = json.dumps(self.edges) + options = json.dumps(self.options) + imports = ( + """ + + """ + if self.imports + else "" + ) + + return f"""
+
+ {imports} + +
""" + + def _repr_html_(self) -> str: + return self.__str__() diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index bba0878..719a77e 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1 +1,2 @@ from sqlglot.optimizer.optimizer import RULES, optimize +from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 652cdef..5bd7b30 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -1,15 +1,18 @@ from sqlglot import alias, exp from sqlglot.errors import OptimizeError from sqlglot.optimizer.scope import traverse_scope +from sqlglot.schema import ensure_schema -def isolate_table_selects(expression): +def isolate_table_selects(expression, schema=None): + schema = ensure_schema(schema) + for scope in traverse_scope(expression): if len(scope.selected_sources) == 1: continue for (_, source) in scope.selected_sources.values(): - if not isinstance(source, exp.Table): + if not isinstance(source, exp.Table) or not schema.column_names(source): continue if not source.alias: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 8da4e43..54425a8 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -1,7 +1,8 @@ import itertools +import typing as t from sqlglot import alias, exp -from sqlglot.errors import OptimizeError, SchemaError +from sqlglot.errors import OptimizeError from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -190,20 +191,15 @@ def _qualify_columns(scope, resolver): column_table = column.table column_name = column.name - if ( - column_table - and column_table in scope.sources - and column_name not in resolver.get_source_columns(column_table) - ): - raise OptimizeError(f"Unknown column: {column_name}") + if column_table and column_table in scope.sources: + source_columns = resolver.get_source_columns(column_table) + if source_columns and column_name not in source_columns: + raise OptimizeError(f"Unknown column: {column_name}") if not column_table: column_table = resolver.get_table(column_name) if not scope.is_subquery and not scope.is_udtf: - if column_name not in resolver.all_columns: - raise OptimizeError(f"Unknown column: {column_name}") - if column_table is None: raise OptimizeError(f"Ambiguous column: {column_name}") @@ -265,6 +261,10 @@ def _expand_stars(scope, resolver): if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") columns = resolver.get_source_columns(table, only_visible=True) + if not columns: + raise OptimizeError( + f"Table has no schema/columns. Cannot expand star for table: {table}." + ) table_id = id(table) for name in columns: if name not in except_columns.get(table_id, set()): @@ -306,16 +306,11 @@ def _qualify_outputs(scope): for i, (selection, aliased_column) in enumerate( itertools.zip_longest(scope.selects, scope.outer_column_list) ): - if isinstance(selection, exp.Column): - # convoluted setter because a simple selection.replace(alias) would require a copy - alias_ = alias(exp.column(""), alias=selection.name) - alias_.set("this", selection) - selection = alias_ - elif isinstance(selection, exp.Subquery): - if not selection.alias: + if isinstance(selection, exp.Subquery): + if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) elif not isinstance(selection, exp.Alias): - alias_ = alias(exp.column(""), f"_col_{i}") + alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") alias_.set("this", selection) selection = alias_ @@ -346,20 +341,30 @@ class _Resolver: self._unambiguous_columns = None self._all_columns = None - def get_table(self, column_name): + def get_table(self, column_name: str) -> t.Optional[str]: """ Get the table for a column name. Args: - column_name (str) + column_name: The column name to find the table for. Returns: - (str) table name + The table name if it can be found/inferred. """ if self._unambiguous_columns is None: self._unambiguous_columns = self._get_unambiguous_columns( self._get_all_source_columns() ) - return self._unambiguous_columns.get(column_name) + + table = self._unambiguous_columns.get(column_name) + + if not table: + sources_without_schema = tuple( + source for source, columns in self._get_all_source_columns().items() if not columns + ) + if len(sources_without_schema) == 1: + return sources_without_schema[0] + + return table @property def all_columns(self): @@ -379,10 +384,7 @@ class _Resolver: # If referencing a table, return the columns from the schema if isinstance(source, exp.Table): - try: - return self.schema.column_names(source, only_visible) - except Exception as e: - raise SchemaError(str(e)) from e + return self.schema.column_names(source, only_visible) if isinstance(source, Scope) and isinstance(source.expression, exp.Values): return source.expression.alias_column_names diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 6125e4e..5a3ed5a 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -230,7 +230,7 @@ class Scope: column for scope in self.subquery_scopes for column in scope.external_columns ] - named_outputs = {e.alias_or_name for e in self.expression.expressions} + named_selects = set(self.expression.named_selects) self._columns = [] for column in columns + external_columns: @@ -238,7 +238,7 @@ class Scope: if ( not ancestor or column.table - or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint)) + or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) ): self._columns.append(column) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c97b19a..42777d1 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -40,22 +40,23 @@ class _Parser(type): class Parser(metaclass=_Parser): """ - Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` - and produces a parsed syntax tree. - - Args - error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE. - error_message_context (int): determines the amount of context to capture from - a query string when displaying the error message (in number of characters). + Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces + a parsed syntax tree. + + Args: + error_level: the desired error level. + Default: ErrorLevel.RAISE + error_message_context: determines the amount of context to capture from a + query string when displaying the error message (in number of characters). Default: 50. - index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list + index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list. Default: 0 - alias_post_tablesample (bool): If the table alias comes after tablesample + alias_post_tablesample: If the table alias comes after tablesample. Default: False - max_errors (int): Maximum number of error messages to include in a raised ParseError. + max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3 - null_ordering (str): Indicates the default null ordering method to use if not explicitly set. + null_ordering: Indicates the default null ordering method to use if not explicitly set. Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Default: "nulls_are_small" """ @@ -109,6 +110,8 @@ class Parser(metaclass=_Parser): TokenType.TEXT, TokenType.MEDIUMTEXT, TokenType.LONGTEXT, + TokenType.MEDIUMBLOB, + TokenType.LONGBLOB, TokenType.BINARY, TokenType.VARBINARY, TokenType.JSON, @@ -176,6 +179,7 @@ class Parser(metaclass=_Parser): TokenType.DIV, TokenType.DISTKEY, TokenType.DISTSTYLE, + TokenType.END, TokenType.EXECUTE, TokenType.ENGINE, TokenType.ESCAPE, @@ -468,9 +472,6 @@ class Parser(metaclass=_Parser): TokenType.NULL: lambda self, _: self.expression(exp.Null), TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), - TokenType.PARAMETER: lambda self, _: self.expression( - exp.Parameter, this=self._parse_var() or self._parse_primary() - ), TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), @@ -479,6 +480,16 @@ class Parser(metaclass=_Parser): TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } + PLACEHOLDER_PARSERS = { + TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), + TokenType.PARAMETER: lambda self: self.expression( + exp.Parameter, this=self._parse_var() or self._parse_primary() + ), + TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set((TokenType.NUMBER, TokenType.VAR)) + else None, + } + RANGE_PARSERS = { TokenType.BETWEEN: lambda self, this: self._parse_between(this), TokenType.IN: lambda self, this: self._parse_in(this), @@ -601,8 +612,7 @@ class Parser(metaclass=_Parser): WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} - # allows tables to have special tokens as prefixes - TABLE_PREFIX_TOKENS: t.Set[TokenType] = set() + ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} STRICT_CAST = True @@ -677,7 +687,7 @@ class Parser(metaclass=_Parser): def parse_into( self, - expression_types: str | exp.Expression | t.Collection[exp.Expression | str], + expression_types: exp.IntoType, raw_tokens: t.List[Token], sql: t.Optional[str] = None, ) -> t.List[t.Optional[exp.Expression]]: @@ -820,24 +830,8 @@ class Parser(metaclass=_Parser): if self.error_level == ErrorLevel.IGNORE: return - for k in expression.args: - if k not in expression.arg_types: - self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}") - for k, mandatory in expression.arg_types.items(): - v = expression.args.get(k) - if mandatory and (v is None or (isinstance(v, list) and not v)): - self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - - if ( - args - and isinstance(expression, exp.Func) - and len(args) > len(expression.arg_types) - and not expression.is_var_len_args - ): - self.raise_error( - f"The number of provided arguments ({len(args)}) is greater than " - f"the maximum number of supported arguments ({len(expression.arg_types)})" - ) + for error_message in expression.error_messages(args): + self.raise_error(error_message) def _find_token(self, token: Token, sql: str) -> int: line = 1 @@ -868,6 +862,9 @@ class Parser(metaclass=_Parser): def _retreat(self, index: int) -> None: self._advance(index - self._index) + def _parse_command(self) -> exp.Expression: + return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -876,11 +873,7 @@ class Parser(metaclass=_Parser): return self.STATEMENT_PARSERS[self._prev.token_type](self) if self._match_set(Tokenizer.COMMANDS): - return self.expression( - exp.Command, - this=self._prev.text, - expression=self._parse_string(), - ) + return self._parse_command() expression = self._parse_expression() expression = self._parse_set_operations(expression) if expression else self._parse_select() @@ -942,12 +935,18 @@ class Parser(metaclass=_Parser): no_primary_index = None indexes = None no_schema_binding = None + begin = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): - this = self._parse_user_defined_function() + this = self._parse_user_defined_function(kind=create_token.token_type) properties = self._parse_properties() if self._match(TokenType.ALIAS): - expression = self._parse_select_or_expression() + begin = self._match(TokenType.BEGIN) + return_ = self._match_text_seq("RETURN") + expression = self._parse_statement() + + if return_: + expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: this = self._parse_index() elif create_token.token_type in ( @@ -1002,6 +1001,7 @@ class Parser(metaclass=_Parser): no_primary_index=no_primary_index, indexes=indexes, no_schema_binding=no_schema_binding, + begin=begin, ) def _parse_property(self) -> t.Optional[exp.Expression]: @@ -1087,7 +1087,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") else: - value = self._parse_schema(exp.Literal.string("TABLE")) + value = self._parse_schema(exp.Var(this="TABLE")) else: value = self._parse_types() @@ -1550,7 +1550,7 @@ class Parser(metaclass=_Parser): return None index = self._parse_id_var() columns = None - if self._curr and self._curr.token_type == TokenType.L_PAREN: + if self._match(TokenType.L_PAREN, advance=False): columns = self._parse_wrapped_csv(self._parse_column) return self.expression( exp.Index, @@ -1561,6 +1561,27 @@ class Parser(metaclass=_Parser): amp=amp, ) + def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + catalog = None + db = None + table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False) + + while self._match(TokenType.DOT): + if catalog: + # This allows nesting the table in arbitrarily many dot expressions if needed + table = self.expression(exp.Dot, this=table, expression=self._parse_id_var()) + else: + catalog = db + db = table + table = self._parse_id_var() + + if not table: + self.raise_error(f"Expected table name but got {self._curr}") + + return self.expression( + exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() + ) + def _parse_table( self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None ) -> t.Optional[exp.Expression]: @@ -1584,27 +1605,7 @@ class Parser(metaclass=_Parser): if subquery: return subquery - catalog = None - db = None - table = (not schema and self._parse_function()) or self._parse_id_var( - any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS - ) - - while self._match(TokenType.DOT): - if catalog: - # This allows nesting the table in arbitrarily many dot expressions if needed - table = self.expression(exp.Dot, this=table, expression=self._parse_id_var()) - else: - catalog = db - db = table - table = self._parse_id_var() - - if not table: - self.raise_error(f"Expected table name but got {self._curr}") - - this = self.expression( - exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() - ) + this = self._parse_table_parts(schema=schema) if schema: return self._parse_schema(this=this) @@ -1889,7 +1890,7 @@ class Parser(metaclass=_Parser): expression, this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), - expression=self._parse_select(nested=True), + expression=self._parse_set_operations(self._parse_select(nested=True)), ) def _parse_expression(self) -> t.Optional[exp.Expression]: @@ -2286,7 +2287,9 @@ class Parser(metaclass=_Parser): self._match_r_paren(this) return self._parse_window(this) - def _parse_user_defined_function(self) -> t.Optional[exp.Expression]: + def _parse_user_defined_function( + self, kind: t.Optional[TokenType] = None + ) -> t.Optional[exp.Expression]: this = self._parse_id_var() while self._match(TokenType.DOT): @@ -2297,7 +2300,9 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv(self._parse_udf_kwarg) self._match_r_paren() - return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + return self.expression( + exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True + ) def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: literal = self._parse_primary() @@ -2371,10 +2376,6 @@ class Parser(metaclass=_Parser): or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() - - if isinstance(this, exp.Literal): - this = this.name - return self.expression(exp.Schema, this=this, expressions=args) def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: @@ -2470,15 +2471,43 @@ class Parser(metaclass=_Parser): def _parse_unique(self) -> exp.Expression: return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) + def _parse_key_constraint_options(self) -> t.List[str]: + options = [] + while True: + if not self._curr: + break + + if self._match_text_seq("NOT", "ENFORCED"): + options.append("NOT ENFORCED") + elif self._match_text_seq("DEFERRABLE"): + options.append("DEFERRABLE") + elif self._match_text_seq("INITIALLY", "DEFERRED"): + options.append("INITIALLY DEFERRED") + elif self._match_text_seq("NORELY"): + options.append("NORELY") + elif self._match_text_seq("MATCH", "FULL"): + options.append("MATCH FULL") + elif self._match_text_seq("ON", "UPDATE", "NO ACTION"): + options.append("ON UPDATE NO ACTION") + elif self._match_text_seq("ON", "DELETE", "NO ACTION"): + options.append("ON DELETE NO ACTION") + else: + break + + return options + def _parse_references(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.REFERENCES): return None - return self.expression( - exp.Reference, - this=self._parse_id_var(), - expressions=self._parse_wrapped_id_vars(), - ) + expressions = None + this = self._parse_id_var() + + if self._match(TokenType.L_PAREN, advance=False): + expressions = self._parse_wrapped_id_vars() + + options = self._parse_key_constraint_options() + return self.expression(exp.Reference, this=this, expressions=expressions, options=options) def _parse_foreign_key(self) -> exp.Expression: expressions = self._parse_wrapped_id_vars() @@ -2503,12 +2532,14 @@ class Parser(metaclass=_Parser): options[kind] = action return self.expression( - exp.ForeignKey, - expressions=expressions, - reference=reference, - **options, # type: ignore + exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore ) + def _parse_primary_key(self) -> exp.Expression: + expressions = self._parse_wrapped_id_vars() + options = self._parse_key_constraint_options() + return self.expression(exp.PrimaryKey, expressions=expressions, options=options) + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.L_BRACKET): return this @@ -2631,7 +2662,7 @@ class Parser(metaclass=_Parser): order = self._parse_order(this=expression) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) - def _parse_convert(self, strict: bool) -> exp.Expression: + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: to: t.Optional[exp.Expression] this = self._parse_column() @@ -2641,19 +2672,25 @@ class Parser(metaclass=_Parser): to = self._parse_types() else: to = None + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_position(self) -> exp.Expression: + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) if self._match(TokenType.IN): - args.append(self._parse_bitwise()) + return self.expression( + exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0) + ) - this = exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), - ) + if haystack_first: + haystack = seq_get(args, 0) + needle = seq_get(args, 1) + else: + needle = seq_get(args, 0) + haystack = seq_get(args, 1) + + this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2)) self.validate_expression(this, args) @@ -2894,24 +2931,26 @@ class Parser(metaclass=_Parser): return None def _parse_placeholder(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.PLACEHOLDER): - return self.expression(exp.Placeholder) - elif self._match(TokenType.COLON): - if self._match_set((TokenType.NUMBER, TokenType.VAR)): - return self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set(self.PLACEHOLDER_PARSERS): + placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self) + if placeholder: + return placeholder self._advance(-1) return None def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.EXCEPT): return None - - return self._parse_wrapped_id_vars() + if self._match(TokenType.L_PAREN, advance=False): + return self._parse_wrapped_id_vars() + return self._parse_csv(self._parse_id_var) def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.REPLACE): return None - return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) + if self._match(TokenType.L_PAREN, advance=False): + return self._parse_wrapped_csv(self._parse_expression) + return self._parse_csv(self._parse_expression) def _parse_csv( self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA @@ -3021,6 +3060,28 @@ class Parser(metaclass=_Parser): def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + def _parse_add_constraint(self) -> t.Optional[exp.Expression]: + this = None + kind = self._prev.token_type + + if kind == TokenType.CONSTRAINT: + this = self._parse_id_var() + + if self._match(TokenType.CHECK): + expression = self._parse_wrapped(self._parse_conjunction) + enforced = self._match_text_seq("ENFORCED") + + return self.expression( + exp.AddConstraint, this=this, expression=expression, enforced=enforced + ) + + if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY): + expression = self._parse_foreign_key() + elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY): + expression = self._parse_primary_key() + + return self.expression(exp.AddConstraint, this=this, expression=expression) + def _parse_alter(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE): return None @@ -3029,8 +3090,14 @@ class Parser(metaclass=_Parser): this = self._parse_table(schema=True) actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None - if self._match_text_seq("ADD", advance=False): - actions = self._parse_csv(self._parse_add_column) + + index = self._index + if self._match_text_seq("ADD"): + if self._match_set(self.ADD_CONSTRAINT_TOKENS): + actions = self._parse_csv(self._parse_add_constraint) + else: + self._retreat(index) + actions = self._parse_csv(self._parse_add_column) elif self._match_text_seq("DROP", advance=False): actions = self._parse_csv(self._parse_drop_column) elif self._match_text_seq("RENAME", "TO"): @@ -3077,7 +3144,7 @@ class Parser(metaclass=_Parser): def _parse_merge(self) -> exp.Expression: self._match(TokenType.INTO) - target = self._parse_table(schema=True) + target = self._parse_table() self._match(TokenType.USING) using = self._parse_table() @@ -3146,12 +3213,13 @@ class Parser(metaclass=_Parser): self._retreat(index) return None - def _match(self, token_type): + def _match(self, token_type, advance=True): if not self._curr: return None if self._curr.token_type == token_type: - self._advance() + if advance: + self._advance() return True return None diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 4967231..40df39f 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -32,7 +32,7 @@ class Plan: return self._dag @property - def leaves(self) -> t.Generator[Step, None, None]: + def leaves(self) -> t.Iterator[Step]: return (node for node, deps in self.dag.items() if not deps) def __repr__(self) -> str: @@ -401,7 +401,7 @@ class SetOperation(Step): op=expression.__class__, left=left.name, right=right.name, - distinct=expression.args.get("distinct"), + distinct=bool(expression.args.get("distinct")), ) step.add_dependency(left) step.add_dependency(right) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index a0d69a7..f6f3883 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -109,10 +109,7 @@ class AbstractMappingSchema(t.Generic[T]): value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) if value == 0: - if raise_on_missing: - raise SchemaError(f"Cannot find mapping for {table}.") - else: - return None + return None elif value == 1: possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) if len(possibilities) == 1: @@ -262,7 +259,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): schema = self.find(table_) if schema is None: - raise SchemaError(f"Could not find table schema {table}") + return [] if not only_visible or not self.visible: return list(schema) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index f12528f..19dd1d6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -84,6 +84,8 @@ class TokenType(AutoName): TEXT = auto() MEDIUMTEXT = auto() LONGTEXT = auto() + MEDIUMBLOB = auto() + LONGBLOB = auto() BINARY = auto() VARBINARY = auto() JSON = auto() @@ -587,6 +589,7 @@ class Tokenizer(metaclass=_Tokenizer): "PRECEDING": TokenType.PRECEDING, "PRIMARY KEY": TokenType.PRIMARY_KEY, "PROCEDURE": TokenType.PROCEDURE, + "QUALIFY": TokenType.QUALIFY, "RANGE": TokenType.RANGE, "RECURSIVE": TokenType.RECURSIVE, "REGEXP": TokenType.RLIKE, @@ -726,6 +729,8 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.SHOW, } + COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN} + # handle numeric literals like in hive (3L = BIGINT) NUMERIC_LITERALS: t.Dict[str, str] = {} ENCODE: t.Optional[str] = None @@ -842,8 +847,10 @@ class Tokenizer(metaclass=_Tokenizer): ) self._comments = [] + # If we have either a semicolon or a begin token before the command's token, we'll parse + # whatever follows the command's token as a string if token_type in self.COMMANDS and ( - len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON + len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS ): start = self._current tokens = len(self.tokens) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 2827dd4..905e1f4 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -17,7 +17,8 @@ class TestClickhouse(Validator): self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") - self.validate_identity("position(a, b)") + self.validate_identity("position(haystack, needle)") + self.validate_identity("position(haystack, needle, position)") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", @@ -48,6 +49,10 @@ class TestClickhouse(Validator): "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", }, ) + self.validate_all( + "SELECT position(needle IN haystack)", + write={"clickhouse": "SELECT position(haystack, needle)"}, + ) def test_cte(self): self.validate_identity("WITH 'x' AS foo SELECT foo") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index b2f4676..f1144ce 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -950,40 +950,40 @@ class TestDialect(Validator): }, ) self.validate_all( - "POSITION(' ' in x)", + "POSITION(needle in haystack)", write={ - "drill": "STRPOS(x, ' ')", - "duckdb": "STRPOS(x, ' ')", - "postgres": "STRPOS(x, ' ')", - "presto": "STRPOS(x, ' ')", - "spark": "LOCATE(' ', x)", - "clickhouse": "position(x, ' ')", - "snowflake": "POSITION(' ', x)", - "mysql": "LOCATE(' ', x)", + "drill": "STRPOS(haystack, needle)", + "duckdb": "STRPOS(haystack, needle)", + "postgres": "STRPOS(haystack, needle)", + "presto": "STRPOS(haystack, needle)", + "spark": "LOCATE(needle, haystack)", + "clickhouse": "position(haystack, needle)", + "snowflake": "POSITION(needle, haystack)", + "mysql": "LOCATE(needle, haystack)", }, ) self.validate_all( - "STR_POSITION(x, 'a')", + "STR_POSITION(haystack, needle)", write={ - "drill": "STRPOS(x, 'a')", - "duckdb": "STRPOS(x, 'a')", - "postgres": "STRPOS(x, 'a')", - "presto": "STRPOS(x, 'a')", - "spark": "LOCATE('a', x)", - "clickhouse": "position(x, 'a')", - "snowflake": "POSITION('a', x)", - "mysql": "LOCATE('a', x)", + "drill": "STRPOS(haystack, needle)", + "duckdb": "STRPOS(haystack, needle)", + "postgres": "STRPOS(haystack, needle)", + "presto": "STRPOS(haystack, needle)", + "spark": "LOCATE(needle, haystack)", + "clickhouse": "position(haystack, needle)", + "snowflake": "POSITION(needle, haystack)", + "mysql": "LOCATE(needle, haystack)", }, ) self.validate_all( - "POSITION('a', x, 3)", + "POSITION(needle, haystack, pos)", write={ - "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", - "presto": "STRPOS(x, 'a', 3)", - "spark": "LOCATE('a', x, 3)", - "clickhouse": "position(x, 'a', 3)", - "snowflake": "POSITION('a', x, 3)", - "mysql": "LOCATE('a', x, 3)", + "drill": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1", + "presto": "STRPOS(haystack, needle, pos)", + "spark": "LOCATE(needle, haystack, pos)", + "clickhouse": "position(haystack, needle, pos)", + "snowflake": "POSITION(needle, haystack, pos)", + "mysql": "LOCATE(needle, haystack, pos)", }, ) self.validate_all( @@ -1365,3 +1365,19 @@ SELECT "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *", }, ) + self.validate_all( + """ + MERGE a b USING c d ON b.id = d.id + WHEN MATCHED AND EXISTS ( + SELECT b.name + EXCEPT + SELECT d.name + ) + THEN UPDATE SET b.name = d.name + """, + write={ + "bigquery": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT DISTINCT SELECT d.name) THEN UPDATE SET b.name = d.name", + "snowflake": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name", + "spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index dfd2f8e..ce865e1 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -75,6 +75,15 @@ class TestMySQL(Validator): "spark": "CAST(x AS TEXT) + CAST(y AS TEXT)", }, ) + self.validate_all( + "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)", + read={ + "mysql": "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)", + }, + write={ + "spark": "CAST(x AS BLOB) + CAST(y AS BLOB)", + }, + ) def test_canonical_functions(self): self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 0e9ce9b..7bac166 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -12,6 +12,24 @@ class TestSnowflake(Validator): "snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'", }, ) + self.validate_all( + "SELECT * EXCLUDE a, b FROM xxx", + write={ + "snowflake": "SELECT * EXCLUDE (a, b) FROM xxx", + }, + ) + self.validate_all( + "SELECT * RENAME a AS b, c AS d FROM xxx", + write={ + "snowflake": "SELECT * RENAME (a AS b, c AS d) FROM xxx", + }, + ) + self.validate_all( + "SELECT * EXCLUDE a, b RENAME (c AS d, E as F) FROM xxx", + write={ + "snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx", + }, + ) self.validate_all( 'x:a:"b c"', write={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d2972ca..4224a1e 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,3 +1,4 @@ +from sqlglot import exp, parse, parse_one from tests.dialects.test_dialect import Validator @@ -5,6 +6,10 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity("SELECT CASE WHEN a > 1 THEN b END") + self.validate_identity("END") + self.validate_identity("@x") + self.validate_identity("#x") self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") self.validate_identity("PRINT @TestVariable") self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") @@ -87,6 +92,95 @@ class TestTSQL(Validator): }, ) + def test_udf(self): + self.validate_identity( + "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" + ) + self.validate_identity( + "CREATE PROC foo @ID INTEGER, @AGE INTEGER AS SELECT DB_NAME(@ID) AS ThatDB" + ) + self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz") + self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz") + self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1") + self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER") + + # The following two cases don't necessarily correspond to valid TSQL, but they are used to verify + # that the syntax RETURNS @return_variable TABLE ... is parsed correctly. + # + # See also "Transact-SQL Multi-Statement Table-Valued Function Syntax" + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16 + self.validate_identity( + "CREATE FUNCTION foo(@bar INTEGER) RETURNS @foo TABLE (x INTEGER, y NUMERIC) AS RETURN SELECT 1" + ) + self.validate_identity( + "CREATE FUNCTION foo() RETURNS @contacts TABLE (first_name VARCHAR(50), phone VARCHAR(25)) AS SELECT @fname, @phone" + ) + + self.validate_all( + """ + CREATE FUNCTION udfProductInYear ( + @model_year INT + ) + RETURNS TABLE + AS + RETURN + SELECT + product_name, + model_year, + list_price + FROM + production.products + WHERE + model_year = @model_year + """, + write={ + "tsql": """CREATE FUNCTION udfProductInYear( + @model_year INTEGER +) +RETURNS TABLE AS +RETURN SELECT + product_name, + model_year, + list_price +FROM production.products +WHERE + model_year = @model_year""", + }, + pretty=True, + ) + + sql = """ + CREATE procedure [TRANSF].[SP_Merge_Sales_Real] + @Loadid INTEGER + ,@NumberOfRows INTEGER + AS + BEGIN + SET XACT_ABORT ON; + + DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104); + DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104); + DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER); + DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER); + + DECLARE @SalesAmountBefore float; + SELECT @SalesAmountBefore=SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] S; + END + """ + + expected_sqls = [ + 'CREATE PROCEDURE "TRANSF"."SP_Merge_Sales_Real" @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON', + "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)", + "DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)", + "DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)", + "DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)", + "DECLARE @SalesAmountBefore float", + 'SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF."Pre_Merge_Sales_Real" AS S', + "END", + ] + + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + def test_charindex(self): self.validate_all( "CHARINDEX(x, y, 9)", @@ -472,3 +566,51 @@ class TestTSQL(Validator): "EOMONTH(GETDATE(), -1)", write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, ) + + def test_variables(self): + # In TSQL @, # can be used as a prefix for variables/identifiers + expr = parse_one("@x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + + expr = parse_one("#x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + + def test_system_time(self): + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo'""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo' AS alias", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo' AS alias""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME FROM c TO d", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME FROM c TO d""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME BETWEEN c AND d", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME BETWEEN c AND d""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME CONTAINED IN (c, d)", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME CONTAINED IN (c, d)""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME ALL AS alias", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 4e21d2b..d52b417 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -94,8 +94,8 @@ CONCAT_WS('-', 'a', 'b') CONCAT_WS('-', 'a', 'b', 'c') POSEXPLODE("x") AS ("a", "b") POSEXPLODE("x") AS ("a", "b", "c") -STR_POSITION(x, 'a') -STR_POSITION(x, 'a', 3) +STR_POSITION(haystack, needle) +STR_POSITION(haystack, needle, pos) LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1) SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)] x[ORDINAL(1)][SAFE_OFFSET(2)] @@ -375,12 +375,16 @@ SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x SELECT * FROM (SELECT 1 UNION ALL SELECT 2) SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b) SELECT * FROM ((SELECT 1) AS a(b)) +SELECT * FROM ((SELECT 1) UNION (SELECT 2) UNION (SELECT 3)) SELECT * FROM x AS y(a, b) SELECT * EXCEPT (a, b) +SELECT * EXCEPT (a, b) FROM y SELECT * REPLACE (a AS b, b AS C) SELECT * REPLACE (a + 1 AS b, b AS C) SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) +SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) FROM y SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) +SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) FROM x SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals) SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals) WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2 @@ -438,6 +442,8 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLO SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) +SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) +SELECT PERCENTILE_DISC(0.5) WITHIN GROUP (ORDER BY x) SELECT SUM(x) FILTER(WHERE x > 1) SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y) SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) @@ -611,6 +617,7 @@ WITH a AS (SELECT * FROM b) DELETE FROM a WITH a AS (SELECT * FROM b) CACHE TABLE a SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? SELECT :hello, ? FROM x LIMIT :my_limit +SELECT * FROM x FETCH NEXT @take ROWS ONLY OFFSET @skip WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z @@ -670,3 +677,17 @@ CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY) CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10)) +ALTER TABLE "schema"."tablename" ADD CONSTRAINT "CHK_Name" CHECK (NOT "IdDwh" IS NULL AND "IdDwh" <> (0)) +ALTER TABLE persons ADD CONSTRAINT persons_pk PRIMARY KEY (first_name, last_name) +ALTER TABLE pets ADD CONSTRAINT pets_persons_fk FOREIGN KEY (owner_first_name, owner_last_name) REFERENCES persons +ALTER TABLE pets ADD CONSTRAINT pets_name_not_cute_chk CHECK (LENGTH(name) < 20) +ALTER TABLE people10m ADD CONSTRAINT dateWithinRange CHECK (birthDate > '1900-01-01') +ALTER TABLE people10m ADD CONSTRAINT validIds CHECK (id > 1 AND id < 99999999) ENFORCED +ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE INITIALLY DEFERRED NORELY +ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE INITIALLY DEFERRED NORELY +ALTER TABLE baa ADD CONSTRAINT boo FOREIGN KEY (x, y) REFERENCES persons ON UPDATE NO ACTION ON DELETE NO ACTION MATCH FULL +ALTER TABLE a ADD PRIMARY KEY (x, y) NOT ENFORCED +ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla +CREATE TABLE foo (baz_id INT REFERENCES baz(id) DEFERRABLE) +SELECT end FROM a +SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1 diff --git a/tests/fixtures/optimizer/isolate_table_selects.sql b/tests/fixtures/optimizer/isolate_table_selects.sql index 3b9a938..93c0f7c 100644 --- a/tests/fixtures/optimizer/isolate_table_selects.sql +++ b/tests/fixtures/optimizer/isolate_table_selects.sql @@ -18,3 +18,6 @@ WITH y AS (SELECT *) SELECT * FROM x AS x; WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y; WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y; + +SELECT * FROM x AS x JOIN xx AS y; +SELECT * FROM (SELECT * FROM x AS x) AS x JOIN xx AS y; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index b9f6c3f..03ecf16 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -2,7 +2,7 @@ SELECT a FROM (SELECT * FROM x); SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; SELECT 1 FROM (SELECT * FROM x) WHERE b = 2; -SELECT 1 AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2; +SELECT 1 AS "1" FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2; SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q; SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 9c5a0be..ee041e2 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -4,6 +4,14 @@ SELECT a FROM x; SELECT x.a AS a FROM x AS x; +# execute: false +SELECT a FROM zz GROUP BY a ORDER BY a; +SELECT zz.a AS a FROM zz AS zz GROUP BY zz.a ORDER BY a; + +# execute: false +SELECT x, p FROM (SELECT x from xx) xx CROSS JOIN yy; +SELECT xx.x AS x, yy.p AS p FROM (SELECT xx.x AS x FROM xx AS xx) AS xx CROSS JOIN yy AS yy; + SELECT a FROM x AS z; SELECT z.a AS a FROM x AS z; @@ -20,8 +28,8 @@ SELECT a AS b FROM x; SELECT x.a AS b FROM x AS x; # execute: false -SELECT 1, 2 FROM x; -SELECT 1 AS _col_0, 2 AS _col_1 FROM x AS x; +SELECT 1, 2 + 3 FROM x; +SELECT 1 AS "1", 2 + 3 AS _col_1 FROM x AS x; # execute: false SELECT a + b FROM x; @@ -57,6 +65,10 @@ SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a; SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2; SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b); +# execute: false +SELECT CAST(a AS INT) FROM x ORDER BY a; +SELECT CAST(x.a AS INT) AS a FROM x AS x ORDER BY a; + # execute: false SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2; SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b); diff --git a/tests/fixtures/optimizer/qualify_columns__invalid.sql b/tests/fixtures/optimizer/qualify_columns__invalid.sql index 1104b6e..2a3ccfb 100644 --- a/tests/fixtures/optimizer/qualify_columns__invalid.sql +++ b/tests/fixtures/optimizer/qualify_columns__invalid.sql @@ -1,4 +1,3 @@ -SELECT a FROM zz; SELECT * FROM zz; SELECT z.a FROM x; SELECT z.* FROM x; @@ -11,3 +10,4 @@ SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q. SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c; SELECT x.a FROM x JOIN y USING (a); SELECT a, SUM(b) FROM x GROUP BY 3; +SELECT p FROM (SELECT x from xx) y CROSS JOIN yy CROSS JOIN zz diff --git a/tests/test_executor.py b/tests/test_executor.py index f45a5d4..013ff34 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -481,11 +481,11 @@ class TestExecutor(unittest.TestCase): def test_static_queries(self): for sql, cols, rows in [ - ("SELECT 1", ["_col_0"], [(1,)]), + ("SELECT 1", ["1"], [(1,)]), ("SELECT 1 + 2 AS x", ["x"], [(3,)]), ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), - ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), + ("SELECT 'foo' LIMIT 1", ["foo"], [("foo",)]), ( "SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0", "_col_1"], diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 9e5f988..2d5407e 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -189,6 +189,27 @@ class TestExpressions(unittest.TestCase): "SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100", ) + def test_function_building(self): + self.assertEqual(exp.func("bla", 1, "foo").sql(), "BLA(1, 'foo')") + self.assertEqual(exp.func("COUNT", exp.Star()).sql(), "COUNT(*)") + self.assertEqual(exp.func("bloo").sql(), "BLOO()") + self.assertEqual( + exp.func("locate", "x", "xo", dialect="hive").sql("hive"), "LOCATE('x', 'xo')" + ) + + self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition) + self.assertIsInstance(exp.func("bla", 1, "foo"), exp.Anonymous) + self.assertIsInstance( + exp.func("cast", this=exp.Literal.number(5), to=exp.DataType.build("DOUBLE")), + exp.Cast, + ) + + with self.assertRaises(ValueError): + exp.func("some_func", 1, arg2="foo") + + with self.assertRaises(ValueError): + exp.func("abs") + def test_named_selects(self): expression = parse_one( "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" diff --git a/tests/test_lineage.py b/tests/test_lineage.py new file mode 100644 index 0000000..7a48605 --- /dev/null +++ b/tests/test_lineage.py @@ -0,0 +1,20 @@ +import unittest + +from sqlglot.lineage import lineage + + +class TestLineage(unittest.TestCase): + maxDiff = None + + def test_lineage(self) -> None: + node = lineage( + "a", + "SELECT a FROM y", + schema={"x": {"a": "int"}}, + sources={"y": "SELECT * FROM x"}, + ) + self.assertEqual( + node.source.sql(), + "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */", + ) + self.assertGreater(len(node.to_html()._repr_html_()), 1000) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index af21679..360dfb5 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -117,6 +117,7 @@ class TestOptimizer(unittest.TestCase): self.check_file( "isolate_table_selects", optimizer.isolate_table_selects.isolate_table_selects, + schema=self.schema, ) def test_qualify_tables(self): diff --git a/tests/test_schema.py b/tests/test_schema.py index 3dd9103..dc7e5b2 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -17,6 +17,11 @@ class TestSchema(unittest.TestCase): with self.assertRaises(SchemaError): schema.column_names(to_table(table)) + def assert_column_names_empty(self, schema, *tables): + for table in tables: + with self.subTest(table): + self.assertEqual(schema.column_names(to_table(table)), []) + def test_schema(self): schema = ensure_schema( { @@ -38,7 +43,7 @@ class TestSchema(unittest.TestCase): ("z.x.y", ["b", "c"]), ) - self.assert_column_names_raises( + self.assert_column_names_empty( schema, "z", "z.z", @@ -76,6 +81,10 @@ class TestSchema(unittest.TestCase): self.assert_column_names_raises( schema, "x", + ) + + self.assert_column_names_empty( + schema, "z.x", "z.y", ) @@ -129,12 +138,16 @@ class TestSchema(unittest.TestCase): self.assert_column_names_raises( schema, - "q", - "d2.x", "y", "z", "d1.y", "d1.z", + ) + + self.assert_column_names_empty( + schema, + "q", + "d2.x", "a.b.c", ) -- cgit v1.2.3