summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-30 17:08:37 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-30 17:08:37 +0000
commitbe1cb18ea28222fca384a5459a024b7e9af5cadb (patch)
tree4698c9069380a7c30ceb51129f93f6c8662315e4 /sqlglot
parentReleasing debian version 10.5.6-1. (diff)
downloadsqlglot-be1cb18ea28222fca384a5459a024b7e9af5cadb.tar.xz
sqlglot-be1cb18ea28222fca384a5459a024b7e9af5cadb.zip
Merging upstream version 10.5.10.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py15
-rw-r--r--sqlglot/dataframe/README.md32
-rw-r--r--sqlglot/dataframe/sql/dataframe.py9
-rw-r--r--sqlglot/dataframe/sql/session.py4
-rw-r--r--sqlglot/dialects/__init__.py61
-rw-r--r--sqlglot/dialects/bigquery.py1
-rw-r--r--sqlglot/dialects/clickhouse.py9
-rw-r--r--sqlglot/dialects/mysql.py4
-rw-r--r--sqlglot/dialects/snowflake.py8
-rw-r--r--sqlglot/dialects/tsql.py116
-rw-r--r--sqlglot/diff.py12
-rw-r--r--sqlglot/executor/context.py2
-rw-r--r--sqlglot/expressions.py410
-rw-r--r--sqlglot/generator.py67
-rw-r--r--sqlglot/helper.py2
-rw-r--r--sqlglot/lineage.py228
-rw-r--r--sqlglot/optimizer/__init__.py1
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py7
-rw-r--r--sqlglot/optimizer/qualify_columns.py54
-rw-r--r--sqlglot/optimizer/scope.py4
-rw-r--r--sqlglot/parser.py272
-rw-r--r--sqlglot/planner.py4
-rw-r--r--sqlglot/schema.py7
-rw-r--r--sqlglot/tokens.py9
24 files changed, 1063 insertions, 275 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index f2db4f1..67a4463 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -1,5 +1,6 @@
"""
.. include:: ../README.md
+----
"""
from __future__ import annotations
@@ -29,14 +30,16 @@ from sqlglot.expressions import table_ as table
from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
-from sqlglot.schema import MappingSchema
+from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.5.6"
+__version__ = "10.5.10"
pretty = False
+"""Whether to format generated SQL by default."""
schema = MappingSchema()
+"""The default schema used by SQLGlot (e.g. in the optimizer)."""
def parse(
@@ -48,7 +51,7 @@ def parse(
Args:
sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
- **opts: other options.
+ **opts: other `sqlglot.parser.Parser` options.
Returns:
The resulting syntax tree collection.
@@ -60,7 +63,7 @@ def parse(
def parse_one(
sql: str,
read: t.Optional[str | Dialect] = None,
- into: t.Optional[t.Type[Expression] | str] = None,
+ into: t.Optional[exp.IntoType] = None,
**opts,
) -> Expression:
"""
@@ -70,7 +73,7 @@ def parse_one(
sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
into: the SQLGlot Expression to parse into.
- **opts: other options.
+ **opts: other `sqlglot.parser.Parser` options.
Returns:
The syntax tree for the first parsed statement.
@@ -110,7 +113,7 @@ def transpile(
identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
the source and the target dialect.
error_level: the desired error level of the parser.
- **opts: other options.
+ **opts: other `sqlglot.generator.Generator` options.
Returns:
The list of transpiled SQL statements.
diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md
index 54d3856..02179f4 100644
--- a/sqlglot/dataframe/README.md
+++ b/sqlglot/dataframe/README.md
@@ -1,29 +1,29 @@
# PySpark DataFrame SQL Generator
-This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
+This is a drop-in replacement for the PySpark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
# How to use
## Instructions
-* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library
-* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`
-* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`
+* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
+* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
+* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
* The column structure can be defined the following ways:
- * Dictionary where the keys are column names and values are string of the Spark SQL type name
- * Ex: {'cola': 'string', 'colb': 'int'}
- * PySpark DataFrame `StructType` similar to when using `createDataFrame`
+ * Dictionary where the keys are column names and values are string of the Spark SQL type name.
+ * Ex: `{'cola': 'string', 'colb': 'int'}`
+ * PySpark DataFrame `StructType` similar to when using `createDataFrame`.
* Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
- * A string of names and types similar to what is supported in `createDataFrame`
+ * A string of names and types similar to what is supported in `createDataFrame`.
* Ex: `cola: STRING, colb: INT`
- * [Not Recommended] A list of string column names without type
- * Ex: ['cola', 'colb']
- * The lack of types may limit functionality in future releases
- * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally
-* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command
+ * [Not Recommended] A list of string column names without type.
+ * Ex: `['cola', 'colb']`
+ * The lack of types may limit functionality in future releases.
+ * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally.
+* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command.
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
- * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects
+ * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
* Ex: `.sql(pretty=True, dialect='bigquery')`
## Examples
@@ -51,7 +51,7 @@ df = (
print(df.sql(pretty=True)) # Spark will be the dialect used by default
```
-Output:
+
```sparksql
SELECT
`employee`.`age` AS `age`,
@@ -206,7 +206,7 @@ sql_statements = (
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
- .sql(dialect="bigquery")
+ .sql(dialect="spark")
)
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index a17bb9d..65a37f5 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -111,16 +111,13 @@ class DataFrame:
return DataFrameNaFunctions(self)
def _replace_cte_names_with_hashes(self, expression: exp.Select):
- expression = expression.copy()
- ctes = expression.ctes
replacement_mapping = {}
- for cte in ctes:
+ for cte in expression.ctes:
old_name_id = cte.args["alias"].this
new_hashed_id = exp.to_identifier(
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
)
replacement_mapping[old_name_id] = new_hashed_id
- cte.set("alias", exp.TableAlias(this=new_hashed_id))
expression = expression.transform(replace_id_value, replacement_mapping)
return expression
@@ -183,7 +180,7 @@ class DataFrame:
expression = df.expression
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
for hint in df.pending_partition_hints:
- hint_expression.args.get("expressions").append(hint)
+ hint_expression.append("expressions", hint)
df.pending_hints.remove(hint)
join_aliases = {
@@ -209,7 +206,7 @@ class DataFrame:
sequence_id_expression.set("this", matching_cte.args["alias"].this)
df.pending_hints.remove(hint)
break
- hint_expression.args.get("expressions").append(hint)
+ hint_expression.append("expressions", hint)
if hint_expression.expressions:
expression.set("hint", hint_expression)
return df
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index c4a22c6..af589b0 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -129,7 +129,7 @@ class SparkSession:
@property
def _random_name(self) -> str:
- return f"a{str(uuid.uuid4())[:8]}"
+ return "r" + uuid.uuid4().hex
@property
def _random_branch_id(self) -> str:
@@ -145,7 +145,7 @@ class SparkSession:
@property
def _random_id(self) -> str:
- id = f"a{str(uuid.uuid4())[:8]}"
+ id = self._random_name
self.known_ids.add(id)
return id
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 2084681..34cf613 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -1,3 +1,64 @@
+"""
+## Dialects
+
+One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a
+"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`,
+`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming
+SQL code.
+
+However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such
+example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to
+override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new
+dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class),
+in order to implement the target behavior.
+
+
+### Implementing a custom Dialect
+
+Consider the following example:
+
+```python
+from sqlglot import exp
+from sqlglot.dialects.dialect import Dialect
+from sqlglot.generator import Generator
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+class Custom(Dialect):
+ class Tokenizer(Tokenizer):
+ QUOTES = ["'", '"']
+ IDENTIFIERS = ["`"]
+
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "INT64": TokenType.BIGINT,
+ "FLOAT64": TokenType.DOUBLE,
+ }
+
+ class Generator(Generator):
+ TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
+
+ TYPE_MAPPING = {
+ exp.DataType.Type.TINYINT: "INT64",
+ exp.DataType.Type.SMALLINT: "INT64",
+ exp.DataType.Type.INT: "INT64",
+ exp.DataType.Type.BIGINT: "INT64",
+ exp.DataType.Type.DECIMAL: "NUMERIC",
+ exp.DataType.Type.FLOAT: "FLOAT64",
+ exp.DataType.Type.DOUBLE: "FLOAT64",
+ exp.DataType.Type.BOOLEAN: "BOOL",
+ exp.DataType.Type.TEXT: "STRING",
+ }
+```
+
+This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string
+delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since
+the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation
+logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping.
+
+----
+"""
+
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 9ddfbea..e7d30ec 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -124,7 +124,6 @@ class BigQuery(Dialect):
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"NOT DETERMINISTIC": TokenType.VOLATILE,
- "QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
}
KEYWORDS.pop("DIV")
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 1c173a4..9e8c691 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -73,13 +73,8 @@ class ClickHouse(Dialect):
return this
- def _parse_position(self) -> exp.Expression:
- this = super()._parse_position()
- # clickhouse position args are swapped
- substr = this.this
- this.args["this"] = this.args.get("substr")
- this.args["substr"] = substr
- return this
+ def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
+ return super()._parse_position(haystack_first=True)
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.Expression:
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 1bddfe1..2a0a917 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -124,6 +124,8 @@ class MySQL(Dialect):
**tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"LONGTEXT": TokenType.LONGTEXT,
+ "MEDIUMBLOB": TokenType.MEDIUMBLOB,
+ "LONGBLOB": TokenType.LONGBLOB,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
@@ -459,6 +461,8 @@ class MySQL(Dialect):
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
+ TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
+ TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index c44950a..6225a53 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -194,7 +194,8 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- "QUALIFY": TokenType.QUALIFY,
+ "EXCLUDE": TokenType.EXCEPT,
+ "RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@@ -232,6 +233,11 @@ class Snowflake(Dialect):
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
+ STAR_MAPPING = {
+ "except": "EXCLUDE",
+ "replace": "RENAME",
+ }
+
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.ReturnsProperty,
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 9342e6b..9f9099e 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import re
+import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
@@ -251,6 +252,7 @@ class TSQL(Dialect):
"NTEXT": TokenType.TEXT,
"NVARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND,
+ "PROC": TokenType.PROCEDURE,
"REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION,
"SMALLDATETIME": TokenType.DATETIME,
@@ -263,6 +265,11 @@ class TSQL(Dialect):
"XML": TokenType.XML,
}
+ # TSQL allows @, # to appear as a variable/identifier prefix
+ SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
+ SINGLE_TOKENS.pop("@")
+ SINGLE_TOKENS.pop("#")
+
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
@@ -293,26 +300,82 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
- # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
- TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
+ RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
+ TokenType.TABLE,
+ *parser.Parser.TYPE_TOKENS, # type: ignore
+ }
+
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ TokenType.END: lambda self: self._parse_command(),
+ }
+
+ def _parse_system_time(self) -> t.Optional[exp.Expression]:
+ if not self._match_text_seq("FOR", "SYSTEM_TIME"):
+ return None
- def _parse_convert(self, strict):
+ if self._match_text_seq("AS", "OF"):
+ system_time = self.expression(
+ exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
+ )
+ elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
+ kind = self._prev.text
+ this = self._parse_bitwise()
+ self._match_texts(("TO", "AND"))
+ expression = self._parse_bitwise()
+ system_time = self.expression(
+ exp.SystemTime, this=this, expression=expression, kind=kind
+ )
+ elif self._match_text_seq("CONTAINED", "IN"):
+ args = self._parse_wrapped_csv(self._parse_bitwise)
+ system_time = self.expression(
+ exp.SystemTime,
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ kind="CONTAINED IN",
+ )
+ elif self._match(TokenType.ALL):
+ system_time = self.expression(exp.SystemTime, kind="ALL")
+ else:
+ system_time = None
+ self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
+
+ return system_time
+
+ def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ table = super()._parse_table_parts(schema=schema)
+ table.set("system_time", self._parse_system_time())
+ return table
+
+ def _parse_returns(self) -> exp.Expression:
+ table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
+ returns = super()._parse_returns()
+ returns.set("table", table)
+ return returns
+
+ def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_conjunction()
+ if not to or not this:
+ return None
+
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
- format_val = self._parse_number().name
- if format_val not in TSQL.convert_format_mapping:
+ format_val = self._parse_number()
+ format_val_name = format_val.name if format_val else ""
+
+ if format_val_name not in TSQL.convert_format_mapping:
raise ValueError(
- f"CONVERT function at T-SQL does not support format style {format_val}"
+ f"CONVERT function at T-SQL does not support format style {format_val_name}"
)
- format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
+
+ format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
# Check whether the convert entails a string to date format
if to.this == DataType.Type.DATE:
@@ -333,6 +396,21 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_user_defined_function(
+ self, kind: t.Optional[TokenType] = None
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_user_defined_function(kind=kind)
+
+ if (
+ kind == TokenType.FUNCTION
+ or isinstance(this, exp.UserDefinedFunction)
+ or self._match(TokenType.ALIAS, advance=False)
+ ):
+ return this
+
+ expressions = self._parse_csv(self._parse_udf_kwarg)
+ return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -354,3 +432,27 @@ class TSQL(Dialect):
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
}
+
+ TRANSFORMS.pop(exp.ReturnsProperty)
+
+ def systemtime_sql(self, expression: exp.SystemTime) -> str:
+ kind = expression.args["kind"]
+ if kind == "ALL":
+ return "FOR SYSTEM_TIME ALL"
+
+ start = self.sql(expression, "this")
+ if kind == "AS OF":
+ return f"FOR SYSTEM_TIME AS OF {start}"
+
+ end = self.sql(expression, "expression")
+ if kind == "FROM":
+ return f"FOR SYSTEM_TIME FROM {start} TO {end}"
+ if kind == "BETWEEN":
+ return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
+
+ return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
+
+ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
+ table = expression.args.get("table")
+ table = f"{table} " if table else ""
+ return f"RETURNS {table}{self.sql(expression, 'this')}"
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index fa8bc1b..a5373b0 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -1,5 +1,6 @@
"""
.. include:: ../posts/sql_diff.md
+----
"""
from __future__ import annotations
@@ -75,12 +76,13 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
]
Args:
- source (sqlglot.Expression): the source expression.
- target (sqlglot.Expression): the target expression against which the diff should be calculated.
+ source: the source expression.
+ target: the target expression against which the diff should be calculated.
Returns:
- the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees.
- This list represents a sequence of steps needed to transform the source expression tree into the target one.
+ the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
+ target expression trees. This list represents a sequence of steps needed to transform the source
+ expression tree into the target one.
"""
return ChangeDistiller().diff(source.copy(), target.copy())
@@ -258,7 +260,7 @@ class ChangeDistiller:
return bigram_histo
-def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
+def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
for a in expression.args.values():
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 8a58287..c405c45 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -63,7 +63,7 @@ class Context:
reader = table[i]
yield reader, self
- def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
+ def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index be99fe2..f9751ca 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,5 +1,12 @@
"""
-.. include:: ../pdoc/docs/expressions.md
+## Expressions
+
+Every AST node in SQLGlot is represented by a subclass of `Expression`.
+
+This module contains the implementation of all supported `Expression` types. Additionally,
+it exposes a number of helper functions, which are mainly used to programmatically build
+SQL expressions, such as `sqlglot.expressions.select`.
+----
"""
from __future__ import annotations
@@ -27,35 +34,66 @@ from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
+ IntoType = t.Union[
+ str,
+ t.Type[Expression],
+ t.Collection[t.Union[str, t.Type[Expression]]],
+ ]
+
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
+
+ # When an Expression class is created, its key is automatically set to be
+ # the lowercase version of the class' name.
klass.key = clsname.lower()
+
+ # This is so that docstrings are not inherited in pdoc
+ klass.__doc__ = klass.__doc__ or ""
+
return klass
class Expression(metaclass=_Expression):
"""
- The base class for all expressions in a syntax tree.
+ The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary
+ context, such as its child expressions, their names (arg keys), and whether a given child expression
+ is optional or not.
Attributes:
- arg_types (dict): determines arguments supported by this expression.
- The key in a dictionary defines a unique key of an argument using
- which the argument's value can be retrieved. The value is a boolean
- flag which indicates whether the argument's value is required (True)
- or optional (False).
+ key: a unique key for each class in the Expression hierarchy. This is useful for hashing
+ and representing expressions as strings.
+ arg_types: determines what arguments (child nodes) are supported by an expression. It
+ maps arg keys to booleans that indicate whether the corresponding args are optional.
+
+ Example:
+ >>> class Foo(Expression):
+ ... arg_types = {"this": True, "expression": False}
+
+ The above definition informs us that Foo is an Expression that requires an argument called
+ "this" and may also optionally receive an argument called "expression".
+
+ Args:
+ args: a mapping used for retrieving the arguments of an expression, given their arg keys.
+ parent: a reference to the parent expression (or None, in case of root expressions).
+ arg_key: the arg key an expression is associated with, i.e. the name its parent expression
+ uses to refer to it.
+ comments: a list of comments that are associated with a given expression. This is used in
+ order to preserve comments when transpiling SQL code.
+ _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
+ optimizer, in order to enable some transformations that require type information.
"""
- key = "Expression"
+ key = "expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
- def __init__(self, **args):
- self.args = args
- self.parent = None
- self.arg_key = None
- self.comments = None
+ def __init__(self, **args: t.Any):
+ self.args: t.Dict[str, t.Any] = args
+ self.parent: t.Optional[Expression] = None
+ self.arg_key: t.Optional[str] = None
+ self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items():
@@ -76,17 +114,30 @@ class Expression(metaclass=_Expression):
@property
def this(self):
+ """
+ Retrieves the argument with key "this".
+ """
return self.args.get("this")
@property
def expression(self):
+ """
+ Retrieves the argument with key "expression".
+ """
return self.args.get("expression")
@property
def expressions(self):
+ """
+ Retrieves the argument with key "expressions".
+ """
return self.args.get("expressions") or []
def text(self, key):
+ """
+ Returns a textual representation of the argument corresponding to "key". This can only be used
+ for args that are strings or leaf Expression instances, such as identifiers and literals.
+ """
field = self.args.get(key)
if isinstance(field, str):
return field
@@ -96,14 +147,23 @@ class Expression(metaclass=_Expression):
@property
def is_string(self):
+ """
+ Checks whether a Literal expression is a string.
+ """
return isinstance(self, Literal) and self.args["is_string"]
@property
def is_number(self):
+ """
+ Checks whether a Literal expression is a number.
+ """
return isinstance(self, Literal) and not self.args["is_string"]
@property
def is_int(self):
+ """
+ Checks whether a Literal expression is an integer.
+ """
if self.is_number:
try:
int(self.name)
@@ -114,6 +174,9 @@ class Expression(metaclass=_Expression):
@property
def alias(self):
+ """
+ Returns the alias of the expression, or an empty string if it's not aliased.
+ """
if isinstance(self.args.get("alias"), TableAlias):
return self.args["alias"].name
return self.text("alias")
@@ -129,6 +192,24 @@ class Expression(metaclass=_Expression):
return self.alias or self.name
@property
+ def output_name(self):
+ """
+ Name of the output column if this expression is a selection.
+
+ If the Expression has no output name, an empty string is returned.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> parse_one("SELECT a").expressions[0].output_name
+ 'a'
+ >>> parse_one("SELECT b AS c").expressions[0].output_name
+ 'c'
+ >>> parse_one("SELECT 1 + 2").expressions[0].output_name
+ ''
+ """
+ return ""
+
+ @property
def type(self) -> t.Optional[DataType]:
return self._type
@@ -145,6 +226,9 @@ class Expression(metaclass=_Expression):
return copy
def copy(self):
+ """
+ Returns a deep copy of the expression.
+ """
new = deepcopy(self)
for item, parent, _ in new.bfs():
if isinstance(item, Expression) and parent:
@@ -169,7 +253,7 @@ class Expression(metaclass=_Expression):
Sets `arg_key` to `value`.
Args:
- arg_key (str): name of the expression arg
+ arg_key (str): name of the expression arg.
value: value to set the arg to.
"""
self.args[arg_key] = value
@@ -203,8 +287,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
- the node which matches the criteria or None if no node matching
- the criteria was found.
+ The node which matches the criteria or None if no such node was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
@@ -217,7 +300,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
- the generator object.
+ The generator object.
"""
for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
@@ -231,7 +314,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
- the parent node
+ The parent node.
"""
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
@@ -269,7 +352,7 @@ class Expression(metaclass=_Expression):
the DFS (Depth-first) order.
Returns:
- the generator object.
+ The generator object.
"""
parent = parent or self.parent
yield self, parent, key
@@ -287,7 +370,7 @@ class Expression(metaclass=_Expression):
the BFS (Breadth-first) order.
Returns:
- the generator object.
+ The generator object.
"""
queue = deque([(self, self.parent, None)])
@@ -341,32 +424,33 @@ class Expression(metaclass=_Expression):
return self.sql()
def __repr__(self):
- return self.to_s()
+ return self._to_s()
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
- Args
- dialect (str): the dialect of the output SQL string
- (eg. "spark", "hive", "presto", "mysql").
- opts (dict): other :class:`~sqlglot.generator.Generator` options.
+ Args:
+ dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql").
+ opts: other `sqlglot.generator.Generator` options.
- Returns
- the SQL string.
+ Returns:
+ The SQL string.
"""
from sqlglot.dialects import Dialect
return Dialect.get_or_raise(dialect)().generate(self, **opts)
- def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
+ def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
args: t.Dict[str, t.Any] = {
k: ", ".join(
- v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
+ v._to_s(hide_missing=hide_missing, level=level + 1)
+ if hasattr(v, "_to_s")
+ else str(v)
for v in ensure_collection(vs)
if v is not None
)
@@ -394,7 +478,7 @@ class Expression(metaclass=_Expression):
modified in place.
Returns:
- the transformed tree.
+ The transformed tree.
"""
node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs)
@@ -423,8 +507,8 @@ class Expression(metaclass=_Expression):
Args:
expression (Expression|None): new node
- Returns :
- the new expression or expressions
+ Returns:
+ The new expression or expressions.
"""
if not self.parent:
return expression
@@ -458,6 +542,40 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_)
return self
+ def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
+ """
+ Checks if this expression is valid (e.g. all mandatory args are set).
+
+ Args:
+ args: a sequence of values that were used to instantiate a Func expression. This is used
+ to check that the provided arguments don't exceed the function argument limit.
+
+ Returns:
+ A list of error messages for all possible errors that were found.
+ """
+ errors: t.List[str] = []
+
+ for k in self.args:
+ if k not in self.arg_types:
+ errors.append(f"Unexpected keyword: '{k}' for {self.__class__}")
+ for k, mandatory in self.arg_types.items():
+ v = self.args.get(k)
+ if mandatory and (v is None or (isinstance(v, list) and not v)):
+ errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
+
+ if (
+ args
+ and isinstance(self, Func)
+ and len(args) > len(self.arg_types)
+ and not self.is_var_len_args
+ ):
+ errors.append(
+ f"The number of provided arguments ({len(args)}) is greater than "
+ f"the maximum number of supported arguments ({len(self.arg_types)})"
+ )
+
+ return errors
+
def dump(self):
"""
Dump this Expression to a JSON-serializable dict.
@@ -552,7 +670,7 @@ class DerivedTable(Expression):
@property
def named_selects(self):
- return [select.alias_or_name for select in self.selects]
+ return [select.output_name for select in self.selects]
class Unionable(Expression):
@@ -654,6 +772,7 @@ class Create(Expression):
"no_primary_index": False,
"indexes": False,
"no_schema_binding": False,
+ "begin": False,
}
@@ -696,7 +815,7 @@ class Show(Expression):
class UserDefinedFunction(Expression):
- arg_types = {"this": True, "expressions": False}
+ arg_types = {"this": True, "expressions": False, "wrapped": False}
class UserDefinedFunctionKwarg(Expression):
@@ -750,6 +869,10 @@ class Column(Condition):
def table(self):
return self.text("table")
+ @property
+ def output_name(self):
+ return self.name
+
class ColumnDef(Expression):
arg_types = {
@@ -865,6 +988,10 @@ class ForeignKey(Expression):
}
+class PrimaryKey(Expression):
+ arg_types = {"expressions": True, "options": False}
+
+
class Unique(Expression):
arg_types = {"expressions": True}
@@ -904,6 +1031,10 @@ class Identifier(Expression):
def __hash__(self):
return hash((self.key, self.this.lower()))
+ @property
+ def output_name(self):
+ return self.name
+
class Index(Expression):
arg_types = {
@@ -996,6 +1127,10 @@ class Literal(Condition):
def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True)
+ @property
+ def output_name(self):
+ return self.name
+
class Join(Expression):
arg_types = {
@@ -1186,7 +1321,7 @@ class SchemaCommentProperty(Property):
class ReturnsProperty(Property):
- arg_types = {"this": True, "is_table": False}
+ arg_types = {"this": True, "is_table": False, "table": False}
class LanguageProperty(Property):
@@ -1262,8 +1397,13 @@ class Qualify(Expression):
pass
+# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql
+class Return(Expression):
+ pass
+
+
class Reference(Expression):
- arg_types = {"this": True, "expressions": True}
+ arg_types = {"this": True, "expressions": False, "options": False}
class Tuple(Expression):
@@ -1397,6 +1537,16 @@ class Table(Expression):
"joins": False,
"pivots": False,
"hints": False,
+ "system_time": False,
+ }
+
+
+# See the TSQL "Querying data in a system-versioned temporal table" page
+class SystemTime(Expression):
+ arg_types = {
+ "this": False,
+ "expression": False,
+ "kind": True,
}
@@ -2027,7 +2177,7 @@ class Select(Subqueryable):
@property
def named_selects(self) -> t.List[str]:
- return [e.alias_or_name for e in self.expressions if e.alias_or_name]
+ return [e.output_name for e in self.expressions if e.alias_or_name]
@property
def selects(self) -> t.List[Expression]:
@@ -2051,6 +2201,10 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this
return expression
+ @property
+ def output_name(self):
+ return self.alias
+
class TableSample(Expression):
arg_types = {
@@ -2066,6 +2220,16 @@ class TableSample(Expression):
}
+class Tag(Expression):
+ """Tags are used for generating arbitrary sql like SELECT <span>x</span>."""
+
+ arg_types = {
+ "this": False,
+ "prefix": False,
+ "postfix": False,
+ }
+
+
class Pivot(Expression):
arg_types = {
"this": False,
@@ -2106,6 +2270,10 @@ class Star(Expression):
def name(self):
return "*"
+ @property
+ def output_name(self):
+ return self.name
+
class Parameter(Expression):
pass
@@ -2143,6 +2311,8 @@ class DataType(Expression):
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
+ MEDIUMBLOB = auto()
+ LONGBLOB = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
@@ -2282,11 +2452,11 @@ class Rollback(Expression):
class AlterTable(Expression):
- arg_types = {
- "this": True,
- "actions": True,
- "exists": False,
- }
+ arg_types = {"this": True, "actions": True, "exists": False}
+
+
+class AddConstraint(Expression):
+ arg_types = {"this": False, "expression": False, "enforced": False}
# Binary expressions like (ADD a b)
@@ -2456,6 +2626,10 @@ class Neg(Unary):
class Alias(Expression):
arg_types = {"this": True, "alias": False}
+ @property
+ def output_name(self):
+ return self.alias
+
class Aliases(Expression):
arg_types = {"this": True, "expressions": True}
@@ -2523,16 +2697,13 @@ class Func(Condition):
"""
The base class for all function expressions.
- Attributes
- is_var_len_args (bool): if set to True the last argument defined in
- arg_types will be treated as a variable length argument and the
- argument's value will be stored as a list.
- _sql_names (list): determines the SQL name (1st item in the list) and
- aliases (subsequent items) for this function expression. These
- values are used to map this node to a name during parsing as well
- as to provide the function's name during SQL string generation. By
- default the SQL name is set to the expression's class name transformed
- to snake case.
+ Attributes:
+ is_var_len_args (bool): if set to True the last argument defined in arg_types will be
+ treated as a variable length argument and the argument's value will be stored as a list.
+ _sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items)
+ for this function expression. These values are used to map this node to a name during parsing
+ as well as to provide the function's name during SQL string generation. By default the SQL
+ name is set to the expression's class name transformed to snake case.
"""
is_var_len_args = False
@@ -2558,7 +2729,7 @@ class Func(Condition):
raise NotImplementedError(
"SQL name is only supported by concrete function implementations"
)
- if not hasattr(cls, "_sql_names"):
+ if "_sql_names" not in cls.__dict__:
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@@ -2658,6 +2829,10 @@ class Cast(Func):
def to(self):
return self.args["to"]
+ @property
+ def output_name(self):
+ return self.name
+
class Collate(Binary):
pass
@@ -2956,6 +3131,14 @@ class Pow(Func):
_sql_names = ["POWER", "POW"]
+class PercentileCont(AggFunc):
+ pass
+
+
+class PercentileDisc(AggFunc):
+ pass
+
+
class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
@@ -3213,12 +3396,13 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
+# Helpers
def maybe_parse(
- sql_or_expression,
+ sql_or_expression: str | Expression,
*,
- into=None,
- dialect=None,
- prefix=None,
+ into: t.Optional[IntoType] = None,
+ dialect: t.Optional[str] = None,
+ prefix: t.Optional[str] = None,
**opts,
) -> Expression:
"""Gracefully handle a possible string or expression.
@@ -3230,11 +3414,11 @@ def maybe_parse(
(IDENTIFIER this: x, quoted: False)
Args:
- sql_or_expression (str | Expression): the SQL code string or an expression
- into (Expression): the SQLGlot Expression to parse into
- dialect (str): the dialect used to parse the input expressions (in the case that an
+ sql_or_expression: the SQL code string or an expression
+ into: the SQLGlot Expression to parse into
+ dialect: the dialect used to parse the input expressions (in the case that an
input expression is a SQL string).
- prefix (str): a string to prefix the sql with before it gets parsed
+ prefix: a string to prefix the sql with before it gets parsed
(automatically includes a space)
**opts: other options to use to parse the input expressions (again, in the case
that an input expression is a SQL string).
@@ -3993,7 +4177,7 @@ def table_name(table) -> str:
"""Get the full name of a table as a string.
Args:
- table (exp.Table | str): Table expression node or string.
+ table (exp.Table | str): table expression node or string.
Examples:
>>> from sqlglot import exp, parse_one
@@ -4001,7 +4185,7 @@ def table_name(table) -> str:
'a.b.c'
Returns:
- str: the table name
+ The table name.
"""
table = maybe_parse(table, into=Table)
@@ -4024,8 +4208,8 @@ def replace_tables(expression, mapping):
"""Replace all tables in expression according to the mapping.
Args:
- expression (sqlglot.Expression): Expression node to be transformed and replaced
- mapping (Dict[str, str]): Mapping of table names
+ expression (sqlglot.Expression): expression node to be transformed and replaced.
+ mapping (Dict[str, str]): mapping of table names.
Examples:
>>> from sqlglot import exp, parse_one
@@ -4033,7 +4217,7 @@ def replace_tables(expression, mapping):
'SELECT * FROM c'
Returns:
- The mapped expression
+ The mapped expression.
"""
def _replace_tables(node):
@@ -4053,9 +4237,9 @@ def replace_placeholders(expression, *args, **kwargs):
"""Replace placeholders in an expression.
Args:
- expression (sqlglot.Expression): Expression node to be transformed and replaced
- args: Positional names that will substitute unnamed placeholders in the given order
- kwargs: Keyword arguments that will substitute named placeholders
+ expression (sqlglot.Expression): expression node to be transformed and replaced.
+ args: positional names that will substitute unnamed placeholders in the given order.
+ kwargs: keyword arguments that will substitute named placeholders.
Examples:
>>> from sqlglot import exp, parse_one
@@ -4065,7 +4249,7 @@ def replace_placeholders(expression, *args, **kwargs):
'SELECT * FROM foo WHERE a = b'
Returns:
- The mapped expression
+ The mapped expression.
"""
def _replace_placeholders(node, args, **kwargs):
@@ -4084,15 +4268,101 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
+def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
+ """Transforms an expression by expanding all referenced sources into subqueries.
+
+ Examples:
+ >>> from sqlglot import parse_one
+ >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
+ 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
+
+ Args:
+ expression: The expression to expand.
+ sources: A dictionary of name to Subqueryables.
+ copy: Whether or not to copy the expression during transformation. Defaults to True.
+
+ Returns:
+ The transformed expression.
+ """
+
+ def _expand(node: Expression):
+ if isinstance(node, Table):
+ name = table_name(node)
+ source = sources.get(name)
+ if source:
+ subquery = source.subquery(node.alias or name)
+ subquery.comments = [f"source: {name}"]
+ return subquery
+ return node
+
+ return expression.transform(_expand, copy=copy)
+
+
+def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
+ """
+ Returns a Func expression.
+
+ Examples:
+ >>> func("abs", 5).sql()
+ 'ABS(5)'
+
+ >>> func("cast", this=5, to=DataType.build("DOUBLE")).sql()
+ 'CAST(5 AS DOUBLE)'
+
+ Args:
+ name: the name of the function to build.
+ args: the args used to instantiate the function of interest.
+ dialect: the source dialect.
+ kwargs: the kwargs used to instantiate the function of interest.
+
+ Note:
+ The arguments `args` and `kwargs` are mutually exclusive.
+
+ Returns:
+ An instance of the function of interest, or an anonymous function, if `name` doesn't
+ correspond to an existing `sqlglot.expressions.Func` class.
+ """
+ if args and kwargs:
+ raise ValueError("Can't use both args and kwargs to instantiate a function.")
+
+ from sqlglot.dialects.dialect import Dialect
+
+ args = tuple(convert(arg) for arg in args)
+ kwargs = {key: convert(value) for key, value in kwargs.items()}
+
+ parser = Dialect.get_or_raise(dialect)().parser()
+ from_args_list = parser.FUNCTIONS.get(name.upper())
+
+ if from_args_list:
+ function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore
+ else:
+ kwargs = kwargs or {"expressions": args}
+ function = Anonymous(this=name, **kwargs)
+
+ for error_message in function.error_messages(args):
+ raise ValueError(error_message)
+
+ return function
+
+
def true():
+ """
+ Returns a true Boolean expression.
+ """
return Boolean(this=True)
def false():
+ """
+ Returns a false Boolean expression.
+ """
return Boolean(this=False)
def null():
+ """
+ Returns a Null expression.
+ """
return Null()
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 6375d92..b398d8e 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -16,7 +16,7 @@ class Generator:
"""
Generator interprets the given syntax tree and produces a SQL string as an output.
- Args
+ Args:
time_mapping (dict): the dictionary of custom time mappings in which the key
represents a python time format and the output the target time format
time_trie (trie): a trie of the time_mapping keys
@@ -84,6 +84,13 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT",
+ exp.DataType.Type.MEDIUMBLOB: "BLOB",
+ exp.DataType.Type.LONGBLOB: "BLOB",
+ }
+
+ STAR_MAPPING = {
+ "except": "EXCEPT",
+ "replace": "REPLACE",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@@ -106,6 +113,8 @@ class Generator:
exp.TableFormatProperty,
}
+ WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint)
+
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@@ -241,15 +250,17 @@ class Generator:
return sql
sep = "\n" if self.pretty else " "
- comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
+ comments_sql = sep.join(
+ f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
+ )
- if not comments:
+ if not comments_sql:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
- return f"{comments}{self.sep()}{sql}"
+ return f"{comments_sql}{self.sep()}{sql}"
- return f"{sql} {comments}"
+ return f"{sql} {comments_sql}"
def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
@@ -433,8 +444,9 @@ class Generator:
def create_sql(self, expression: exp.Create) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
+ begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
- expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
+ expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@@ -741,12 +753,14 @@ class Generator:
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
+ system_time = expression.args.get("system_time")
+ system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
if alias and pivots:
pivots = f"{pivots}{alias}"
alias = ""
- return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
+ return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias:
@@ -1009,9 +1023,9 @@ class Generator:
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
- except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else ""
+ except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
- replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
+ replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"
def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
@@ -1193,6 +1207,12 @@ class Generator:
update = f" ON UPDATE {update}" if update else ""
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
+ def primarykey_sql(self, expression: exp.ForeignKey) -> str:
+ expressions = self.expressions(expression, flat=True)
+ options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = f" {options}" if options else ""
+ return f"PRIMARY KEY ({expressions}){options}"
+
def unique_sql(self, expression: exp.Unique) -> str:
columns = self.expressions(expression, key="expressions")
return f"UNIQUE ({columns})"
@@ -1229,10 +1249,16 @@ class Generator:
unit = f" {unit}" if unit else ""
return f"INTERVAL{this}{unit}"
+ def return_sql(self, expression: exp.Return) -> str:
+ return f"RETURN {self.sql(expression, 'this')}"
+
def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
- return f"REFERENCES {this}({expressions})"
+ expressions = f"({expressions})" if expressions else ""
+ options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = f" {options}" if options else ""
+ return f"REFERENCES {this}{expressions}{options}"
def anonymous_sql(self, expression: exp.Anonymous) -> str:
args = self.format_args(*expression.expressions)
@@ -1362,7 +1388,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
- elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
+ elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
@@ -1370,6 +1396,17 @@ class Generator:
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
+ def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
+ this = self.sql(expression, "this")
+ expression_ = self.sql(expression, "expression")
+ add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD"
+
+ enforced = expression.args.get("enforced")
+ if enforced is not None:
+ return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}"
+
+ return f"{add_constraint} {expression_}"
+
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1550,13 +1587,19 @@ class Generator:
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
+ def tag_sql(self, expression: exp.Tag) -> str:
+ return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
+
def token_sql(self, token_type: TokenType) -> str:
return self.TOKEN_MAPPING.get(token_type, token_type.name)
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
- return f"{this}({expressions})"
+ expressions = (
+ self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
+ )
+ return f"{this}{expressions}"
def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
this = self.sql(expression, "this")
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 5a0f2ac..68e0383 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -332,7 +332,7 @@ def is_iterable(value: t.Any) -> bool:
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
-def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
+def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
"""
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type `str` and `bytes` are not regarded as iterables.
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
new file mode 100644
index 0000000..4e7eab8
--- /dev/null
+++ b/sqlglot/lineage.py
@@ -0,0 +1,228 @@
+from __future__ import annotations
+
+import json
+import typing as t
+from dataclasses import dataclass, field
+
+from sqlglot import Schema, exp, maybe_parse
+from sqlglot.optimizer import Scope, build_scope, optimize
+from sqlglot.optimizer.qualify_columns import qualify_columns
+from sqlglot.optimizer.qualify_tables import qualify_tables
+
+
+@dataclass(frozen=True)
+class Node:
+ name: str
+ expression: exp.Expression
+ source: exp.Expression
+ downstream: t.List[Node] = field(default_factory=list)
+
+ def walk(self) -> t.Iterator[Node]:
+ yield self
+
+ for d in self.downstream:
+ if isinstance(d, Node):
+ yield from d.walk()
+ else:
+ yield d
+
+ def to_html(self, **opts) -> LineageHTML:
+ return LineageHTML(self, **opts)
+
+
+def lineage(
+ column: str | exp.Column,
+ sql: str | exp.Expression,
+ schema: t.Optional[t.Dict | Schema] = None,
+ sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
+ rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
+ dialect: t.Optional[str] = None,
+) -> Node:
+ """Build the lineage graph for a column of a SQL query.
+
+ Args:
+ column: The column to build the lineage for.
+ sql: The SQL string or expression.
+ schema: The schema of tables.
+ sources: A mapping of queries which will be used to continue building lineage.
+ rules: Optimizer rules to apply, by default only qualifying tables and columns.
+ dialect: The dialect of input SQL.
+
+ Returns:
+ A lineage node.
+ """
+
+ expression = maybe_parse(sql, dialect=dialect)
+
+ if sources:
+ expression = exp.expand(
+ expression,
+ {
+ k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
+ for k, v in sources.items()
+ },
+ )
+
+ optimized = optimize(expression, schema=schema, rules=rules)
+ scope = build_scope(optimized)
+ tables: t.Dict[str, Node] = {}
+
+ def to_node(
+ column_name: str,
+ scope: Scope,
+ scope_name: t.Optional[str] = None,
+ upstream: t.Optional[Node] = None,
+ ) -> Node:
+ if isinstance(scope.expression, exp.Union):
+ for scope in scope.union_scopes:
+ node = to_node(
+ column_name,
+ scope=scope,
+ scope_name=scope_name,
+ upstream=upstream,
+ )
+ return node
+
+ select = next(select for select in scope.selects if select.alias_or_name == column_name)
+ source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
+ select = source.selects[0]
+
+ node = Node(
+ name=f"{scope_name}.{column_name}" if scope_name else column_name,
+ source=source,
+ expression=select,
+ )
+
+ if upstream:
+ upstream.downstream.append(node)
+
+ for c in set(select.find_all(exp.Column)):
+ table = c.table
+ source = scope.sources[table]
+
+ if isinstance(source, Scope):
+ to_node(
+ c.name,
+ scope=source,
+ scope_name=table,
+ upstream=node,
+ )
+ else:
+ if table not in tables:
+ tables[table] = Node(name=table, source=source, expression=source)
+ node.downstream.append(tables[table])
+
+ return node
+
+ return to_node(column if isinstance(column, str) else column.name, scope)
+
+
+class LineageHTML:
+ """Node to HTML generator using vis.js.
+
+ https://visjs.github.io/vis-network/docs/network/
+ """
+
+ def __init__(
+ self,
+ node: Node,
+ dialect: t.Optional[str] = None,
+ imports: bool = True,
+ **opts: t.Any,
+ ):
+ self.node = node
+ self.imports = imports
+
+ self.options = {
+ "height": "500px",
+ "width": "100%",
+ "layout": {
+ "hierarchical": {
+ "enabled": True,
+ "nodeSpacing": 200,
+ "sortMethod": "directed",
+ },
+ },
+ "interaction": {
+ "dragNodes": False,
+ "selectable": False,
+ },
+ "physics": {
+ "enabled": False,
+ },
+ "edges": {
+ "arrows": "to",
+ },
+ "nodes": {
+ "font": "20px monaco",
+ "shape": "box",
+ "widthConstraint": {
+ "maximum": 300,
+ },
+ },
+ **opts,
+ }
+
+ self.nodes = {}
+ self.edges = []
+
+ for node in node.walk():
+ if isinstance(node.expression, exp.Table):
+ label = f"FROM {node.expression.this}"
+ title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
+ group = 1
+ else:
+ label = node.expression.sql(pretty=True, dialect=dialect)
+ source = node.source.transform(
+ lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
+ if n is node.expression
+ else n,
+ copy=False,
+ ).sql(pretty=True, dialect=dialect)
+ title = f"<pre>{source}</pre>"
+ group = 0
+
+ node_id = id(node)
+
+ self.nodes[node_id] = {
+ "id": node_id,
+ "label": label,
+ "title": title,
+ "group": group,
+ }
+
+ for d in node.downstream:
+ self.edges.append({"from": node_id, "to": id(d)})
+
+ def __str__(self):
+ nodes = json.dumps(list(self.nodes.values()))
+ edges = json.dumps(self.edges)
+ options = json.dumps(self.options)
+ imports = (
+ """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
+ <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
+ <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
+ if self.imports
+ else ""
+ )
+
+ return f"""<div>
+ <div id="sqlglot-lineage"></div>
+ {imports}
+ <script type="text/javascript">
+ var nodes = new vis.DataSet({nodes})
+ nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
+
+ new vis.Network(
+ document.getElementById("sqlglot-lineage"),
+ {{
+ nodes: nodes,
+ edges: new vis.DataSet({edges})
+ }},
+ {options},
+ )
+ </script>
+</div>"""
+
+ def _repr_html_(self) -> str:
+ return self.__str__()
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
index bba0878..719a77e 100644
--- a/sqlglot/optimizer/__init__.py
+++ b/sqlglot/optimizer/__init__.py
@@ -1 +1,2 @@
from sqlglot.optimizer.optimizer import RULES, optimize
+from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index 652cdef..5bd7b30 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -1,15 +1,18 @@
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.schema import ensure_schema
-def isolate_table_selects(expression):
+def isolate_table_selects(expression, schema=None):
+ schema = ensure_schema(schema)
+
for scope in traverse_scope(expression):
if len(scope.selected_sources) == 1:
continue
for (_, source) in scope.selected_sources.values():
- if not isinstance(source, exp.Table):
+ if not isinstance(source, exp.Table) or not schema.column_names(source):
continue
if not source.alias:
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 8da4e43..54425a8 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -1,7 +1,8 @@
import itertools
+import typing as t
from sqlglot import alias, exp
-from sqlglot.errors import OptimizeError, SchemaError
+from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -190,20 +191,15 @@ def _qualify_columns(scope, resolver):
column_table = column.table
column_name = column.name
- if (
- column_table
- and column_table in scope.sources
- and column_name not in resolver.get_source_columns(column_table)
- ):
- raise OptimizeError(f"Unknown column: {column_name}")
+ if column_table and column_table in scope.sources:
+ source_columns = resolver.get_source_columns(column_table)
+ if source_columns and column_name not in source_columns:
+ raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_udtf:
- if column_name not in resolver.all_columns:
- raise OptimizeError(f"Unknown column: {column_name}")
-
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
@@ -265,6 +261,10 @@ def _expand_stars(scope, resolver):
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
+ if not columns:
+ raise OptimizeError(
+ f"Table has no schema/columns. Cannot expand star for table: {table}."
+ )
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
@@ -306,16 +306,11 @@ def _qualify_outputs(scope):
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
- if isinstance(selection, exp.Column):
- # convoluted setter because a simple selection.replace(alias) would require a copy
- alias_ = alias(exp.column(""), alias=selection.name)
- alias_.set("this", selection)
- selection = alias_
- elif isinstance(selection, exp.Subquery):
- if not selection.alias:
+ if isinstance(selection, exp.Subquery):
+ if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
- alias_ = alias(exp.column(""), f"_col_{i}")
+ alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
alias_.set("this", selection)
selection = alias_
@@ -346,20 +341,30 @@ class _Resolver:
self._unambiguous_columns = None
self._all_columns = None
- def get_table(self, column_name):
+ def get_table(self, column_name: str) -> t.Optional[str]:
"""
Get the table for a column name.
Args:
- column_name (str)
+ column_name: The column name to find the table for.
Returns:
- (str) table name
+ The table name if it can be found/inferred.
"""
if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
- return self._unambiguous_columns.get(column_name)
+
+ table = self._unambiguous_columns.get(column_name)
+
+ if not table:
+ sources_without_schema = tuple(
+ source for source, columns in self._get_all_source_columns().items() if not columns
+ )
+ if len(sources_without_schema) == 1:
+ return sources_without_schema[0]
+
+ return table
@property
def all_columns(self):
@@ -379,10 +384,7 @@ class _Resolver:
# If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
- try:
- return self.schema.column_names(source, only_visible)
- except Exception as e:
- raise SchemaError(str(e)) from e
+ return self.schema.column_names(source, only_visible)
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 6125e4e..5a3ed5a 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -230,7 +230,7 @@ class Scope:
column for scope in self.subquery_scopes for column in scope.external_columns
]
- named_outputs = {e.alias_or_name for e in self.expression.expressions}
+ named_selects = set(self.expression.named_selects)
self._columns = []
for column in columns + external_columns:
@@ -238,7 +238,7 @@ class Scope:
if (
not ancestor
or column.table
- or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
+ or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index c97b19a..42777d1 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -40,22 +40,23 @@ class _Parser(type):
class Parser(metaclass=_Parser):
"""
- Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
- and produces a parsed syntax tree.
-
- Args
- error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE.
- error_message_context (int): determines the amount of context to capture from
- a query string when displaying the error message (in number of characters).
+ Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces
+ a parsed syntax tree.
+
+ Args:
+ error_level: the desired error level.
+ Default: ErrorLevel.RAISE
+ error_message_context: determines the amount of context to capture from a
+ query string when displaying the error message (in number of characters).
Default: 50.
- index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list
+ index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list.
Default: 0
- alias_post_tablesample (bool): If the table alias comes after tablesample
+ alias_post_tablesample: If the table alias comes after tablesample.
Default: False
- max_errors (int): Maximum number of error messages to include in a raised ParseError.
+ max_errors: Maximum number of error messages to include in a raised ParseError.
This is only relevant if error_level is ErrorLevel.RAISE.
Default: 3
- null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
+ null_ordering: Indicates the default null ordering method to use if not explicitly set.
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
Default: "nulls_are_small"
"""
@@ -109,6 +110,8 @@ class Parser(metaclass=_Parser):
TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
+ TokenType.MEDIUMBLOB,
+ TokenType.LONGBLOB,
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
@@ -176,6 +179,7 @@ class Parser(metaclass=_Parser):
TokenType.DIV,
TokenType.DISTKEY,
TokenType.DISTSTYLE,
+ TokenType.END,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
@@ -468,9 +472,6 @@ class Parser(metaclass=_Parser):
TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
- TokenType.PARAMETER: lambda self, _: self.expression(
- exp.Parameter, this=self._parse_var() or self._parse_primary()
- ),
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
@@ -479,6 +480,16 @@ class Parser(metaclass=_Parser):
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
+ PLACEHOLDER_PARSERS = {
+ TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
+ TokenType.PARAMETER: lambda self: self.expression(
+ exp.Parameter, this=self._parse_var() or self._parse_primary()
+ ),
+ TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match_set((TokenType.NUMBER, TokenType.VAR))
+ else None,
+ }
+
RANGE_PARSERS = {
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
TokenType.IN: lambda self, this: self._parse_in(this),
@@ -601,8 +612,7 @@ class Parser(metaclass=_Parser):
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
- # allows tables to have special tokens as prefixes
- TABLE_PREFIX_TOKENS: t.Set[TokenType] = set()
+ ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
STRICT_CAST = True
@@ -677,7 +687,7 @@ class Parser(metaclass=_Parser):
def parse_into(
self,
- expression_types: str | exp.Expression | t.Collection[exp.Expression | str],
+ expression_types: exp.IntoType,
raw_tokens: t.List[Token],
sql: t.Optional[str] = None,
) -> t.List[t.Optional[exp.Expression]]:
@@ -820,24 +830,8 @@ class Parser(metaclass=_Parser):
if self.error_level == ErrorLevel.IGNORE:
return
- for k in expression.args:
- if k not in expression.arg_types:
- self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
- for k, mandatory in expression.arg_types.items():
- v = expression.args.get(k)
- if mandatory and (v is None or (isinstance(v, list) and not v)):
- self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
-
- if (
- args
- and isinstance(expression, exp.Func)
- and len(args) > len(expression.arg_types)
- and not expression.is_var_len_args
- ):
- self.raise_error(
- f"The number of provided arguments ({len(args)}) is greater than "
- f"the maximum number of supported arguments ({len(expression.arg_types)})"
- )
+ for error_message in expression.error_messages(args):
+ self.raise_error(error_message)
def _find_token(self, token: Token, sql: str) -> int:
line = 1
@@ -868,6 +862,9 @@ class Parser(metaclass=_Parser):
def _retreat(self, index: int) -> None:
self._advance(index - self._index)
+ def _parse_command(self) -> exp.Expression:
+ return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
+
def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@@ -876,11 +873,7 @@ class Parser(metaclass=_Parser):
return self.STATEMENT_PARSERS[self._prev.token_type](self)
if self._match_set(Tokenizer.COMMANDS):
- return self.expression(
- exp.Command,
- this=self._prev.text,
- expression=self._parse_string(),
- )
+ return self._parse_command()
expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select()
@@ -942,12 +935,18 @@ class Parser(metaclass=_Parser):
no_primary_index = None
indexes = None
no_schema_binding = None
+ begin = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
- this = self._parse_user_defined_function()
+ this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
- expression = self._parse_select_or_expression()
+ begin = self._match(TokenType.BEGIN)
+ return_ = self._match_text_seq("RETURN")
+ expression = self._parse_statement()
+
+ if return_:
+ expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
elif create_token.token_type in (
@@ -1002,6 +1001,7 @@ class Parser(metaclass=_Parser):
no_primary_index=no_primary_index,
indexes=indexes,
no_schema_binding=no_schema_binding,
+ begin=begin,
)
def _parse_property(self) -> t.Optional[exp.Expression]:
@@ -1087,7 +1087,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
- value = self._parse_schema(exp.Literal.string("TABLE"))
+ value = self._parse_schema(exp.Var(this="TABLE"))
else:
value = self._parse_types()
@@ -1550,7 +1550,7 @@ class Parser(metaclass=_Parser):
return None
index = self._parse_id_var()
columns = None
- if self._curr and self._curr.token_type == TokenType.L_PAREN:
+ if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(self._parse_column)
return self.expression(
exp.Index,
@@ -1561,6 +1561,27 @@ class Parser(metaclass=_Parser):
amp=amp,
)
+ def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ catalog = None
+ db = None
+ table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False)
+
+ while self._match(TokenType.DOT):
+ if catalog:
+ # This allows nesting the table in arbitrarily many dot expressions if needed
+ table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
+ else:
+ catalog = db
+ db = table
+ table = self._parse_id_var()
+
+ if not table:
+ self.raise_error(f"Expected table name but got {self._curr}")
+
+ return self.expression(
+ exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
+ )
+
def _parse_table(
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
@@ -1584,27 +1605,7 @@ class Parser(metaclass=_Parser):
if subquery:
return subquery
- catalog = None
- db = None
- table = (not schema and self._parse_function()) or self._parse_id_var(
- any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS
- )
-
- while self._match(TokenType.DOT):
- if catalog:
- # This allows nesting the table in arbitrarily many dot expressions if needed
- table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
- else:
- catalog = db
- db = table
- table = self._parse_id_var()
-
- if not table:
- self.raise_error(f"Expected table name but got {self._curr}")
-
- this = self.expression(
- exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
- )
+ this = self._parse_table_parts(schema=schema)
if schema:
return self._parse_schema(this=this)
@@ -1889,7 +1890,7 @@ class Parser(metaclass=_Parser):
expression,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
- expression=self._parse_select(nested=True),
+ expression=self._parse_set_operations(self._parse_select(nested=True)),
)
def _parse_expression(self) -> t.Optional[exp.Expression]:
@@ -2286,7 +2287,9 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this)
return self._parse_window(this)
- def _parse_user_defined_function(self) -> t.Optional[exp.Expression]:
+ def _parse_user_defined_function(
+ self, kind: t.Optional[TokenType] = None
+ ) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
while self._match(TokenType.DOT):
@@ -2297,7 +2300,9 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
- return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+ return self.expression(
+ exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
+ )
def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
literal = self._parse_primary()
@@ -2371,10 +2376,6 @@ class Parser(metaclass=_Parser):
or self._parse_column_def(self._parse_field(any_token=True))
)
self._match_r_paren()
-
- if isinstance(this, exp.Literal):
- this = this.name
-
return self.expression(exp.Schema, this=this, expressions=args)
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@@ -2470,15 +2471,43 @@ class Parser(metaclass=_Parser):
def _parse_unique(self) -> exp.Expression:
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
+ def _parse_key_constraint_options(self) -> t.List[str]:
+ options = []
+ while True:
+ if not self._curr:
+ break
+
+ if self._match_text_seq("NOT", "ENFORCED"):
+ options.append("NOT ENFORCED")
+ elif self._match_text_seq("DEFERRABLE"):
+ options.append("DEFERRABLE")
+ elif self._match_text_seq("INITIALLY", "DEFERRED"):
+ options.append("INITIALLY DEFERRED")
+ elif self._match_text_seq("NORELY"):
+ options.append("NORELY")
+ elif self._match_text_seq("MATCH", "FULL"):
+ options.append("MATCH FULL")
+ elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
+ options.append("ON UPDATE NO ACTION")
+ elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
+ options.append("ON DELETE NO ACTION")
+ else:
+ break
+
+ return options
+
def _parse_references(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.REFERENCES):
return None
- return self.expression(
- exp.Reference,
- this=self._parse_id_var(),
- expressions=self._parse_wrapped_id_vars(),
- )
+ expressions = None
+ this = self._parse_id_var()
+
+ if self._match(TokenType.L_PAREN, advance=False):
+ expressions = self._parse_wrapped_id_vars()
+
+ options = self._parse_key_constraint_options()
+ return self.expression(exp.Reference, this=this, expressions=expressions, options=options)
def _parse_foreign_key(self) -> exp.Expression:
expressions = self._parse_wrapped_id_vars()
@@ -2503,12 +2532,14 @@ class Parser(metaclass=_Parser):
options[kind] = action
return self.expression(
- exp.ForeignKey,
- expressions=expressions,
- reference=reference,
- **options, # type: ignore
+ exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
)
+ def _parse_primary_key(self) -> exp.Expression:
+ expressions = self._parse_wrapped_id_vars()
+ options = self._parse_key_constraint_options()
+ return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
+
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.L_BRACKET):
return this
@@ -2631,7 +2662,7 @@ class Parser(metaclass=_Parser):
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
- def _parse_convert(self, strict: bool) -> exp.Expression:
+ def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to: t.Optional[exp.Expression]
this = self._parse_column()
@@ -2641,19 +2672,25 @@ class Parser(metaclass=_Parser):
to = self._parse_types()
else:
to = None
+
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_position(self) -> exp.Expression:
+ def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN):
- args.append(self._parse_bitwise())
+ return self.expression(
+ exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0)
+ )
- this = exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
- )
+ if haystack_first:
+ haystack = seq_get(args, 0)
+ needle = seq_get(args, 1)
+ else:
+ needle = seq_get(args, 0)
+ haystack = seq_get(args, 1)
+
+ this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2))
self.validate_expression(this, args)
@@ -2894,24 +2931,26 @@ class Parser(metaclass=_Parser):
return None
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
- if self._match(TokenType.PLACEHOLDER):
- return self.expression(exp.Placeholder)
- elif self._match(TokenType.COLON):
- if self._match_set((TokenType.NUMBER, TokenType.VAR)):
- return self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match_set(self.PLACEHOLDER_PARSERS):
+ placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self)
+ if placeholder:
+ return placeholder
self._advance(-1)
return None
def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.EXCEPT):
return None
-
- return self._parse_wrapped_id_vars()
+ if self._match(TokenType.L_PAREN, advance=False):
+ return self._parse_wrapped_id_vars()
+ return self._parse_csv(self._parse_id_var)
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE):
return None
- return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
+ if self._match(TokenType.L_PAREN, advance=False):
+ return self._parse_wrapped_csv(self._parse_expression)
+ return self._parse_csv(self._parse_expression)
def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
@@ -3021,6 +3060,28 @@ class Parser(metaclass=_Parser):
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
+ def _parse_add_constraint(self) -> t.Optional[exp.Expression]:
+ this = None
+ kind = self._prev.token_type
+
+ if kind == TokenType.CONSTRAINT:
+ this = self._parse_id_var()
+
+ if self._match(TokenType.CHECK):
+ expression = self._parse_wrapped(self._parse_conjunction)
+ enforced = self._match_text_seq("ENFORCED")
+
+ return self.expression(
+ exp.AddConstraint, this=this, expression=expression, enforced=enforced
+ )
+
+ if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY):
+ expression = self._parse_foreign_key()
+ elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
+ expression = self._parse_primary_key()
+
+ return self.expression(exp.AddConstraint, this=this, expression=expression)
+
def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
return None
@@ -3029,8 +3090,14 @@ class Parser(metaclass=_Parser):
this = self._parse_table(schema=True)
actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
- if self._match_text_seq("ADD", advance=False):
- actions = self._parse_csv(self._parse_add_column)
+
+ index = self._index
+ if self._match_text_seq("ADD"):
+ if self._match_set(self.ADD_CONSTRAINT_TOKENS):
+ actions = self._parse_csv(self._parse_add_constraint)
+ else:
+ self._retreat(index)
+ actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("RENAME", "TO"):
@@ -3077,7 +3144,7 @@ class Parser(metaclass=_Parser):
def _parse_merge(self) -> exp.Expression:
self._match(TokenType.INTO)
- target = self._parse_table(schema=True)
+ target = self._parse_table()
self._match(TokenType.USING)
using = self._parse_table()
@@ -3146,12 +3213,13 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return None
- def _match(self, token_type):
+ def _match(self, token_type, advance=True):
if not self._curr:
return None
if self._curr.token_type == token_type:
- self._advance()
+ if advance:
+ self._advance()
return True
return None
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 4967231..40df39f 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -32,7 +32,7 @@ class Plan:
return self._dag
@property
- def leaves(self) -> t.Generator[Step, None, None]:
+ def leaves(self) -> t.Iterator[Step]:
return (node for node, deps in self.dag.items() if not deps)
def __repr__(self) -> str:
@@ -401,7 +401,7 @@ class SetOperation(Step):
op=expression.__class__,
left=left.name,
right=right.name,
- distinct=expression.args.get("distinct"),
+ distinct=bool(expression.args.get("distinct")),
)
step.add_dependency(left)
step.add_dependency(right)
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index a0d69a7..f6f3883 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -109,10 +109,7 @@ class AbstractMappingSchema(t.Generic[T]):
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
if value == 0:
- if raise_on_missing:
- raise SchemaError(f"Cannot find mapping for {table}.")
- else:
- return None
+ return None
elif value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
if len(possibilities) == 1:
@@ -262,7 +259,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema = self.find(table_)
if schema is None:
- raise SchemaError(f"Could not find table schema {table}")
+ return []
if not only_visible or not self.visible:
return list(schema)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index f12528f..19dd1d6 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -84,6 +84,8 @@ class TokenType(AutoName):
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
+ MEDIUMBLOB = auto()
+ LONGBLOB = auto()
BINARY = auto()
VARBINARY = auto()
JSON = auto()
@@ -587,6 +589,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
+ "QUALIFY": TokenType.QUALIFY,
"RANGE": TokenType.RANGE,
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
@@ -726,6 +729,8 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SHOW,
}
+ COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}
+
# handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None
@@ -842,8 +847,10 @@ class Tokenizer(metaclass=_Tokenizer):
)
self._comments = []
+ # If we have either a semicolon or a begin token before the command's token, we'll parse
+ # whatever follows the command's token as a string
if token_type in self.COMMANDS and (
- len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
+ len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS
):
start = self._current
tokens = len(self.tokens)