summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-19 14:50:39 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-19 14:50:39 +0000
commitf2981e8e4d28233864f1ca06ecec45ab80bf9eae (patch)
treeb70cb633916830138ce3424aa361f0bbaff02be2
parentReleasing debian version 10.0.1-1. (diff)
downloadsqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.tar.xz
sqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.zip
Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r--CHANGELOG.md22
-rw-r--r--CONTRIBUTING.md2
-rw-r--r--README.md14
-rw-r--r--benchmarks/bench.py12
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dataframe/sql/column.py5
-rw-r--r--sqlglot/dataframe/sql/dataframe.py8
-rw-r--r--sqlglot/dataframe/sql/functions.py18
-rw-r--r--sqlglot/dataframe/sql/session.py8
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py11
-rw-r--r--sqlglot/dialects/dialect.py16
-rw-r--r--sqlglot/dialects/drill.py174
-rw-r--r--sqlglot/dialects/duckdb.py4
-rw-r--r--sqlglot/dialects/hive.py20
-rw-r--r--sqlglot/dialects/mysql.py76
-rw-r--r--sqlglot/dialects/oracle.py1
-rw-r--r--sqlglot/dialects/postgres.py38
-rw-r--r--sqlglot/dialects/presto.py38
-rw-r--r--sqlglot/dialects/snowflake.py2
-rw-r--r--sqlglot/dialects/sqlite.py5
-rw-r--r--sqlglot/dialects/tsql.py2
-rw-r--r--sqlglot/diff.py55
-rw-r--r--sqlglot/errors.py4
-rw-r--r--sqlglot/executor/__init__.py23
-rw-r--r--sqlglot/executor/context.py47
-rw-r--r--sqlglot/executor/env.py162
-rw-r--r--sqlglot/executor/python.py287
-rw-r--r--sqlglot/executor/table.py43
-rw-r--r--sqlglot/expressions.py128
-rw-r--r--sqlglot/generator.py42
-rw-r--r--sqlglot/helper.py45
-rw-r--r--sqlglot/optimizer/annotate_types.py26
-rw-r--r--sqlglot/optimizer/canonicalize.py48
-rw-r--r--sqlglot/optimizer/eliminate_joins.py13
-rw-r--r--sqlglot/optimizer/optimize_joins.py4
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py4
-rw-r--r--sqlglot/optimizer/qualify_tables.py14
-rw-r--r--sqlglot/optimizer/simplify.py6
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py2
-rw-r--r--sqlglot/parser.py403
-rw-r--r--sqlglot/planner.py227
-rw-r--r--sqlglot/schema.py215
-rw-r--r--sqlglot/tokens.py58
-rw-r--r--tests/dataframe/unit/test_dataframe.py20
-rw-r--r--tests/dataframe/unit/test_dataframe_writer.py36
-rw-r--r--tests/dataframe/unit/test_session.py17
-rw-r--r--tests/dialects/test_bigquery.py4
-rw-r--r--tests/dialects/test_dialect.py81
-rw-r--r--tests/dialects/test_drill.py53
-rw-r--r--tests/dialects/test_mysql.py10
-rw-r--r--tests/dialects/test_presto.py75
-rw-r--r--tests/dialects/test_snowflake.py11
-rw-r--r--tests/fixtures/identity.sql21
-rw-r--r--tests/fixtures/optimizer/canonicalize.sql5
-rw-r--r--tests/fixtures/optimizer/optimizer.sql4
-rw-r--r--tests/fixtures/optimizer/tpc-h/tpc-h.sql50
-rw-r--r--tests/fixtures/pretty.sql7
-rw-r--r--tests/helpers.py64
-rw-r--r--tests/test_executor.py403
-rw-r--r--tests/test_expressions.py13
-rw-r--r--tests/test_optimizer.py19
-rw-r--r--tests/test_parser.py57
-rw-r--r--tests/test_tokens.py1
-rw-r--r--tests/test_transpile.py11
67 files changed, 2463 insertions, 842 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 87dd21d..70f2b55 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,21 +7,37 @@ v10.0.0
Changes:
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
+
- Breaking: renamed list_get to seq_get.
+
- Breaking: activated mypy type checking for SQLGlot.
+
- New: Azure Databricks support.
+
- New: placeholders can now be replaced in an expression.
+
- New: null safe equal operator (<=>).
+
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
+
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
+
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
+
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
-- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16)
+
+- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16).
+
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
+
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
+
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
+
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
+
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
+
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
v9.0.0
@@ -37,6 +53,7 @@ v8.0.0
Changes:
- Breaking : New add\_table method in Schema ABC.
+
- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental.
v7.1.0
@@ -45,8 +62,11 @@ v7.1.0
Changes:
- Improvement: Pretty generator now takes max\_text\_width which breaks segments into new lines
+
- New: exp.to\_table helper to turn table names into table expression objects
+
- New: int[] type parsers
+
- New: annotations are now generated in sql
v7.0.0
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 1d3b822..97c795d 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -21,7 +21,7 @@ Pull requests are the best way to propose changes to the codebase. We actively w
5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor!
## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues)
-We use GitHub issues to track public bugs. Report a bug by [opening a new issue]().
+We use GitHub issues to track public bugs. Report a bug by opening a new issue.
**Great Bug Reports** tend to have:
diff --git a/README.md b/README.md
index b00b803..2ceadfb 100644
--- a/README.md
+++ b/README.md
@@ -90,7 +90,7 @@ sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive"
"SELECT DATE_FORMAT(x, 'yy-M-ss')"
```
-As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`:
+As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks for identifiers and `FLOAT` instead of `REAL`:
```python
import sqlglot
@@ -376,12 +376,12 @@ print(Dialect["custom"])
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
-| Query | sqlglot | sqltree | sqlparse | moz_sql_parser | sqloxide |
-| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
-| tpch | 0.01178 (1.0) | 0.01173 (0.995) | 0.04676 (3.966) | 0.06800 (5.768) | 0.00094 (0.080) |
-| short | 0.00084 (1.0) | 0.00079 (0.948) | 0.00296 (3.524) | 0.00443 (5.266) | 0.00006 (0.072) |
-| long | 0.01102 (1.0) | 0.01044 (0.947) | 0.04349 (3.945) | 0.05998 (5.440) | 0.00084 (0.077) |
-| crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) |
+| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
+| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
+| tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) |
+| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) |
+| long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) |
+| crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) |
## Optional Dependencies
diff --git a/benchmarks/bench.py b/benchmarks/bench.py
index cef62a8..2475608 100644
--- a/benchmarks/bench.py
+++ b/benchmarks/bench.py
@@ -5,8 +5,10 @@ collections.Iterable = collections.abc.Iterable
import gc
import timeit
-import moz_sql_parser
import numpy as np
+
+import sqlfluff
+import moz_sql_parser
import sqloxide
import sqlparse
import sqltree
@@ -177,6 +179,10 @@ def sqloxide_parse(sql):
sqloxide.parse_sql(sql, dialect="ansi")
+def sqlfluff_parse(sql):
+ sqlfluff.parse(sql)
+
+
def border(columns):
columns = " | ".join(columns)
return f"| {columns} |"
@@ -193,6 +199,7 @@ def diff(row, column):
libs = [
"sqlglot",
+ "sqlfluff",
"sqltree",
"sqlparse",
"moz_sql_parser",
@@ -206,7 +213,8 @@ for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.it
for lib in libs:
try:
row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3))
- except:
+ except Exception as e:
+ print(e)
row[lib] = "error"
columns = ["Query"] + libs
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 6e67b19..50e2d9c 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.0.1"
+__version__ = "10.0.8"
pretty = False
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index f9e1c5b..22075e9 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -260,7 +260,10 @@ class Column:
"""
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
- new_expression = exp.Cast(this=self.column_expression, to=dataType)
+ new_expression = exp.Cast(
+ this=self.column_expression,
+ to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
+ )
return Column(new_expression)
def startswith(self, value: t.Union[str, Column]) -> Column:
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 40cd6c9..548c322 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -314,7 +314,13 @@ class DataFrame:
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
cache_table_name
)
- sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
+ sqlglot.schema.add_table(
+ cache_table_name,
+ {
+ expression.alias_or_name: expression.type.name
+ for expression in select_expression.expressions
+ },
+ )
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index dbfb06f..1ee361a 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
def decode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
+ return Column.invoke_expression_over_column(
+ col, glotexp.Decode, charset=glotexp.Literal.string(charset)
+ )
def encode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
+ return Column.invoke_expression_over_column(
+ col, glotexp.Encode, charset=glotexp.Literal.string(charset)
+ )
def format_number(col: ColumnOrName, d: int) -> Column:
@@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column:
def hex(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "HEX")
+ return Column.invoke_expression_over_column(col, glotexp.Hex)
def unhex(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "UNHEX")
+ return Column.invoke_expression_over_column(col, glotexp.Unhex)
def length(col: ColumnOrName) -> Column:
@@ -939,11 +943,7 @@ def array_join(
def concat(*cols: ColumnOrName) -> Column:
- if len(cols) == 1:
- return Column.invoke_anonymous_function(cols[0], "CONCAT")
- return Column.invoke_anonymous_function(
- cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
- )
+ return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 8cb16ef..c4a22c6 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -88,14 +88,14 @@ class SparkSession:
"expressions": sel_columns,
"from": exp.From(
expressions=[
- exp.Subquery(
- this=exp.Values(expressions=data_expressions),
+ exp.Values(
+ expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
- )
- ]
+ ),
+ ],
),
}
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 0816831..2e42e7d 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
+from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5bbff9d..4550d65 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -119,6 +119,8 @@ class BigQuery(Dialect):
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
+ "BEGIN": TokenType.COMMAND,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@@ -204,6 +206,15 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
+ def transaction_sql(self, *_):
+ return "BEGIN TRANSACTION"
+
+ def commit_sql(self, *_):
+ return "COMMIT TRANSACTION"
+
+ def rollback_sql(self, *_):
+ return "ROLLBACK TRANSACTION"
+
def in_unnest_op(self, unnest):
return self.sql(unnest)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 3af08bb..8c497ab 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -32,6 +32,7 @@ class Dialects(str, Enum):
TRINO = "trino"
TSQL = "tsql"
DATABRICKS = "databricks"
+ DRILL = "drill"
class _Dialect(type):
@@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
+
+
+def locate_to_strposition(args):
+ return exp.StrPosition(
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
+ )
+
+
+def strposition_to_local_sql(self, expression):
+ args = self.format_args(
+ expression.args.get("substr"), expression.this, expression.args.get("position")
+ )
+ return f"LOCATE({args})"
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
new file mode 100644
index 0000000..eb420aa
--- /dev/null
+++ b/sqlglot/dialects/drill.py
@@ -0,0 +1,174 @@
+from __future__ import annotations
+
+import re
+
+from sqlglot import exp, generator, parser, tokens
+from sqlglot.dialects.dialect import (
+ Dialect,
+ create_with_partitions_sql,
+ format_time_lambda,
+ no_pivot_sql,
+ no_trycast_sql,
+ rename_func,
+ str_position_sql,
+)
+from sqlglot.dialects.postgres import _lateral_sql
+
+
+def _to_timestamp(args):
+ # TO_TIMESTAMP accepts either a single double argument or (text, text)
+ if len(args) == 1 and args[0].is_number:
+ return exp.UnixToTime.from_arg_list(args)
+ return format_time_lambda(exp.StrToTime, "drill")(args)
+
+
+def _str_to_time_sql(self, expression):
+ return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
+
+
+def _ts_or_ds_to_date_sql(self, expression):
+ time_format = self.format_time(expression)
+ if time_format and time_format not in (Drill.time_format, Drill.date_format):
+ return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
+ return f"CAST({self.sql(expression, 'this')} AS DATE)"
+
+
+def _date_add_sql(kind):
+ def func(self, expression):
+ this = self.sql(expression, "this")
+ unit = expression.text("unit").upper() or "DAY"
+ expression = self.sql(expression, "expression")
+ return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
+
+ return func
+
+
+def if_sql(self, expression):
+ """
+ Drill requires backticks around certain SQL reserved words, IF being one of them, This function
+ adds the backticks around the keyword IF.
+ Args:
+ self: The Drill dialect
+ expression: The input IF expression
+
+ Returns: The expression with IF in backticks.
+
+ """
+ expressions = self.format_args(
+ expression.this, expression.args.get("true"), expression.args.get("false")
+ )
+ return f"`IF`({expressions})"
+
+
+def _str_to_date(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format == Drill.date_format:
+ return f"CAST({this} AS DATE)"
+ return f"TO_DATE({this}, {time_format})"
+
+
+class Drill(Dialect):
+ normalize_functions = None
+ null_ordering = "nulls_are_last"
+ date_format = "'yyyy-MM-dd'"
+ dateint_format = "'yyyyMMdd'"
+ time_format = "'yyyy-MM-dd HH:mm:ss'"
+
+ time_mapping = {
+ "y": "%Y",
+ "Y": "%Y",
+ "YYYY": "%Y",
+ "yyyy": "%Y",
+ "YY": "%y",
+ "yy": "%y",
+ "MMMM": "%B",
+ "MMM": "%b",
+ "MM": "%m",
+ "M": "%-m",
+ "dd": "%d",
+ "d": "%-d",
+ "HH": "%H",
+ "H": "%-H",
+ "hh": "%I",
+ "h": "%-I",
+ "mm": "%M",
+ "m": "%-M",
+ "ss": "%S",
+ "s": "%-S",
+ "SSSSSS": "%f",
+ "a": "%p",
+ "DD": "%j",
+ "D": "%-j",
+ "E": "%a",
+ "EE": "%a",
+ "EEE": "%a",
+ "EEEE": "%A",
+ "''T''": "T",
+ }
+
+ class Tokenizer(tokens.Tokenizer):
+ QUOTES = ["'"]
+ IDENTIFIERS = ["`"]
+ ESCAPES = ["\\"]
+ ENCODE = "utf-8"
+
+ class Parser(parser.Parser):
+ STRICT_CAST = False
+
+ FUNCTIONS = {
+ **parser.Parser.FUNCTIONS,
+ "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
+ "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
+ }
+
+ class Generator(generator.Generator):
+ TYPE_MAPPING = {
+ **generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.SMALLINT: "INTEGER",
+ exp.DataType.Type.TINYINT: "INTEGER",
+ exp.DataType.Type.BINARY: "VARBINARY",
+ exp.DataType.Type.TEXT: "VARCHAR",
+ exp.DataType.Type.NCHAR: "VARCHAR",
+ exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
+ }
+
+ ROOT_PROPERTIES = {exp.PartitionedByProperty}
+
+ TRANSFORMS = {
+ **generator.Generator.TRANSFORMS,
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.Lateral: _lateral_sql,
+ exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
+ exp.ArraySize: rename_func("REPEATED_COUNT"),
+ exp.Create: create_with_partitions_sql,
+ exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.DateSub: _date_add_sql("SUB"),
+ exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
+ exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
+ exp.If: if_sql,
+ exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
+ exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
+ exp.Pivot: no_pivot_sql,
+ exp.RegexpLike: rename_func("REGEXP_MATCHES"),
+ exp.StrPosition: str_position_sql,
+ exp.StrToDate: _str_to_date,
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TryCast: no_trycast_sql,
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
+ }
+
+ def normalize_func(self, name):
+ return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 781edff..f1da72b 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
def _sort_array_sql(self, expression):
this = self.sql(expression, "this")
- if expression.args.get("asc") == exp.FALSE:
+ if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
- return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
+ return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_pack_sql(self, expression):
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index ed7357c..cff7139 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
if_sql,
+ locate_to_strposition,
no_ilike_sql,
no_recursive_cte_sql,
no_safe_divide_sql,
no_trycast_sql,
rename_func,
+ strposition_to_local_sql,
struct_extract_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
+from sqlglot.tokens import TokenType
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@@ -181,6 +184,15 @@ class Hive(Dialect):
"F": "FLOAT",
"BD": "DECIMAL",
}
+ KEYWORDS = {
+ **tokens.Tokenizer.KEYWORDS,
+ "ADD ARCHIVE": TokenType.COMMAND,
+ "ADD ARCHIVES": TokenType.COMMAND,
+ "ADD FILE": TokenType.COMMAND,
+ "ADD FILES": TokenType.COMMAND,
+ "ADD JAR": TokenType.COMMAND,
+ "ADD JARS": TokenType.COMMAND,
+ }
class Parser(parser.Parser):
STRICT_CAST = False
@@ -210,11 +222,7 @@ class Hive(Dialect):
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
- "LOCATE": lambda args: exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
- ),
+ "LOCATE": locate_to_strposition,
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
@@ -272,7 +280,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
- exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
+ exp.StrPosition: strposition_to_local_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index e742640..93a60f4 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -5,10 +5,12 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ locate_to_strposition,
no_ilike_sql,
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
+ strposition_to_local_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -120,6 +122,7 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@@ -172,13 +175,18 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
- STRICT_CAST = False
+ FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
+ "LOCATE": locate_to_strposition,
+ "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
+ "LEFT": lambda args: exp.Substring(
+ this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
+ ),
}
FUNCTION_PARSERS = {
@@ -264,6 +272,7 @@ class MySQL(Dialect):
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
+ "TRANSACTION": lambda self: self._parse_set_transaction(),
}
PROFILE_TYPES = {
@@ -278,39 +287,48 @@ class MySQL(Dialect):
"SWAPS",
}
+ TRANSACTION_CHARACTERISTICS = {
+ "ISOLATION LEVEL REPEATABLE READ",
+ "ISOLATION LEVEL READ COMMITTED",
+ "ISOLATION LEVEL READ UNCOMMITTED",
+ "ISOLATION LEVEL SERIALIZABLE",
+ "READ WRITE",
+ "READ ONLY",
+ }
+
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
- self._match_text(target)
+ self._match_text_seq(target)
target_id = self._parse_id_var()
else:
target_id = None
- log = self._parse_string() if self._match_text("IN") else None
+ log = self._parse_string() if self._match_text_seq("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
- position = self._parse_number() if self._match_text("FROM") else None
+ position = self._parse_number() if self._match_text_seq("FROM") else None
db = None
else:
position = None
- db = self._parse_id_var() if self._match_text("FROM") else None
+ db = self._parse_id_var() if self._match_text_seq("FROM") else None
- channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
+ channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
- like = self._parse_string() if self._match_text("LIKE") else None
+ like = self._parse_string() if self._match_text_seq("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
- types = self._parse_csv(self._parse_show_profile_type)
- query = self._parse_number() if self._match_text("FOR", "QUERY") else None
- offset = self._parse_number() if self._match_text("OFFSET") else None
- limit = self._parse_number() if self._match_text("LIMIT") else None
+ types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
+ query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
+ offset = self._parse_number() if self._match_text_seq("OFFSET") else None
+ limit = self._parse_number() if self._match_text_seq("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
- mutex = True if self._match_text("MUTEX") else None
- mutex = False if self._match_text("STATUS") else mutex
+ mutex = True if self._match_text_seq("MUTEX") else None
+ mutex = False if self._match_text_seq("STATUS") else mutex
return self.expression(
exp.Show,
@@ -331,16 +349,16 @@ class MySQL(Dialect):
**{"global": global_},
)
- def _parse_show_profile_type(self):
- for type_ in self.PROFILE_TYPES:
- if self._match_text(*type_.split(" ")):
- return exp.Var(this=type_)
+ def _parse_var_from_options(self, options):
+ for option in options:
+ if self._match_text_seq(*option.split(" ")):
+ return exp.Var(this=option)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
- if self._match_text("LIMIT"):
+ if self._match_text_seq("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
@@ -353,6 +371,9 @@ class MySQL(Dialect):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
+ if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
+ return self._parse_set_transaction(global_=kind == "GLOBAL")
+
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
@@ -381,7 +402,7 @@ class MySQL(Dialect):
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
- if self._match_text("COLLATE"):
+ if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
@@ -392,6 +413,18 @@ class MySQL(Dialect):
kind="NAMES",
)
+ def _parse_set_transaction(self, global_=False):
+ self._match_text_seq("TRANSACTION")
+ characteristics = self._parse_csv(
+ lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
+ )
+ return self.expression(
+ exp.SetItem,
+ expressions=characteristics,
+ kind="TRANSACTION",
+ **{"global": global_},
+ )
+
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
@@ -411,6 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.StrPosition: strposition_to_local_sql,
}
ROOT_PROPERTIES = {
@@ -481,9 +515,11 @@ class MySQL(Dialect):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
+ expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
- return f"{kind}{this}{collate}"
+ global_ = "GLOBAL " if expression.args.get("global") else ""
+ return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 3bc1109..870d2b9 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -91,6 +91,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 553a73b..4353164 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -164,11 +164,34 @@ class Postgres(Dialect):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
+
+ CREATABLES = (
+ "AGGREGATE",
+ "CAST",
+ "CONVERSION",
+ "COLLATION",
+ "DEFAULT CONVERSION",
+ "CONSTRAINT",
+ "DOMAIN",
+ "EXTENSION",
+ "FOREIGN",
+ "FUNCTION",
+ "OPERATOR",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SEQUENCE",
+ "TEXT",
+ "TRIGGER",
+ "TYPE",
+ "UNLOGGED",
+ "USER",
+ )
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
- "COMMENT ON": TokenType.COMMENT_ON,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
@@ -176,6 +199,19 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
+ "TEMP": TokenType.TEMPORARY,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
+ "BEGIN": TokenType.COMMAND,
+ "COMMENT ON": TokenType.COMMAND,
+ "DECLARE": TokenType.COMMAND,
+ "DO": TokenType.COMMAND,
+ "REFRESH": TokenType.COMMAND,
+ "REINDEX": TokenType.COMMAND,
+ "RESET": TokenType.COMMAND,
+ "REVOKE": TokenType.COMMAND,
+ "GRANT": TokenType.COMMAND,
+ **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
+ **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 11ea778..9d5cc11 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
+from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
+def _decode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"FROM_UTF8({self.sql(expression, 'this')})"
+
+
+def _encode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"TO_UTF8({self.sql(expression, 'this')})"
+
+
def _no_sort_array(self, expression):
- if expression.args.get("asc") == exp.FALSE:
+ if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
@@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property):
- columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
+ columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
@@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
+def _ensure_utf8(charset):
+ if charset.name.lower() != "utf-8":
+ raise UnsupportedError(f"Unsupported charset {charset}")
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@@ -115,6 +131,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"ROW": TokenType.STRUCT,
}
@@ -140,6 +157,14 @@ class Presto(Dialect):
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "FROM_HEX": exp.Unhex.from_arg_list,
+ "TO_HEX": exp.Hex.from_arg_list,
+ "TO_UTF8": lambda args: exp.Encode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
+ "FROM_UTF8": lambda args: exp.Decode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
}
class Generator(generator.Generator):
@@ -187,7 +212,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
+ exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
+ exp.Encode: _encode_sql,
+ exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@@ -212,7 +240,13 @@ class Presto(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}
+
+ def transaction_sql(self, expression):
+ modes = expression.args.get("modes")
+ modes = f" {', '.join(modes)}" if modes else ""
+ return f"START TRANSACTION{modes}"
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index d1aaded..a96bd80 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -148,6 +148,7 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
+ FUNCTION_PARSERS.pop("TRIM")
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
@@ -203,6 +204,7 @@ class Snowflake(Dialect):
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
+ exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 8c9fb76..87b98a5 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -63,3 +63,8 @@ class SQLite(Dialect):
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}
+
+ def transaction_sql(self, expression):
+ this = expression.this
+ this = f" {this}" if this else ""
+ return f"BEGIN{this} TRANSACTION"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index a233d4b..d3b83de 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -248,7 +248,7 @@ class TSQL(Dialect):
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
- this = self._parse_column()
+ this = self._parse_conjunction()
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 2d959ab..758ad1b 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import typing as t
from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
@@ -6,6 +9,10 @@ from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection
+if t.TYPE_CHECKING:
+ T = t.TypeVar("T")
+ Edit = t.Union[Insert, Remove, Move, Update, Keep]
+
@dataclass(frozen=True)
class Insert:
@@ -44,7 +51,7 @@ class Keep:
target: exp.Expression
-def diff(source, target):
+def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.
@@ -89,25 +96,25 @@ class ChangeDistiller:
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
"""
- def __init__(self, f=0.6, t=0.6):
+ def __init__(self, f: float = 0.6, t: float = 0.6) -> None:
self.f = f
self.t = t
self._sql_generator = Dialect().generator()
- def diff(self, source, target):
+ def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index)
- self._bigram_histo_cache = {}
+ self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set()
return self._generate_edit_script(matching_set)
- def _generate_edit_script(self, matching_set):
- edit_script = []
+ def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
+ edit_script: t.List[Edit] = []
for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id]))
for inserted_node_id in self._unmatched_target_nodes:
@@ -125,7 +132,9 @@ class ChangeDistiller:
return edit_script
- def _generate_move_edits(self, source, target, matching_set):
+ def _generate_move_edits(
+ self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]]
+ ) -> t.List[Move]:
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
@@ -138,7 +147,7 @@ class ChangeDistiller:
return move_edits
- def _compute_matching_set(self):
+ def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]:
leaves_matching_set = self._compute_leaf_matching_set()
matching_set = leaves_matching_set.copy()
@@ -183,8 +192,8 @@ class ChangeDistiller:
return matching_set
- def _compute_leaf_matching_set(self):
- candidate_matchings = []
+ def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
+ candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
@@ -216,7 +225,7 @@ class ChangeDistiller:
return matching_set
- def _dice_coefficient(self, source, target):
+ def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float:
source_histo = self._bigram_histo(source)
target_histo = self._bigram_histo(target)
@@ -231,13 +240,13 @@ class ChangeDistiller:
return 2 * overlap_len / total_grams
- def _bigram_histo(self, expression):
+ def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]:
if id(expression) in self._bigram_histo_cache:
return self._bigram_histo_cache[id(expression)]
expression_str = self._sql_generator.generate(expression)
count = max(0, len(expression_str) - 1)
- bigram_histo = defaultdict(int)
+ bigram_histo: t.DefaultDict[str, int] = defaultdict(int)
for i in range(count):
bigram_histo[expression_str[i : i + 2]] += 1
@@ -245,7 +254,7 @@ class ChangeDistiller:
return bigram_histo
-def _get_leaves(expression):
+def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
has_child_exprs = False
for a in expression.args.values():
@@ -258,7 +267,7 @@ def _get_leaves(expression):
yield expression
-def _is_same_type(source, target):
+def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@@ -271,15 +280,17 @@ def _is_same_type(source, target):
return False
-def _expression_only_args(expression):
- args = []
+def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
+ args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)]
-def _lcs(seq_a, seq_b, equal):
+def _lcs(
+ seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool]
+) -> t.Sequence[t.Optional[T]]:
"""Calculates the longest common subsequence"""
len_a = len(seq_a)
@@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal):
for i in range(len_a + 1):
for j in range(len_b + 1):
if i == 0 or j == 0:
- lcs_result[i][j] = []
+ lcs_result[i][j] = [] # type: ignore
elif equal(seq_a[i - 1], seq_b[j - 1]):
- lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
+ lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore
else:
lcs_result[i][j] = (
lcs_result[i - 1][j]
- if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
+ if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore
else lcs_result[i][j - 1]
)
- return lcs_result[len_a][len_b]
+ return lcs_result[len_a][len_b] # type: ignore
diff --git a/sqlglot/errors.py b/sqlglot/errors.py
index 2ef908f..23a08bd 100644
--- a/sqlglot/errors.py
+++ b/sqlglot/errors.py
@@ -37,6 +37,10 @@ class SchemaError(SqlglotError):
pass
+class ExecuteError(SqlglotError):
+ pass
+
+
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index e765616..04621b5 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -1,20 +1,23 @@
import logging
import time
-from sqlglot import parse_one
+from sqlglot import maybe_parse
+from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
+from sqlglot.executor.table import Table, ensure_tables
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
+from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot")
-def execute(sql, schema, read=None):
+def execute(sql, schema=None, read=None, tables=None):
"""
Run a sql query against data.
Args:
- sql (str): a sql statement
+ sql (str|sqlglot.Expression): a sql statement
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
@@ -23,10 +26,20 @@ def execute(sql, schema, read=None):
3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
+ tables (dict): additional tables to register.
Returns:
sqlglot.executor.Table: Simple columnar data structure.
"""
- expression = parse_one(sql, read=read)
+ tables = ensure_tables(tables)
+ if not schema:
+ schema = {
+ name: {column: type(table[0][column]).__name__ for column in table.columns}
+ for name, table in tables.mapping.items()
+ }
+ schema = ensure_schema(schema)
+ if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
+ raise ExecuteError("Tables must support the same table args as schema")
+ expression = maybe_parse(sql, dialect=read)
now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now)
@@ -34,6 +47,6 @@ def execute(sql, schema, read=None):
plan = Plan(expression)
logger.debug("Logical Plan: %s", plan)
now = time.time()
- result = PythonExecutor().execute(plan)
+ result = PythonExecutor(tables=tables).execute(plan)
logger.debug("Query finished: %f", time.time() - now)
return result
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 393347b..e9ff75b 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -1,5 +1,12 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot.executor.env import ENV
+if t.TYPE_CHECKING:
+ from sqlglot.executor.table import Table, TableIter
+
class Context:
"""
@@ -12,14 +19,14 @@ class Context:
evaluation of aggregation functions.
"""
- def __init__(self, tables, env=None):
+ def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
"""
Args
- tables (dict): table_name -> Table, representing the scope of the current execution context
- env (Optional[dict]): dictionary of functions within the execution context
+ tables: representing the scope of the current execution context.
+ env: dictionary of functions within the execution context.
"""
self.tables = tables
- self._table = None
+ self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@@ -31,7 +38,7 @@ class Context:
return tuple(self.eval(code) for code in codes)
@property
- def table(self):
+ def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
@@ -41,8 +48,12 @@ class Context:
raise Exception(f"Rows are different.")
return self._table
+ def add_columns(self, *columns: str) -> None:
+ for table in self.tables.values():
+ table.add_columns(*columns)
+
@property
- def columns(self):
+ def columns(self) -> t.Tuple:
return self.table.columns
def __iter__(self):
@@ -52,35 +63,39 @@ class Context:
reader = table[i]
yield reader, self
- def table_iter(self, table):
+ def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
- def sort(self, key):
- table = self.table
+ def filter(self, condition) -> None:
+ rows = [reader.row for reader, _ in self if self.eval(condition)]
- def sort_key(row):
- table.reader.row = row
+ for table in self.tables.values():
+ table.rows = rows
+
+ def sort(self, key) -> None:
+ def sort_key(row: t.Tuple) -> t.Tuple:
+ self.set_row(row)
return self.eval_tuple(key)
- table.rows.sort(key=sort_key)
+ self.table.rows.sort(key=sort_key)
- def set_row(self, row):
+ def set_row(self, row: t.Tuple) -> None:
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers
- def set_index(self, index):
+ def set_index(self, index: int) -> None:
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers
- def set_range(self, start, end):
+ def set_range(self, start: int, end: int) -> None:
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
- def __contains__(self, table):
+ def __contains__(self, table: str) -> bool:
return table in self.tables
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index bbe6c81..ed80cc9 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -1,7 +1,10 @@
import datetime
+import inspect
import re
import statistics
+from functools import wraps
+from sqlglot import exp
from sqlglot.helper import PYTHON_VERSION
@@ -16,20 +19,153 @@ class reverse_key:
return other.obj < self.obj
+def filter_nulls(func):
+ @wraps(func)
+ def _func(values):
+ return func(v for v in values if v is not None)
+
+ return _func
+
+
+def null_if_any(*required):
+ """
+ Decorator that makes a function return `None` if any of the `required` arguments are `None`.
+
+ This also supports decoration with no arguments, e.g.:
+
+ @null_if_any
+ def foo(a, b): ...
+
+ In which case all arguments are required.
+ """
+ f = None
+ if len(required) == 1 and callable(required[0]):
+ f = required[0]
+ required = ()
+
+ def decorator(func):
+ if required:
+ required_indices = [
+ i for i, param in enumerate(inspect.signature(func).parameters) if param in required
+ ]
+
+ def predicate(*args):
+ return any(args[i] is None for i in required_indices)
+
+ else:
+
+ def predicate(*args):
+ return any(a is None for a in args)
+
+ @wraps(func)
+ def _func(*args):
+ if predicate(*args):
+ return None
+ return func(*args)
+
+ return _func
+
+ if f:
+ return decorator(f)
+
+ return decorator
+
+
+@null_if_any("substr", "this")
+def str_position(substr, this, position=None):
+ position = position - 1 if position is not None else position
+ return this.find(substr, position) + 1
+
+
+@null_if_any("this")
+def substring(this, start=None, length=None):
+ if start is None:
+ return this
+ elif start == 0:
+ return ""
+ elif start < 0:
+ start = len(this) + start
+ else:
+ start -= 1
+
+ end = None if length is None else start + length
+
+ return this[start:end]
+
+
+@null_if_any
+def cast(this, to):
+ if to == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(this)
+ if to == exp.DataType.Type.DATETIME:
+ return datetime.datetime.fromisoformat(this)
+ if to in exp.DataType.TEXT_TYPES:
+ return str(this)
+ if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
+ return float(this)
+ if to in exp.DataType.NUMERIC_TYPES:
+ return int(this)
+ raise NotImplementedError(f"Casting to '{to}' not implemented.")
+
+
+def ordered(this, desc, nulls_first):
+ if desc:
+ return reverse_key(this)
+ return this
+
+
+@null_if_any
+def interval(this, unit):
+ if unit == "DAY":
+ return datetime.timedelta(days=float(this))
+ raise NotImplementedError
+
+
ENV = {
"__builtins__": {},
- "datetime": datetime,
- "locals": locals,
- "re": re,
- "bool": bool,
- "float": float,
- "int": int,
- "str": str,
- "desc": reverse_key,
- "SUM": sum,
- "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
- "COUNT": lambda acc: sum(1 for e in acc if e is not None),
- "MAX": max,
- "MIN": min,
+ "exp": exp,
+ # aggs
+ "SUM": filter_nulls(sum),
+ "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
+ "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
+ "MAX": filter_nulls(max),
+ "MIN": filter_nulls(min),
+ # scalar functions
+ "ABS": null_if_any(lambda this: abs(this)),
+ "ADD": null_if_any(lambda e, this: e + this),
+ "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
+ "BITWISEAND": null_if_any(lambda this, e: this & e),
+ "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
+ "BITWISEOR": null_if_any(lambda this, e: this | e),
+ "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
+ "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
+ "CAST": cast,
+ "COALESCE": lambda *args: next((a for a in args if a is not None), None),
+ "CONCAT": null_if_any(lambda *args: "".join(args)),
+ "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
+ "DIV": null_if_any(lambda e, this: e / this),
+ "EQ": null_if_any(lambda this, e: this == e),
+ "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
+ "GT": null_if_any(lambda this, e: this > e),
+ "GTE": null_if_any(lambda this, e: this >= e),
+ "IFNULL": lambda e, alt: alt if e is None else e,
+ "IF": lambda predicate, true, false: true if predicate else false,
+ "INTDIV": null_if_any(lambda e, this: e // this),
+ "INTERVAL": interval,
+ "LIKE": null_if_any(
+ lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
+ ),
+ "LOWER": null_if_any(lambda arg: arg.lower()),
+ "LT": null_if_any(lambda this, e: this < e),
+ "LTE": null_if_any(lambda this, e: this <= e),
+ "MOD": null_if_any(lambda e, this: e % this),
+ "MUL": null_if_any(lambda e, this: e * this),
+ "NEQ": null_if_any(lambda this, e: this != e),
+ "ORD": null_if_any(ord),
+ "ORDERED": ordered,
"POW": pow,
+ "STRPOSITION": str_position,
+ "SUB": null_if_any(lambda e, this: e - this),
+ "SUBSTRING": substring,
+ "UPPER": null_if_any(lambda arg: arg.upper()),
}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 7d1db32..cb2543c 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -5,16 +5,18 @@ import math
from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
+from sqlglot.errors import ExecuteError
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
-from sqlglot.executor.table import Table
-from sqlglot.helper import csv_reader
+from sqlglot.executor.table import RowReader, Table
+from sqlglot.helper import csv_reader, subclasses
class PythonExecutor:
- def __init__(self, env=None):
- self.generator = Python().generator(identify=True)
+ def __init__(self, env=None, tables=None):
+ self.generator = Python().generator(identify=True, comments=False)
self.env = {**ENV, **(env or {})}
+ self.tables = tables or {}
def execute(self, plan):
running = set()
@@ -24,36 +26,41 @@ class PythonExecutor:
while queue:
node = queue.pop()
- context = self.context(
- {
- name: table
- for dep in node.dependencies
- for name, table in contexts[dep].tables.items()
- }
- )
- running.add(node)
-
- if isinstance(node, planner.Scan):
- contexts[node] = self.scan(node, context)
- elif isinstance(node, planner.Aggregate):
- contexts[node] = self.aggregate(node, context)
- elif isinstance(node, planner.Join):
- contexts[node] = self.join(node, context)
- elif isinstance(node, planner.Sort):
- contexts[node] = self.sort(node, context)
- else:
- raise NotImplementedError
-
- running.remove(node)
- finished.add(node)
-
- for dep in node.dependents:
- if dep not in running and all(d in contexts for d in dep.dependencies):
- queue.add(dep)
-
- for dep in node.dependencies:
- if all(d in finished for d in dep.dependents):
- contexts.pop(dep)
+ try:
+ context = self.context(
+ {
+ name: table
+ for dep in node.dependencies
+ for name, table in contexts[dep].tables.items()
+ }
+ )
+ running.add(node)
+
+ if isinstance(node, planner.Scan):
+ contexts[node] = self.scan(node, context)
+ elif isinstance(node, planner.Aggregate):
+ contexts[node] = self.aggregate(node, context)
+ elif isinstance(node, planner.Join):
+ contexts[node] = self.join(node, context)
+ elif isinstance(node, planner.Sort):
+ contexts[node] = self.sort(node, context)
+ elif isinstance(node, planner.SetOperation):
+ contexts[node] = self.set_operation(node, context)
+ else:
+ raise NotImplementedError
+
+ running.remove(node)
+ finished.add(node)
+
+ for dep in node.dependents:
+ if dep not in running and all(d in contexts for d in dep.dependencies):
+ queue.add(dep)
+
+ for dep in node.dependencies:
+ if all(d in finished for d in dep.dependents):
+ contexts.pop(dep)
+ except Exception as e:
+ raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
root = plan.root
return contexts[root].tables[root.name]
@@ -76,38 +83,43 @@ class PythonExecutor:
return Context(tables, env=self.env)
def table(self, expressions):
- return Table(expression.alias_or_name for expression in expressions)
+ return Table(
+ expression.alias_or_name if isinstance(expression, exp.Expression) else expression
+ for expression in expressions
+ )
def scan(self, step, context):
source = step.source
- if isinstance(source, exp.Expression):
+ if source and isinstance(source, exp.Expression):
source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
- if source in context:
+ if source is None:
+ context, table_iter = self.static()
+ elif source in context:
if not projections and not condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
- else:
+ elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
table_iter = self.scan_csv(step)
+ context = next(table_iter)
+ else:
+ context, table_iter = self.scan_table(step)
if projections:
sink = self.table(step.projections)
else:
- sink = None
-
- for reader, ctx in table_iter:
- if sink is None:
- sink = Table(reader.columns)
+ sink = self.table(context.columns)
- if condition and not ctx.eval(condition):
+ for reader in table_iter:
+ if condition and not context.eval(condition):
continue
if projections:
- sink.append(ctx.eval_tuple(projections))
+ sink.append(context.eval_tuple(projections))
else:
sink.append(reader.row)
@@ -116,14 +128,23 @@ class PythonExecutor:
return self.context({step.name: sink})
+ def static(self):
+ return self.context({}), [RowReader(())]
+
+ def scan_table(self, step):
+ table = self.tables.find(step.source)
+ context = self.context({step.source.alias_or_name: table})
+ return context, iter(table)
+
def scan_csv(self, step):
- source = step.source
- alias = source.alias
+ alias = step.source.alias
+ source = step.source.this
with csv_reader(source) as reader:
columns = next(reader)
table = Table(columns)
context = self.context({alias: table})
+ yield context
types = []
for row in reader:
@@ -134,7 +155,7 @@ class PythonExecutor:
except (ValueError, SyntaxError):
types.append(str)
context.set_row(tuple(t(v) for t, v in zip(types, row)))
- yield context.table.reader, context
+ yield context.table.reader
def join(self, step, context):
source = step.name
@@ -160,16 +181,19 @@ class PythonExecutor:
for name, column_range in column_ranges.items()
}
)
+ condition = self.generate(join["condition"])
+ if condition:
+ source_context.filter(condition)
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
- if not condition or not projections:
+ if not condition and not projections:
return source_context
sink = self.table(step.projections if projections else source_context.columns)
- for reader, ctx in join_context:
+ for reader, ctx in source_context:
if condition and not ctx.eval(condition):
continue
@@ -181,7 +205,15 @@ class PythonExecutor:
if len(sink) >= step.limit:
break
- return self.context({step.name: sink})
+ if projections:
+ return self.context({step.name: sink})
+ else:
+ return self.context(
+ {
+ name: Table(table.columns, sink.rows, table.column_range)
+ for name, table in source_context.tables.items()
+ }
+ )
def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns)
@@ -195,6 +227,8 @@ class PythonExecutor:
def hash_join(self, join, source_context, join_context):
source_key = self.generate_tuple(join["source_key"])
join_key = self.generate_tuple(join["join_key"])
+ left = join.get("side") == "LEFT"
+ right = join.get("side") == "RIGHT"
results = collections.defaultdict(lambda: ([], []))
@@ -204,28 +238,47 @@ class PythonExecutor:
results[ctx.eval_tuple(join_key)][1].append(reader.row)
table = Table(source_context.columns + join_context.columns)
+ nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
for a_group, b_group in results.values():
+ if left:
+ b_group = b_group or nulls
+ elif right:
+ a_group = a_group or nulls
+
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def aggregate(self, step, context):
- source = step.source
- group_by = self.generate_tuple(step.group)
+ group_by = self.generate_tuple(step.group.values())
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
if operands:
- source_table = context.tables[source]
- operand_table = Table(source_table.columns + self.table(step.operands).columns)
+ operand_table = Table(self.table(step.operands).columns)
for reader, ctx in context:
- operand_table.append(reader.row + ctx.eval_tuple(operands))
+ operand_table.append(ctx.eval_tuple(operands))
+
+ for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
+ context.table.rows[i] = a + b
+
+ width = len(context.columns)
+ context.add_columns(*operand_table.columns)
+
+ operand_table = Table(
+ context.columns,
+ context.table.rows,
+ range(width, width + len(operand_table.columns)),
+ )
context = self.context(
- {None: operand_table, **{table: operand_table for table in context.tables}}
+ {
+ None: operand_table,
+ **context.tables,
+ }
)
context.sort(group_by)
@@ -233,25 +286,22 @@ class PythonExecutor:
group = None
start = 0
end = 1
- length = len(context.tables[source])
- table = self.table(step.group + step.aggregations)
+ length = len(context.table)
+ table = self.table(list(step.group) + step.aggregations)
for i in range(length):
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
-
+ if key != group:
+ context.set_range(start, end - 2)
+ table.append(group + context.eval_tuple(aggregations))
+ group = key
+ start = end - 2
if i == length - 1:
context.set_range(start, end - 1)
- elif key != group:
- context.set_range(start, end - 2)
- else:
- continue
-
- table.append(group + context.eval_tuple(aggregations))
- group = key
- start = end - 2
+ table.append(group + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
@@ -262,60 +312,77 @@ class PythonExecutor:
def sort(self, step, context):
projections = self.generate_tuple(step.projections)
- sink = self.table(step.projections)
+ projection_columns = [p.alias_or_name for p in step.projections]
+ all_columns = list(context.columns) + projection_columns
+ sink = self.table(all_columns)
for reader, ctx in context:
- sink.append(ctx.eval_tuple(projections))
+ sink.append(reader.row + ctx.eval_tuple(projections))
- context = self.context(
+ sort_ctx = self.context(
{
None: sink,
**{table: sink for table in context.tables},
}
)
- context.sort(self.generate_tuple(step.key))
+ sort_ctx.sort(self.generate_tuple(step.key))
if not math.isinf(step.limit):
- context.table.rows = context.table.rows[0 : step.limit]
+ sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
- return self.context({step.name: context.table})
+ output = Table(
+ projection_columns,
+ rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
+ )
+ return self.context({step.name: output})
+ def set_operation(self, step, context):
+ left = context.tables[step.left]
+ right = context.tables[step.right]
-def _cast_py(self, expression):
- to = expression.args["to"].this
- this = self.sql(expression, "this")
+ sink = self.table(left.columns)
+
+ if issubclass(step.op, exp.Intersect):
+ sink.rows = list(set(left.rows).intersection(set(right.rows)))
+ elif issubclass(step.op, exp.Except):
+ sink.rows = list(set(left.rows).difference(set(right.rows)))
+ elif issubclass(step.op, exp.Union) and step.distinct:
+ sink.rows = list(set(left.rows).union(set(right.rows)))
+ else:
+ sink.rows = left.rows + right.rows
- if to == exp.DataType.Type.DATE:
- return f"datetime.date.fromisoformat({this})"
- if to == exp.DataType.Type.TEXT:
- return f"str({this})"
- raise NotImplementedError
+ return self.context({step.name: sink})
-def _column_py(self, expression):
- table = self.sql(expression, "table") or None
+def _ordered_py(self, expression):
this = self.sql(expression, "this")
- return f"scope[{table}][{this}]"
+ desc = "True" if expression.args.get("desc") else "False"
+ nulls_first = "True" if expression.args.get("nulls_first") else "False"
+ return f"ORDERED({this}, {desc}, {nulls_first})"
-def _interval_py(self, expression):
- this = self.sql(expression, "this")
- unit = expression.text("unit").upper()
- if unit == "DAY":
- return f"datetime.timedelta(days=float({this}))"
- raise NotImplementedError
+def _rename(self, e):
+ try:
+ if "expressions" in e.args:
+ this = self.sql(e, "this")
+ this = f"{this}, " if this else ""
+ return f"{e.key.upper()}({this}{self.expressions(e)})"
+ return f"{e.key.upper()}({self.format_args(*e.args.values())})"
+ except Exception as ex:
+ raise Exception(f"Could not rename {repr(e)}") from ex
-def _like_py(self, expression):
+def _case_sql(self, expression):
this = self.sql(expression, "this")
- expression = self.sql(expression, "expression")
- return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
+ chain = self.sql(expression, "default") or "None"
+ for e in reversed(expression.args["ifs"]):
+ true = self.sql(e, "true")
+ condition = self.sql(e, "this")
+ condition = f"{this} = ({condition})" if this else condition
+ chain = f"{true} if {condition} else ({chain})"
-def _ordered_py(self, expression):
- this = self.sql(expression, "this")
- desc = expression.args.get("desc")
- return f"desc({this})" if desc else this
+ return chain
class Python(Dialect):
@@ -324,32 +391,22 @@ class Python(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
+ **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
+ **{klass: _rename for klass in exp.ALL_FUNCTIONS},
+ exp.Case: _case_sql,
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
+ exp.Between: _rename,
exp.Boolean: lambda self, e: "True" if e.this else "False",
- exp.Cast: _cast_py,
- exp.Column: _column_py,
- exp.EQ: lambda self, e: self.binary(e, "=="),
+ exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
+ exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
+ exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
- exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
- exp.Like: _like_py,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}
-
- def case_sql(self, expression):
- this = self.sql(expression, "this")
- chain = self.sql(expression, "default") or "None"
-
- for e in reversed(expression.args["ifs"]):
- true = self.sql(e, "true")
- condition = self.sql(e, "this")
- condition = f"{this} = ({condition})" if this else condition
- chain = f"{true} if {condition} else ({chain})"
-
- return chain
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 6796740..f1b5b54 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -1,14 +1,27 @@
+from __future__ import annotations
+
+from sqlglot.helper import dict_depth
+from sqlglot.schema import AbstractMappingSchema
+
+
class Table:
def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)
-
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.range_reader = RangeReader(self)
+ def add_columns(self, *columns: str) -> None:
+ self.columns += columns
+ if self.column_range:
+ self.column_range = range(
+ self.column_range.start, self.column_range.stop + len(columns)
+ )
+ self.reader = RowReader(self.columns, self.column_range)
+
def append(self, row):
assert len(row) == len(self.columns)
self.rows.append(row)
@@ -87,3 +100,31 @@ class RowReader:
def __getitem__(self, column):
return self.row[self.columns[column]]
+
+
+class Tables(AbstractMappingSchema[Table]):
+ pass
+
+
+def ensure_tables(d: dict | None) -> Tables:
+ return Tables(_ensure_tables(d))
+
+
+def _ensure_tables(d: dict | None) -> dict:
+ if not d:
+ return {}
+
+ depth = dict_depth(d)
+
+ if depth > 1:
+ return {k: _ensure_tables(v) for k, v in d.items()}
+
+ result = {}
+ for name, table in d.items():
+ if isinstance(table, Table):
+ result[name] = table
+ else:
+ columns = tuple(table[0]) if table else ()
+ rows = [tuple(row[c] for c in columns) for row in table]
+ result[name] = Table(columns=columns, rows=rows)
+ return result
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 57a2c88..beafca8 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -641,9 +641,11 @@ class Set(Expression):
class SetItem(Expression):
arg_types = {
- "this": True,
+ "this": False,
+ "expressions": False,
"kind": False,
"collate": False, # MySQL SET NAMES statement
+ "global": False,
}
@@ -787,6 +789,7 @@ class Drop(Expression):
"exists": False,
"temporary": False,
"materialized": False,
+ "cascade": False,
}
@@ -1073,6 +1076,18 @@ class FileFormatProperty(Property):
pass
+class DistKeyProperty(Property):
+ pass
+
+
+class SortKeyProperty(Property):
+ pass
+
+
+class DistStyleProperty(Property):
+ pass
+
+
class LocationProperty(Property):
pass
@@ -1130,6 +1145,9 @@ class Properties(Expression):
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
+ "DISTKEY": DistKeyProperty,
+ "DISTSTYLE": DistStyleProperty,
+ "SORTKEY": SortKeyProperty,
}
@classmethod
@@ -1356,7 +1374,7 @@ class Var(Expression):
class Schema(Expression):
- arg_types = {"this": False, "expressions": True}
+ arg_types = {"this": False, "expressions": False}
class Select(Subqueryable):
@@ -1741,7 +1759,7 @@ class Select(Subqueryable):
)
if join_alias:
- join.set("this", alias_(join.args["this"], join_alias, table=True))
+ join.set("this", alias_(join.this, join_alias, table=True))
return _apply_list_builder(
join,
instance=self,
@@ -1884,6 +1902,7 @@ class Subquery(DerivedTable, Unionable):
arg_types = {
"this": True,
"alias": False,
+ "with": False,
**QUERY_MODIFIERS,
}
@@ -2025,6 +2044,31 @@ class DataType(Expression):
NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
+ TEXT_TYPES = {
+ Type.CHAR,
+ Type.NCHAR,
+ Type.VARCHAR,
+ Type.NVARCHAR,
+ Type.TEXT,
+ }
+
+ NUMERIC_TYPES = {
+ Type.INT,
+ Type.TINYINT,
+ Type.SMALLINT,
+ Type.BIGINT,
+ Type.FLOAT,
+ Type.DOUBLE,
+ }
+
+ TEMPORAL_TYPES = {
+ Type.TIMESTAMP,
+ Type.TIMESTAMPTZ,
+ Type.TIMESTAMPLTZ,
+ Type.DATE,
+ Type.DATETIME,
+ }
+
@classmethod
def build(cls, dtype, **kwargs) -> DataType:
return DataType(
@@ -2054,16 +2098,25 @@ class Exists(SubqueryPredicate):
pass
-# Commands to interact with the databases or engines
-# These expressions don't truly parse the expression and consume
-# whatever exists as a string until the end or a semicolon
+# Commands to interact with the databases or engines. For most of the command
+# expressions we parse whatever comes after the command's name as a string.
class Command(Expression):
arg_types = {"this": True, "expression": False}
-# Binary Expressions
-# (ADD a b)
-# (FROM table selects)
+class Transaction(Command):
+ arg_types = {"this": False, "modes": False}
+
+
+class Commit(Command):
+ arg_types = {} # type: ignore
+
+
+class Rollback(Command):
+ arg_types = {"savepoint": False}
+
+
+# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@@ -2215,7 +2268,7 @@ class Not(Unary, Condition):
class Paren(Unary, Condition):
- pass
+ arg_types = {"this": True, "with": False}
class Neg(Unary):
@@ -2428,6 +2481,10 @@ class Cast(Func):
return self.args["to"]
+class Collate(Binary):
+ pass
+
+
class TryCast(Cast):
pass
@@ -2442,13 +2499,17 @@ class Coalesce(Func):
is_var_len_args = True
-class ConcatWs(Func):
- arg_types = {"expressions": False}
+class Concat(Func):
+ arg_types = {"expressions": True}
is_var_len_args = True
+class ConcatWs(Concat):
+ _sql_names = ["CONCAT_WS"]
+
+
class Count(AggFunc):
- pass
+ arg_types = {"this": False}
class CurrentDate(Func):
@@ -2556,10 +2617,18 @@ class Day(Func):
pass
+class Decode(Func):
+ arg_types = {"this": True, "charset": True}
+
+
class DiToDate(Func):
pass
+class Encode(Func):
+ arg_types = {"this": True, "charset": True}
+
+
class Exp(Func):
pass
@@ -2581,6 +2650,10 @@ class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
+class Hex(Func):
+ pass
+
+
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -2641,7 +2714,7 @@ class Log10(Func):
class Lower(Func):
- pass
+ _sql_names = ["LOWER", "LCASE"]
class Map(Func):
@@ -2686,6 +2759,12 @@ class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
+class ReadCSV(Func):
+ _sql_names = ["READ_CSV"]
+ is_var_len_args = True
+ arg_types = {"this": True, "expressions": False}
+
+
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@@ -2804,8 +2883,8 @@ class TimeStrToUnix(Func):
class Trim(Func):
arg_types = {
"this": True,
- "position": False,
"expression": False,
+ "position": False,
"collation": False,
}
@@ -2826,6 +2905,10 @@ class TsOrDiToDi(Func):
pass
+class Unhex(Func):
+ pass
+
+
class UnixToStr(Func):
arg_types = {"this": True, "format": False}
@@ -2843,7 +2926,7 @@ class UnixToTimeStr(Func):
class Upper(Func):
- pass
+ _sql_names = ["UPPER", "UCASE"]
class Variance(AggFunc):
@@ -3701,6 +3784,19 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
+def true():
+ return Boolean(this=True)
+
+
+def false():
+ return Boolean(this=False)
+
+
+def null():
+ return Null()
+
+
+# TODO: deprecate this
TRUE = Boolean(this=True)
FALSE = Boolean(this=False)
NULL = Null()
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 11d9073..ffb34eb 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -67,7 +67,7 @@ class Generator:
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
- exp.VolatilityProperty: lambda self, e: self.sql(e.name),
+ exp.VolatilityProperty: lambda self, e: e.name,
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@@ -94,6 +94,9 @@ class Generator:
ROOT_PROPERTIES = {
exp.ReturnsProperty,
exp.LanguageProperty,
+ exp.DistStyleProperty,
+ exp.DistKeyProperty,
+ exp.SortKeyProperty,
}
WITH_PROPERTIES = {
@@ -241,7 +244,7 @@ class Generator:
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
- return f"/*{comment}*/\n{sql}"
+ return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
def wrap(self, expression):
this_sql = self.indent(
@@ -475,7 +478,8 @@ class Generator:
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
- return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
+ cascade = " CASCADE" if expression.args.get("cascade") else ""
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
def except_sql(self, expression):
return self.prepend_ctes(
@@ -915,13 +919,15 @@ class Generator:
def subquery_sql(self, expression):
alias = self.sql(expression, "alias")
- return self.query_modifiers(
+ sql = self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "",
)
+ return self.prepend_ctes(expression, sql)
+
def qualify_sql(self, expression):
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
@@ -1111,9 +1117,12 @@ class Generator:
def paren_sql(self, expression):
if isinstance(expression.unnest(), exp.Select):
- return self.wrap(expression)
- sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
- return f"({sql}{self.seg(')', sep='')}"
+ sql = self.wrap(expression)
+ else:
+ sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
+ sql = f"({sql}{self.seg(')', sep='')}"
+
+ return self.prepend_ctes(expression, sql)
def neg_sql(self, expression):
return f"-{self.sql(expression, 'this')}"
@@ -1173,9 +1182,23 @@ class Generator:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
+ def collate_sql(self, expression):
+ return self.binary(expression, "COLLATE")
+
def command_sql(self, expression):
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
+ def transaction_sql(self, *_):
+ return "BEGIN"
+
+ def commit_sql(self, *_):
+ return "COMMIT"
+
+ def rollback_sql(self, expression):
+ savepoint = expression.args.get("savepoint")
+ savepoint = f" TO {savepoint}" if savepoint else ""
+ return f"ROLLBACK{savepoint}"
+
def distinct_sql(self, expression):
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1193,10 +1216,7 @@ class Generator:
def intdiv_sql(self, expression):
return self.sql(
exp.Cast(
- this=exp.Div(
- this=expression.args["this"],
- expression=expression.args["expression"],
- ),
+ this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 379c2e7..8c5808d 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -11,7 +11,8 @@ from copy import copy
from enum import Enum
if t.TYPE_CHECKING:
- from sqlglot.expressions import Expression, Table
+ from sqlglot import exp
+ from sqlglot.expressions import Expression
T = t.TypeVar("T")
E = t.TypeVar("E", bound=Expression)
@@ -150,7 +151,7 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
if expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
- expression.args["this"] = str(int(expression.args["this"]) + offset)
+ expression.args["this"] = str(int(expression.this) + offset)
return [expression]
return expressions
@@ -228,19 +229,18 @@ def open_file(file_name: str) -> t.TextIO:
@contextmanager
-def csv_reader(table: Table) -> t.Any:
+def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
"""
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
- table: a `Table` expression with an anonymous function `READ_CSV` in it.
+ read_csv: a `ReadCSV` function call
Yields:
A python csv reader.
"""
- file, *args = table.this.expressions
- file = file.name
- file = open_file(file)
+ args = read_csv.expressions
+ file = open_file(read_csv.name)
delimiter = ","
args = iter(arg.name for arg in args)
@@ -354,3 +354,34 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield from flatten(value)
else:
yield value
+
+
+def dict_depth(d: t.Dict) -> int:
+ """
+ Get the nesting depth of a dictionary.
+
+ For example:
+ >>> dict_depth(None)
+ 0
+ >>> dict_depth({})
+ 1
+ >>> dict_depth({"a": "b"})
+ 1
+ >>> dict_depth({"a": {}})
+ 2
+ >>> dict_depth({"a": {"b": {}}})
+ 3
+
+ Args:
+ d (dict): dictionary
+ Returns:
+ int: depth
+ """
+ try:
+ return 1 + dict_depth(next(iter(d.values())))
+ except AttributeError:
+ # d doesn't have attribute "values"
+ return 0
+ except StopIteration:
+ # d.values() returns an empty sequence
+ return 1
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 96331e2..191ea52 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -245,23 +245,31 @@ class TypeAnnotator:
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
- subscope_selects = {
- name: {select.alias_or_name: select for select in source.selects}
- for name, source in scope.sources.items()
- if isinstance(source, Scope)
- }
-
+ selects = {}
+ for name, source in scope.sources.items():
+ if not isinstance(source, Scope):
+ continue
+ if isinstance(source.expression, exp.Values):
+ selects[name] = {
+ alias: column
+ for alias, column in zip(
+ source.expression.alias_column_names,
+ source.expression.expressions[0].expressions,
+ )
+ }
+ else:
+ selects[name] = {
+ select.alias_or_name: select for select in source.expression.selects
+ }
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
- col.type = subscope_selects[col.table][col.name].type
-
+ col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
-
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
new file mode 100644
index 0000000..9b3d98a
--- /dev/null
+++ b/sqlglot/optimizer/canonicalize.py
@@ -0,0 +1,48 @@
+import itertools
+
+from sqlglot import exp
+
+
+def canonicalize(expression: exp.Expression) -> exp.Expression:
+ """Converts a sql expression into a standard form.
+
+ This method relies on annotate_types because many of the
+ conversions rely on type inference.
+
+ Args:
+ expression: The expression to canonicalize.
+ """
+ exp.replace_children(expression, canonicalize)
+ expression = add_text_to_concat(expression)
+ expression = coerce_type(expression)
+ return expression
+
+
+def add_text_to_concat(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
+ node = exp.Concat(this=node.this, expression=node.expression)
+ return node
+
+
+def coerce_type(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Binary):
+ _coerce_date(node.left, node.right)
+ elif isinstance(node, exp.Between):
+ _coerce_date(node.this, node.args["low"])
+ elif isinstance(node, exp.Extract):
+ if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
+ _replace_cast(node.expression, "datetime")
+ return node
+
+
+def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
+ for a, b in itertools.permutations([a, b]):
+ if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
+ _replace_cast(b, "date")
+
+
+def _replace_cast(node: exp.Expression, to: str) -> None:
+ data_type = exp.DataType.build(to)
+ cast = exp.Cast(this=node.copy(), to=data_type)
+ cast.type = data_type
+ node.replace(cast)
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 29621af..de4e011 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -128,8 +128,8 @@ def join_condition(join):
Tuple of (source key, join key, remaining predicate)
"""
name = join.this.alias_or_name
- on = join.args.get("on") or exp.TRUE
- on = on.copy()
+ on = (join.args.get("on") or exp.true()).copy()
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
@@ -141,7 +141,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
- for condition in on.flatten() if isinstance(on, exp.And) else [on]:
+ for condition in on.flatten():
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
@@ -150,13 +150,12 @@ def join_condition(join):
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
- condition.replace(exp.TRUE)
+ condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
- condition.replace(exp.TRUE)
+ condition.replace(exp.true())
on = simplify(on)
- remaining_condition = None if on == exp.TRUE else on
-
+ remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 40e4ab1..fd69832 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -29,7 +29,7 @@ def optimize_joins(expression):
if isinstance(on, exp.Connector):
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
join.on(predicate, copy=False)
expression = reorder_joins(expression)
@@ -70,6 +70,6 @@ def normalize(expression):
def other_table_names(join, exclude):
return [
name
- for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
+ for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude
]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index b2ed062..d0e38cd 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,4 +1,6 @@
import sqlglot
+from sqlglot.optimizer.annotate_types import annotate_types
+from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@@ -28,6 +30,8 @@ RULES = (
merge_subqueries,
eliminate_joins,
eliminate_ctes,
+ annotate_types,
+ canonicalize,
quote_identities,
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 6364f65..f92e5c3 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 69fe2b8..e6e6dc9 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -382,9 +382,7 @@ class _Resolver:
raise OptimizeError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
- values_alias = source.expression.parent
- if hasattr(values_alias, "alias_column_names"):
- return values_alias.alias_column_names
+ return source.expression.alias_column_names
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 0e467d3..5d8e0d9 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,10 +1,11 @@
import itertools
from sqlglot import alias, exp
+from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope
-def qualify_tables(expression, db=None, catalog=None):
+def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
@@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
+ schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
"""
@@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:
- source.replace(
+ source = source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
@@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
)
)
+ if schema and isinstance(source.this, exp.ReadCSV):
+ with csv_reader(source.this) as reader:
+ header = next(reader)
+ columns = next(reader)
+ schema.add_table(
+ source, {k: type(v).__name__ for k, v in zip(header, columns)}
+ )
+
return expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d759e86..c432c59 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
# absorb
if is_complement(b, aa):
- aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
- ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
- a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
+ a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index f41a84e..dbd680b 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
select.parent.replace(alias)
for key, column, predicate in keys:
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if key in group_by:
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index bbea0e5..5b93510 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -141,26 +141,29 @@ class Parser(metaclass=_Parser):
ID_VAR_TOKENS = {
TokenType.VAR,
- TokenType.ALTER,
TokenType.ALWAYS,
TokenType.ANTI,
TokenType.APPLY,
+ TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
TokenType.BOTH,
TokenType.BUCKET,
TokenType.CACHE,
- TokenType.CALL,
+ TokenType.CASCADE,
TokenType.COLLATE,
+ TokenType.COMMAND,
TokenType.COMMIT,
TokenType.CONSTRAINT,
+ TokenType.CURRENT_TIME,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.DESCRIBE,
TokenType.DETERMINISTIC,
+ TokenType.DISTKEY,
+ TokenType.DISTSTYLE,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
- TokenType.EXPLAIN,
TokenType.FALSE,
TokenType.FIRST,
TokenType.FOLLOWING,
@@ -182,7 +185,6 @@ class Parser(metaclass=_Parser):
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
- TokenType.OPTIMIZE,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.PARTITIONED_BY,
@@ -199,6 +201,7 @@ class Parser(metaclass=_Parser):
TokenType.SEMI,
TokenType.SET,
TokenType.SHOW,
+ TokenType.SORTKEY,
TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
@@ -207,7 +210,6 @@ class Parser(metaclass=_Parser):
TokenType.TRANSIENT,
TokenType.TOP,
TokenType.TRAILING,
- TokenType.TRUNCATE,
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
@@ -217,6 +219,7 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
+ *NO_PAREN_FUNCTIONS,
}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
@@ -231,6 +234,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
+ TokenType.IDENTIFIER,
TokenType.ISNULL,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
@@ -242,6 +246,7 @@ class Parser(metaclass=_Parser):
TokenType.RIGHT,
TokenType.DATE,
TokenType.DATETIME,
+ TokenType.TABLE,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
*TYPE_TOKENS,
@@ -277,6 +282,7 @@ class Parser(metaclass=_Parser):
TokenType.DASH: exp.Sub,
TokenType.PLUS: exp.Add,
TokenType.MOD: exp.Mod,
+ TokenType.COLLATE: exp.Collate,
}
FACTOR = {
@@ -391,7 +397,10 @@ class Parser(metaclass=_Parser):
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
- TokenType.USE: lambda self: self._parse_use(),
+ TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
+ TokenType.BEGIN: lambda self: self._parse_transaction(),
+ TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
+ TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
}
PRIMARY_PARSERS = {
@@ -402,7 +411,8 @@ class Parser(metaclass=_Parser):
exp.Literal, this=token.text, is_string=False
),
TokenType.STAR: lambda self, _: self.expression(
- exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
+ exp.Star,
+ **{"except": self._parse_except(), "replace": self._parse_replace()},
),
TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
@@ -446,6 +456,9 @@ class Parser(metaclass=_Parser):
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(),
+ TokenType.DISTKEY: lambda self: self._parse_distkey(),
+ TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
+ TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@@ -471,7 +484,9 @@ class Parser(metaclass=_Parser):
}
CONSTRAINT_PARSERS = {
- TokenType.CHECK: lambda self: self._parse_check(),
+ TokenType.CHECK: lambda self: self.expression(
+ exp.Check, this=self._parse_wrapped(self._parse_conjunction)
+ ),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(),
}
@@ -521,6 +536,8 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
}
+ TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
+
STRICT_CAST = True
__slots__ = (
@@ -740,6 +757,7 @@ class Parser(metaclass=_Parser):
kind=kind,
temporary=temporary,
materialized=materialized,
+ cascade=self._match(TokenType.CASCADE),
)
def _parse_exists(self, not_=False):
@@ -777,7 +795,11 @@ class Parser(metaclass=_Parser):
expression = self._parse_select_or_expression()
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
- elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA):
+ elif create_token.token_type in (
+ TokenType.TABLE,
+ TokenType.VIEW,
+ TokenType.SCHEMA,
+ ):
this = self._parse_table(schema=True)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
@@ -834,7 +856,38 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
- value=exp.Literal.string(self._parse_var().name),
+ value=exp.Literal.string(self._parse_var_or_string().name),
+ )
+
+ def _parse_distkey(self):
+ self._match_l_paren()
+ this = exp.Literal.string("DISTKEY")
+ value = exp.Literal.string(self._parse_var().name)
+ self._match_r_paren()
+ return self.expression(
+ exp.DistKeyProperty,
+ this=this,
+ value=value,
+ )
+
+ def _parse_sortkey(self):
+ self._match_l_paren()
+ this = exp.Literal.string("SORTKEY")
+ value = exp.Literal.string(self._parse_var().name)
+ self._match_r_paren()
+ return self.expression(
+ exp.SortKeyProperty,
+ this=this,
+ value=value,
+ )
+
+ def _parse_diststyle(self):
+ this = exp.Literal.string("DISTSTYLE")
+ value = exp.Literal.string(self._parse_var().name)
+ return self.expression(
+ exp.DistStyleProperty,
+ this=this,
+ value=value,
)
def _parse_auto_increment(self):
@@ -842,7 +895,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.AutoIncrementProperty,
this=exp.Literal.string("AUTO_INCREMENT"),
- value=self._parse_var() or self._parse_number(),
+ value=self._parse_number(),
)
def _parse_schema_comment(self):
@@ -898,13 +951,10 @@ class Parser(metaclass=_Parser):
while True:
if self._match(TokenType.WITH):
- self._match_l_paren()
- properties.extend(self._parse_csv(lambda: self._parse_property()))
- self._match_r_paren()
+ properties.extend(self._parse_wrapped_csv(self._parse_property))
elif self._match(TokenType.PROPERTIES):
- self._match_l_paren()
properties.extend(
- self._parse_csv(
+ self._parse_wrapped_csv(
lambda: self.expression(
exp.AnonymousProperty,
this=self._parse_string(),
@@ -912,25 +962,24 @@ class Parser(metaclass=_Parser):
)
)
)
- self._match_r_paren()
else:
identified_property = self._parse_property()
if not identified_property:
break
properties.append(identified_property)
+
if properties:
return self.expression(exp.Properties, expressions=properties)
return None
def _parse_describe(self):
self._match(TokenType.TABLE)
-
return self.expression(exp.Describe, this=self._parse_id_var())
def _parse_insert(self):
overwrite = self._match(TokenType.OVERWRITE)
local = self._match(TokenType.LOCAL)
- if self._match_text("DIRECTORY"):
+ if self._match_text_seq("DIRECTORY"):
this = self.expression(
exp.Directory,
this=self._parse_var_or_string(),
@@ -954,27 +1003,27 @@ class Parser(metaclass=_Parser):
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
- self._match_text("DELIMITED")
+ self._match_text_seq("DELIMITED")
kwargs = {}
- if self._match_text("FIELDS", "TERMINATED", "BY"):
+ if self._match_text_seq("FIELDS", "TERMINATED", "BY"):
kwargs["fields"] = self._parse_string()
- if self._match_text("ESCAPED", "BY"):
+ if self._match_text_seq("ESCAPED", "BY"):
kwargs["escaped"] = self._parse_string()
- if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"):
+ if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"):
kwargs["collection_items"] = self._parse_string()
- if self._match_text("MAP", "KEYS", "TERMINATED", "BY"):
+ if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"):
kwargs["map_keys"] = self._parse_string()
- if self._match_text("LINES", "TERMINATED", "BY"):
+ if self._match_text_seq("LINES", "TERMINATED", "BY"):
kwargs["lines"] = self._parse_string()
- if self._match_text("NULL", "DEFINED", "AS"):
+ if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
return self.expression(exp.RowFormat, **kwargs)
def _parse_load_data(self):
local = self._match(TokenType.LOCAL)
- self._match_text("INPATH")
+ self._match_text_seq("INPATH")
inpath = self._parse_string()
overwrite = self._match(TokenType.OVERWRITE)
self._match_pair(TokenType.INTO, TokenType.TABLE)
@@ -986,8 +1035,8 @@ class Parser(metaclass=_Parser):
overwrite=overwrite,
inpath=inpath,
partition=self._parse_partition(),
- input_format=self._match_text("INPUTFORMAT") and self._parse_string(),
- serde=self._match_text("SERDE") and self._parse_string(),
+ input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
+ serde=self._match_text_seq("SERDE") and self._parse_string(),
)
def _parse_delete(self):
@@ -996,9 +1045,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
- using=self._parse_csv(
- lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
- ),
+ using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
where=self._parse_where(),
)
@@ -1029,12 +1076,7 @@ class Parser(metaclass=_Parser):
options = []
if self._match(TokenType.OPTIONS):
- self._match_l_paren()
- k = self._parse_string()
- self._match(TokenType.EQ)
- v = self._parse_string()
- options = [k, v]
- self._match_r_paren()
+ options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(
@@ -1050,27 +1092,13 @@ class Parser(metaclass=_Parser):
return None
def parse_values():
- key = self._parse_var()
- value = None
-
- if self._match(TokenType.EQ):
- value = self._parse_string()
-
- return exp.Property(this=key, value=value)
-
- self._match_l_paren()
- values = self._parse_csv(parse_values)
- self._match_r_paren()
+ props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
+ return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
- return self.expression(
- exp.Partition,
- this=values,
- )
+ return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
def _parse_value(self):
- self._match_l_paren()
- expressions = self._parse_csv(self._parse_conjunction)
- self._match_r_paren()
+ expressions = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, nested=False, table=False):
@@ -1124,10 +1152,11 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
this = self._parse_subquery(this)
elif self._match(TokenType.VALUES):
- this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
- alias = self._parse_table_alias()
- if alias:
- this = self.expression(exp.Subquery, this=this, alias=alias)
+ this = self.expression(
+ exp.Values,
+ expressions=self._parse_csv(self._parse_value),
+ alias=self._parse_table_alias(),
+ )
else:
this = None
@@ -1140,7 +1169,6 @@ class Parser(metaclass=_Parser):
recursive = self._match(TokenType.RECURSIVE)
expressions = []
-
while True:
expressions.append(self._parse_cte())
@@ -1149,11 +1177,7 @@ class Parser(metaclass=_Parser):
else:
self._match(TokenType.WITH)
- return self.expression(
- exp.With,
- expressions=expressions,
- recursive=recursive,
- )
+ return self.expression(exp.With, expressions=expressions, recursive=recursive)
def _parse_cte(self):
alias = self._parse_table_alias()
@@ -1163,13 +1187,9 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.ALIAS):
self.raise_error("Expected AS in CTE")
- self._match_l_paren()
- expression = self._parse_statement()
- self._match_r_paren()
-
return self.expression(
exp.CTE,
- this=expression,
+ this=self._parse_wrapped(self._parse_statement),
alias=alias,
)
@@ -1223,7 +1243,7 @@ class Parser(metaclass=_Parser):
def _parse_hint(self):
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
- if not self._match(TokenType.HINT):
+ if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
return self.expression(exp.Hint, expressions=hints)
return None
@@ -1259,26 +1279,18 @@ class Parser(metaclass=_Parser):
columns = self._parse_csv(self._parse_id_var)
elif self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var)
- self._match(TokenType.R_PAREN)
+ self._match_r_paren()
expression = self.expression(
exp.Lateral,
this=this,
view=view,
outer=outer,
- alias=self.expression(
- exp.TableAlias,
- this=table_alias,
- columns=columns,
- ),
+ alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
)
if outer_apply or cross_apply:
- return self.expression(
- exp.Join,
- this=expression,
- side=None if cross_apply else "LEFT",
- )
+ return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
return expression
@@ -1387,12 +1399,8 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.UNNEST):
return None
- self._match_l_paren()
- expressions = self._parse_csv(self._parse_column)
- self._match_r_paren()
-
+ expressions = self._parse_wrapped_csv(self._parse_column)
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
-
alias = self._parse_table_alias()
if alias and self.unnest_column_only:
@@ -1402,10 +1410,7 @@ class Parser(metaclass=_Parser):
alias.set("this", None)
return self.expression(
- exp.Unnest,
- expressions=expressions,
- ordinality=ordinality,
- alias=alias,
+ exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
)
def _parse_derived_table_values(self):
@@ -1418,13 +1423,7 @@ class Parser(metaclass=_Parser):
if is_derived:
self._match_r_paren()
- alias = self._parse_table_alias()
-
- return self.expression(
- exp.Values,
- expressions=expressions,
- alias=alias,
- )
+ return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE):
@@ -1460,9 +1459,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
if self._match(TokenType.SEED):
- self._match_l_paren()
- seed = self._parse_number()
- self._match_r_paren()
+ seed = self._parse_wrapped(self._parse_number)
return self.expression(
exp.TableSample,
@@ -1513,12 +1510,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
- return self.expression(
- exp.Pivot,
- expressions=expressions,
- field=field,
- unpivot=unpivot,
- )
+ return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE):
@@ -1539,11 +1531,7 @@ class Parser(metaclass=_Parser):
def _parse_grouping_sets(self):
if not self._match(TokenType.GROUPING_SETS):
return None
-
- self._match_l_paren()
- grouping_sets = self._parse_csv(self._parse_grouping_set)
- self._match_r_paren()
- return grouping_sets
+ return self._parse_wrapped_csv(self._parse_grouping_set)
def _parse_grouping_set(self):
if self._match(TokenType.L_PAREN):
@@ -1573,7 +1561,6 @@ class Parser(metaclass=_Parser):
def _parse_sort(self, token_type, exp_class):
if not self._match(token_type):
return None
-
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
def _parse_ordered(self):
@@ -1602,9 +1589,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
+
if limit_paren:
- self._match(TokenType.R_PAREN)
+ self._match_r_paren()
+
return limit_exp
+
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
@@ -1612,11 +1602,13 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
self._match(TokenType.ONLY)
return self.expression(exp.Fetch, direction=direction, count=count)
+
return this
def _parse_offset(self, this=None):
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
return this
+
count = self._parse_number()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
@@ -1678,6 +1670,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
+
this = self.expression(
exp.Is,
this=this,
@@ -1754,11 +1747,7 @@ class Parser(metaclass=_Parser):
def _parse_type(self):
if self._match(TokenType.INTERVAL):
- return self.expression(
- exp.Interval,
- this=self._parse_term(),
- unit=self._parse_var(),
- )
+ return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var())
index = self._index
type_token = self._parse_types(check_func=True)
@@ -1824,30 +1813,18 @@ class Parser(metaclass=_Parser):
value = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMPTZ,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
elif (
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
):
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMPLTZ,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match(TokenType.WITHOUT_TIME_ZONE):
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMP,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
maybe_func = maybe_func and value is None
if value is None:
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMP,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
if maybe_func and check_func:
index2 = self._index
@@ -1872,6 +1849,7 @@ class Parser(metaclass=_Parser):
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
+
if not data_type:
return None
return self.expression(exp.StructKwarg, this=this, expression=data_type)
@@ -1879,7 +1857,6 @@ class Parser(metaclass=_Parser):
def _parse_at_time_zone(self, this):
if not self._match(TokenType.AT_TIME_ZONE):
return this
-
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
def _parse_column(self):
@@ -1984,16 +1961,14 @@ class Parser(metaclass=_Parser):
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
- if subquery_predicate and self._curr.token_type in (
- TokenType.SELECT,
- TokenType.WITH,
- ):
+ if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH):
this = self.expression(subquery_predicate, this=self._parse_select())
self._match_r_paren()
return this
if functions is None:
functions = self.FUNCTIONS
+
function = functions.get(upper)
args = self._parse_csv(self._parse_lambda)
@@ -2014,6 +1989,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return this
+
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
@@ -2021,25 +1997,19 @@ class Parser(metaclass=_Parser):
def _parse_introducer(self, token):
literal = self._parse_primary()
if literal:
- return self.expression(
- exp.Introducer,
- this=token.text,
- expression=literal,
- )
+ return self.expression(exp.Introducer, this=token.text, expression=literal)
return self.expression(exp.Identifier, this=token.text)
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
+
if self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
- return self.expression(
- exp.SessionParameter,
- this=this,
- kind=kind,
- )
+
+ return self.expression(exp.SessionParameter, this=this, kind=kind)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
@@ -2106,7 +2076,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
def _parse_column_constraint(self):
- this = None
+ this = self._parse_references()
+
+ if this:
+ return this
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
@@ -2114,13 +2087,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.AUTO_INCREMENT):
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
- self._match_l_paren()
- kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
- self._match_r_paren()
+ constraint = self._parse_wrapped(self._parse_conjunction)
+ kind = self.expression(exp.CheckColumnConstraint, this=constraint)
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
- kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
+ kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.SCHEMA_COMMENT):
@@ -2137,7 +2109,7 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
else:
- return None
+ return this
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
@@ -2159,37 +2131,29 @@ class Parser(metaclass=_Parser):
def _parse_unnamed_constraint(self):
if not self._match_set(self.CONSTRAINT_PARSERS):
return None
-
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
- def _parse_check(self):
- self._match(TokenType.CHECK)
- self._match_l_paren()
- expression = self._parse_conjunction()
- self._match_r_paren()
-
- return self.expression(exp.Check, this=expression)
-
def _parse_unique(self):
- self._match(TokenType.UNIQUE)
- columns = self._parse_wrapped_id_vars()
-
- return self.expression(exp.Unique, expressions=columns)
+ return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
- def _parse_foreign_key(self):
- self._match(TokenType.FOREIGN_KEY)
-
- expressions = self._parse_wrapped_id_vars()
- reference = self._match(TokenType.REFERENCES) and self.expression(
+ def _parse_references(self):
+ if not self._match(TokenType.REFERENCES):
+ return None
+ return self.expression(
exp.Reference,
this=self._parse_id_var(),
expressions=self._parse_wrapped_id_vars(),
)
+
+ def _parse_foreign_key(self):
+ expressions = self._parse_wrapped_id_vars()
+ reference = self._parse_references()
options = {}
while self._match(TokenType.ON):
if not self._match_set((TokenType.DELETE, TokenType.UPDATE)):
self.raise_error("Expected DELETE or UPDATE")
+
kind = self._prev.text.lower()
if self._match(TokenType.NO_ACTION):
@@ -2200,6 +2164,7 @@ class Parser(metaclass=_Parser):
else:
self._advance()
action = self._prev.text.upper()
+
options[kind] = action
return self.expression(
@@ -2363,20 +2328,14 @@ class Parser(metaclass=_Parser):
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
- self._match_l_paren()
- this = self.expression(exp.Filter, this=this, expression=self._parse_where())
- self._match_r_paren()
+ where = self._parse_wrapped(self._parse_where)
+ this = self.expression(exp.Filter, this=this, expression=where)
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
if self._match(TokenType.WITHIN_GROUP):
- self._match_l_paren()
- this = self.expression(
- exp.WithinGroup,
- this=this,
- expression=self._parse_order(),
- )
- self._match_r_paren()
+ order = self._parse_wrapped(self._parse_order)
+ this = self.expression(exp.WithinGroup, this=this, expression=order)
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
# Some dialects choose to implement and some do not.
@@ -2404,18 +2363,11 @@ class Parser(metaclass=_Parser):
return this
if not self._match(TokenType.L_PAREN):
- alias = self._parse_id_var(False)
-
- return self.expression(
- exp.Window,
- this=this,
- alias=alias,
- )
-
- partition = None
+ return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
alias = self._parse_id_var(False)
+ partition = None
if self._match(TokenType.PARTITION_BY):
partition = self._parse_csv(self._parse_conjunction)
@@ -2552,17 +2504,13 @@ class Parser(metaclass=_Parser):
def _parse_replace(self):
if not self._match(TokenType.REPLACE):
return None
+ return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
- self._match_l_paren()
- columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression()))
- self._match_r_paren()
- return columns
-
- def _parse_csv(self, parse_method):
+ def _parse_csv(self, parse_method, sep=TokenType.COMMA):
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
- while self._match(TokenType.COMMA):
+ while self._match(sep):
if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
@@ -2583,16 +2531,53 @@ class Parser(metaclass=_Parser):
return this
def _parse_wrapped_id_vars(self):
+ return self._parse_wrapped_csv(self._parse_id_var)
+
+ def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA):
+ return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
+
+ def _parse_wrapped(self, parse_method):
self._match_l_paren()
- expressions = self._parse_csv(self._parse_id_var)
+ parse_result = parse_method()
self._match_r_paren()
- return expressions
+ return parse_result
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
- def _parse_use(self):
- return self.expression(exp.Use, this=self._parse_id_var())
+ def _parse_transaction(self):
+ this = None
+ if self._match_texts(self.TRANSACTION_KIND):
+ this = self._prev.text
+
+ self._match_texts({"TRANSACTION", "WORK"})
+
+ modes = []
+ while True:
+ mode = []
+ while self._match(TokenType.VAR):
+ mode.append(self._prev.text)
+
+ if mode:
+ modes.append(" ".join(mode))
+ if not self._match(TokenType.COMMA):
+ break
+
+ return self.expression(exp.Transaction, this=this, modes=modes)
+
+ def _parse_commit_or_rollback(self):
+ savepoint = None
+ is_rollback = self._prev.token_type == TokenType.ROLLBACK
+
+ self._match_texts({"TRANSACTION", "WORK"})
+
+ if self._match_text_seq("TO"):
+ self._match_text_seq("SAVEPOINT")
+ savepoint = self._parse_id_var()
+
+ if is_rollback:
+ return self.expression(exp.Rollback, savepoint=savepoint)
+ return self.expression(exp.Commit)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@@ -2675,7 +2660,13 @@ class Parser(metaclass=_Parser):
if expression and self._prev_comment:
expression.comment = self._prev_comment
- def _match_text(self, *texts):
+ def _match_texts(self, texts):
+ if self._curr and self._curr.text.upper() in texts:
+ self._advance()
+ return True
+ return False
+
+ def _match_text_seq(self, *texts):
index = self._index
for text in texts:
if self._curr and self._curr.text.upper() == text:
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index cd1de5e..51db2d4 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import itertools
import math
+import typing as t
from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError
@@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition
class Plan:
- def __init__(self, expression):
- self.expression = expression
+ def __init__(self, expression: exp.Expression) -> None:
+ self.expression = expression.copy()
self.root = Step.from_expression(self.expression)
- self._dag = {}
+ self._dag: t.Dict[Step, t.Set[Step]] = {}
@property
- def dag(self):
+ def dag(self) -> t.Dict[Step, t.Set[Step]]:
if not self._dag:
- dag = {}
+ dag: t.Dict[Step, t.Set[Step]] = {}
nodes = {self.root}
while nodes:
@@ -29,32 +32,64 @@ class Plan:
return self._dag
@property
- def leaves(self):
+ def leaves(self) -> t.Generator[Step, None, None]:
return (node for node, deps in self.dag.items() if not deps)
+ def __repr__(self) -> str:
+ return f"Plan\n----\n{repr(self.root)}"
+
class Step:
@classmethod
- def from_expression(cls, expression, ctes=None):
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
"""
- Build a DAG of Steps from a SQL expression.
-
- Giving an expression like:
-
- SELECT x.a, SUM(x.b)
- FROM x
- JOIN y
- ON x.a = y.a
+ Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
+ Note: the expression's tables and subqueries must be aliased for this method to work. For
+ example, given the following expression:
+
+ SELECT
+ x.a,
+ SUM(x.b)
+ FROM x AS x
+ JOIN y AS y
+ ON x.a = y.a
GROUP BY x.a
- Transform it into a DAG of the form:
-
- Aggregate(x.a, SUM(x.b))
- Join(y)
- Scan(x)
- Scan(y)
-
- This can then more easily be executed on by an engine.
+ the following DAG is produced (the expression IDs might differ per execution):
+
+ - Aggregate: x (4347984624)
+ Context:
+ Aggregations:
+ - SUM(x.b)
+ Group:
+ - x.a
+ Projections:
+ - x.a
+ - "x".""
+ Dependencies:
+ - Join: x (4347985296)
+ Context:
+ y:
+ On: x.a = y.a
+ Projections:
+ Dependencies:
+ - Scan: x (4347983136)
+ Context:
+ Source: x AS x
+ Projections:
+ - Scan: y (4343416624)
+ Context:
+ Source: y AS y
+ Projections:
+
+ Args:
+ expression: the expression to build the DAG from.
+ ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
+
+ Returns:
+ A Step DAG corresponding to `expression`.
"""
ctes = ctes or {}
with_ = expression.args.get("with")
@@ -65,11 +100,11 @@ class Step:
for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes)
step.name = cte.alias
- ctes[step.name] = step
+ ctes[step.name] = step # type: ignore
from_ = expression.args.get("from")
- if from_:
+ if isinstance(expression, exp.Select) and from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError(
@@ -77,8 +112,10 @@ class Step:
)
step = Scan.from_expression(from_[0], ctes)
+ elif isinstance(expression, exp.Union):
+ step = SetOperation.from_expression(expression, ctes)
else:
- raise UnsupportedError("Static selects are unsupported.")
+ step = Scan()
joins = expression.args.get("joins")
@@ -115,7 +152,7 @@ class Step:
group = expression.args.get("group")
- if group:
+ if group or aggregations:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
@@ -123,7 +160,15 @@ class Step:
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations
- aggregate.group = group.expressions
+ # give aggregates names and replace projections with references to them
+ aggregate.group = {
+ f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
+ }
+ for projection in projections:
+ for i, e in aggregate.group.items():
+ for child, _, _ in projection.walk():
+ if child == e:
+ child.replace(exp.column(i, step.name))
aggregate.add_dependency(step)
step = aggregate
@@ -150,22 +195,22 @@ class Step:
return step
- def __init__(self):
- self.name = None
- self.dependencies = set()
- self.dependents = set()
- self.projections = []
- self.limit = math.inf
- self.condition = None
+ def __init__(self) -> None:
+ self.name: t.Optional[str] = None
+ self.dependencies: t.Set[Step] = set()
+ self.dependents: t.Set[Step] = set()
+ self.projections: t.Sequence[exp.Expression] = []
+ self.limit: float = math.inf
+ self.condition: t.Optional[exp.Expression] = None
- def add_dependency(self, dependency):
+ def add_dependency(self, dependency: Step) -> None:
self.dependencies.add(dependency)
dependency.dependents.add(self)
- def __repr__(self):
+ def __repr__(self) -> str:
return self.to_s()
- def to_s(self, level=0):
+ def to_s(self, level: int = 0) -> str:
indent = " " * level
nested = f"{indent} "
@@ -175,7 +220,7 @@ class Step:
context = [f"{nested}Context:"] + context
lines = [
- f"{indent}- {self.__class__.__name__}: {self.name}",
+ f"{indent}- {self.id}",
*context,
f"{nested}Projections:",
]
@@ -193,13 +238,25 @@ class Step:
return "\n".join(lines)
- def _to_s(self, _indent):
+ @property
+ def type_name(self) -> str:
+ return self.__class__.__name__
+
+ @property
+ def id(self) -> str:
+ name = self.name
+ name = f" {name}" if name else ""
+ return f"{self.type_name}:{name} ({id(self)})"
+
+ def _to_s(self, _indent: str) -> t.List[str]:
return []
class Scan(Step):
@classmethod
- def from_expression(cls, expression, ctes=None):
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
table = expression
alias_ = expression.alias
@@ -217,26 +274,24 @@ class Scan(Step):
step = Scan()
step.name = alias_
step.source = expression
- if table.name in ctes:
+ if ctes and table.name in ctes:
step.add_dependency(ctes[table.name])
return step
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.source = None
-
- def _to_s(self, indent):
- return [f"{indent}Source: {self.source.sql()}"]
+ self.source: t.Optional[exp.Expression] = None
-
-class Write(Step):
- pass
+ def _to_s(self, indent: str) -> t.List[str]:
+ return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore
class Join(Step):
@classmethod
- def from_joins(cls, joins, ctes=None):
+ def from_joins(
+ cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
step = Join()
for join in joins:
@@ -252,28 +307,28 @@ class Join(Step):
return step
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.joins = {}
+ self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = []
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side']}")
if join.get("condition"):
- lines.append(f"{indent}On: {join['condition'].sql()}")
+ lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
return lines
class Aggregate(Step):
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.aggregations = []
- self.operands = []
- self.group = []
- self.source = None
+ self.aggregations: t.List[exp.Expression] = []
+ self.operands: t.Tuple[exp.Expression, ...] = ()
+ self.group: t.Dict[str, exp.Expression] = {}
+ self.source: t.Optional[str] = None
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Aggregations:"]
for expression in self.aggregations:
@@ -281,7 +336,7 @@ class Aggregate(Step):
if self.group:
lines.append(f"{indent}Group:")
- for expression in self.group:
+ for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
@@ -292,14 +347,56 @@ class Aggregate(Step):
class Sort(Step):
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
self.key = None
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Key:"]
- for expression in self.key:
+ for expression in self.key: # type: ignore
lines.append(f"{indent} - {expression.sql()}")
return lines
+
+
+class SetOperation(Step):
+ def __init__(
+ self,
+ op: t.Type[exp.Expression],
+ left: str | None,
+ right: str | None,
+ distinct: bool = False,
+ ) -> None:
+ super().__init__()
+ self.op = op
+ self.left = left
+ self.right = right
+ self.distinct = distinct
+
+ @classmethod
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
+ assert isinstance(expression, exp.Union)
+ left = Step.from_expression(expression.left, ctes)
+ right = Step.from_expression(expression.right, ctes)
+ step = cls(
+ op=expression.__class__,
+ left=left.name,
+ right=right.name,
+ distinct=expression.args.get("distinct"),
+ )
+ step.add_dependency(left)
+ step.add_dependency(right)
+ return step
+
+ def _to_s(self, indent: str) -> t.List[str]:
+ lines = []
+ if self.distinct:
+ lines.append(f"{indent}Distinct: {self.distinct}")
+ return lines
+
+ @property
+ def type_name(self) -> str:
+ return self.op.__name__
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index fcf7291..f6f303b 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
-from sqlglot.helper import csv_reader
+from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
@@ -15,6 +15,8 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog")
+T = t.TypeVar("T")
+
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@@ -57,8 +59,81 @@ class Schema(abc.ABC):
The resulting column type.
"""
+ @property
+ def supported_table_args(self) -> t.Tuple[str, ...]:
+ """
+ Table arguments this schema support, e.g. `("this", "db", "catalog")`
+ """
+ raise NotImplementedError
+
+
+class AbstractMappingSchema(t.Generic[T]):
+ def __init__(
+ self,
+ mapping: dict | None = None,
+ ) -> None:
+ self.mapping = mapping or {}
+ self.mapping_trie = self._build_trie(self.mapping)
+ self._supported_table_args: t.Tuple[str, ...] = tuple()
+
+ def _build_trie(self, schema: t.Dict) -> t.Dict:
+ return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
+
+ def _depth(self) -> int:
+ return dict_depth(self.mapping)
+
+ @property
+ def supported_table_args(self) -> t.Tuple[str, ...]:
+ if not self._supported_table_args and self.mapping:
+ depth = self._depth()
+
+ if not depth: # None
+ self._supported_table_args = tuple()
+ elif 1 <= depth <= 3:
+ self._supported_table_args = TABLE_ARGS[:depth]
+ else:
+ raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
+
+ return self._supported_table_args
+
+ def table_parts(self, table: exp.Table) -> t.List[str]:
+ if isinstance(table.this, exp.ReadCSV):
+ return [table.this.name]
+ return [table.text(part) for part in TABLE_ARGS if table.text(part)]
+
+ def find(
+ self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
+ ) -> t.Optional[T]:
+ parts = self.table_parts(table)[0 : len(self.supported_table_args)]
+ value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
+
+ if value == 0:
+ if raise_on_missing:
+ raise SchemaError(f"Cannot find mapping for {table}.")
+ else:
+ return None
+ elif value == 1:
+ possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
+ if len(possibilities) == 1:
+ parts.extend(possibilities[0])
+ else:
+ message = ", ".join(".".join(parts) for parts in possibilities)
+ if raise_on_missing:
+ raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
+ return None
+ return self._nested_get(parts, raise_on_missing=raise_on_missing)
-class MappingSchema(Schema):
+ def _nested_get(
+ self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
+ ) -> t.Optional[t.Any]:
+ return _nested_get(
+ d or self.mapping,
+ *zip(self.supported_table_args, reversed(parts)),
+ raise_on_missing=raise_on_missing,
+ )
+
+
+class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"""
Schema based on a nested mapping.
@@ -82,17 +157,17 @@ class MappingSchema(Schema):
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
- self.schema = schema or {}
+ super().__init__(schema)
self.visible = visible or {}
- self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect
- self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
- self._supported_table_args: t.Tuple[str, ...] = tuple()
+ self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
+ "STR": exp.DataType.Type.TEXT,
+ }
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema(
- schema=mapping_schema.schema,
+ schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
)
@@ -100,27 +175,13 @@ class MappingSchema(Schema):
def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore
- "schema": self.schema.copy(),
+ "schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
}
)
- @property
- def supported_table_args(self):
- if not self._supported_table_args and self.schema:
- depth = _dict_depth(self.schema)
-
- if not depth or depth == 1: # {}
- self._supported_table_args = tuple()
- elif 2 <= depth <= 4:
- self._supported_table_args = TABLE_ARGS[: depth - 1]
- else:
- raise SchemaError(f"Invalid schema shape. Depth: {depth}")
-
- return self._supported_table_args
-
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
@@ -133,17 +194,21 @@ class MappingSchema(Schema):
"""
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
- schema = self.find_schema(table_, raise_on_missing=False)
+ schema = self.find(table_, raise_on_missing=False)
if schema and not column_mapping:
return
_nested_set(
- self.schema,
+ self.mapping,
list(reversed(self.table_parts(table_))),
column_mapping,
)
- self.schema_trie = self._build_trie(self.schema)
+ self.mapping_trie = self._build_trie(self.mapping)
+
+ def _depth(self) -> int:
+ # The columns themselves are a mapping, but we don't want to include those
+ return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)
@@ -153,16 +218,9 @@ class MappingSchema(Schema):
return table_
- def table_parts(self, table: exp.Table) -> t.List[str]:
- return [table.text(part) for part in TABLE_ARGS if table.text(part)]
-
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table)
-
- if not isinstance(table_.this, exp.Identifier):
- return fs_get(table) # type: ignore
-
- schema = self.find_schema(table_)
+ schema = self.find(table_)
if schema is None:
raise SchemaError(f"Could not find table schema {table}")
@@ -173,36 +231,13 @@ class MappingSchema(Schema):
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
- def find_schema(
- self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
- ) -> t.Optional[t.Dict[str, str]]:
- parts = self.table_parts(table)[0 : len(self.supported_table_args)]
- value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
-
- if value == 0:
- if raise_on_missing:
- raise SchemaError(f"Cannot find schema for {table}.")
- else:
- return None
- elif value == 1:
- possibilities = flatten_schema(trie)
- if len(possibilities) == 1:
- parts.extend(possibilities[0])
- else:
- message = ", ".join(".".join(parts) for parts in possibilities)
- if raise_on_missing:
- raise SchemaError(f"Ambiguous schema for {table}: {message}.")
- return None
-
- return self._nested_get(parts, raise_on_missing=raise_on_missing)
-
def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
- table_schema = self.find_schema(table_)
+ table_schema = self.find(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
raise SchemaError(f"Could not convert table '{table}'")
@@ -228,18 +263,6 @@ class MappingSchema(Schema):
return self._type_mapping_cache[schema_type]
- def _build_trie(self, schema: t.Dict):
- return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
-
- def _nested_get(
- self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
- ) -> t.Optional[t.Any]:
- return _nested_get(
- d or self.schema,
- *zip(self.supported_table_args, reversed(parts)),
- raise_on_missing=raise_on_missing,
- )
-
def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema):
@@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
-def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
+def flatten_schema(
+ schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
+) -> t.List[t.List[str]]:
tables = []
keys = keys or []
- depth = _dict_depth(schema)
for k, v in schema.items():
- if depth >= 3:
- tables.extend(flatten_schema(v, keys + [k]))
- elif depth == 2:
+ if depth >= 2:
+ tables.extend(flatten_schema(v, depth - 1, keys + [k]))
+ elif depth == 1:
tables.append(keys + [k])
return tables
-def fs_get(table: exp.Table) -> t.List[str]:
- name = table.this.name
-
- if name.upper() == "READ_CSV":
- with csv_reader(table) as reader:
- return next(reader)
-
- raise ValueError(f"Cannot read schema for {table}")
-
-
def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
@@ -310,7 +324,7 @@ def _nested_get(
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
- raise ValueError(f"Unknown {name}")
+ raise ValueError(f"Unknown {name}: {key}")
return None
return d
@@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
subd[keys[-1]] = value
return d
-
-
-def _dict_depth(d: t.Dict) -> int:
- """
- Get the nesting depth of a dictionary.
-
- For example:
- >>> _dict_depth(None)
- 0
- >>> _dict_depth({})
- 1
- >>> _dict_depth({"a": "b"})
- 1
- >>> _dict_depth({"a": {}})
- 2
- >>> _dict_depth({"a": {"b": {}}})
- 3
-
- Args:
- d (dict): dictionary
- Returns:
- int: depth
- """
- try:
- return 1 + _dict_depth(next(iter(d.values())))
- except AttributeError:
- # d doesn't have attribute "values"
- return 0
- except StopIteration:
- # d.values() returns an empty sequence
- return 1
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 95d84d6..ec8cd91 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -105,12 +105,9 @@ class TokenType(AutoName):
OBJECT = auto()
# keywords
- ADD_FILE = auto()
ALIAS = auto()
ALWAYS = auto()
ALL = auto()
- ALTER = auto()
- ANALYZE = auto()
ANTI = auto()
ANY = auto()
APPLY = auto()
@@ -124,14 +121,14 @@ class TokenType(AutoName):
BUCKET = auto()
BY_DEFAULT = auto()
CACHE = auto()
- CALL = auto()
+ CASCADE = auto()
CASE = auto()
CHARACTER_SET = auto()
CHECK = auto()
CLUSTER_BY = auto()
COLLATE = auto()
+ COMMAND = auto()
COMMENT = auto()
- COMMENT_ON = auto()
COMMIT = auto()
CONSTRAINT = auto()
CREATE = auto()
@@ -149,7 +146,9 @@ class TokenType(AutoName):
DETERMINISTIC = auto()
DISTINCT = auto()
DISTINCT_FROM = auto()
+ DISTKEY = auto()
DISTRIBUTE_BY = auto()
+ DISTSTYLE = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
@@ -159,7 +158,6 @@ class TokenType(AutoName):
EXCEPT = auto()
EXECUTE = auto()
EXISTS = auto()
- EXPLAIN = auto()
FALSE = auto()
FETCH = auto()
FILTER = auto()
@@ -216,7 +214,6 @@ class TokenType(AutoName):
OFFSET = auto()
ON = auto()
ONLY = auto()
- OPTIMIZE = auto()
OPTIONS = auto()
ORDER_BY = auto()
ORDERED = auto()
@@ -258,6 +255,7 @@ class TokenType(AutoName):
SHOW = auto()
SIMILAR_TO = auto()
SOME = auto()
+ SORTKEY = auto()
SORT_BY = auto()
STABLE = auto()
STORED = auto()
@@ -268,9 +266,8 @@ class TokenType(AutoName):
TRANSIENT = auto()
TOP = auto()
THEN = auto()
- TRUE = auto()
TRAILING = auto()
- TRUNCATE = auto()
+ TRUE = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
@@ -280,7 +277,6 @@ class TokenType(AutoName):
USE = auto()
USING = auto()
VALUES = auto()
- VACUUM = auto()
VIEW = auto()
VOLATILE = auto()
WHEN = auto()
@@ -420,7 +416,6 @@ class Tokenizer(metaclass=_Tokenizer):
KEYWORDS = {
"/*+": TokenType.HINT,
- "*/": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
"||": TokenType.DPIPE,
@@ -435,15 +430,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
- "ADD ARCHIVE": TokenType.ADD_FILE,
- "ADD ARCHIVES": TokenType.ADD_FILE,
- "ADD FILE": TokenType.ADD_FILE,
- "ADD FILES": TokenType.ADD_FILE,
- "ADD JAR": TokenType.ADD_FILE,
- "ADD JARS": TokenType.ADD_FILE,
"ALL": TokenType.ALL,
- "ALTER": TokenType.ALTER,
- "ANALYZE": TokenType.ANALYZE,
"AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY,
@@ -455,10 +442,10 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
- "CALL": TokenType.CALL,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
+ "CASCADE": TokenType.CASCADE,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
@@ -479,7 +466,9 @@ class Tokenizer(metaclass=_Tokenizer):
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM,
+ "DISTKEY": TokenType.DISTKEY,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
+ "DISTSTYLE": TokenType.DISTSTYLE,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
@@ -489,7 +478,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS,
- "EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER,
@@ -541,7 +529,6 @@ class Tokenizer(metaclass=_Tokenizer):
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
"ONLY": TokenType.ONLY,
- "OPTIMIZE": TokenType.OPTIMIZE,
"OPTIONS": TokenType.OPTIONS,
"OR": TokenType.OR,
"ORDER BY": TokenType.ORDER_BY,
@@ -579,6 +566,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
+ "SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
"STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
@@ -592,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
- "TRUNCATE": TokenType.TRUNCATE,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
@@ -600,7 +587,6 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
- "VACUUM": TokenType.VACUUM,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
@@ -659,6 +645,14 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
+ "ALTER": TokenType.COMMAND,
+ "ANALYZE": TokenType.COMMAND,
+ "CALL": TokenType.COMMAND,
+ "EXPLAIN": TokenType.COMMAND,
+ "OPTIMIZE": TokenType.COMMAND,
+ "PREPARE": TokenType.COMMAND,
+ "TRUNCATE": TokenType.COMMAND,
+ "VACUUM": TokenType.COMMAND,
}
WHITE_SPACE = {
@@ -670,20 +664,11 @@ class Tokenizer(metaclass=_Tokenizer):
}
COMMANDS = {
- TokenType.ALTER,
- TokenType.ADD_FILE,
- TokenType.ANALYZE,
- TokenType.BEGIN,
- TokenType.CALL,
- TokenType.COMMENT_ON,
- TokenType.COMMIT,
- TokenType.EXPLAIN,
- TokenType.OPTIMIZE,
+ TokenType.COMMAND,
+ TokenType.EXECUTE,
+ TokenType.FETCH,
TokenType.SET,
TokenType.SHOW,
- TokenType.TRUNCATE,
- TokenType.VACUUM,
- TokenType.ROLLBACK,
}
# handle numeric literals like in hive (3L = BIGINT)
@@ -885,6 +870,7 @@ class Tokenizer(metaclass=_Tokenizer):
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
+ self._prev_token_comment = self._comment
self._comment = None
diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py
index e36667b..24850bc 100644
--- a/tests/dataframe/unit/test_dataframe.py
+++ b/tests/dataframe/unit/test_dataframe.py
@@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator):
+ maxDiff = None
+
def test_hash_select_expression(self):
expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
@@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator):
def test_cache(self):
df = self.df_employee.select("fname").cache()
expected_statements = [
- "DROP VIEW IF EXISTS t11623",
- "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
+ "DROP VIEW IF EXISTS t31563",
+ "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_default(self):
df = self.df_employee.select("fname").persist()
expected_statements = [
- "DROP VIEW IF EXISTS t11623",
- "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
+ "DROP VIEW IF EXISTS t31563",
+ "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [
- "DROP VIEW IF EXISTS t11623",
- "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
+ "DROP VIEW IF EXISTS t31563",
+ "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py
index 14b4a0a..7c646f5 100644
--- a/tests/dataframe/unit/test_dataframe_writer.py
+++ b/tests/dataframe/unit/test_dataframe_writer.py
@@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator):
+ maxDiff = None
+
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
- expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
- expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
- expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t35612",
- "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
+ "DROP VIEW IF EXISTS t37164",
+ "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
]
self.compare_sql(df, expected_statements)
@@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
- expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t35612",
- "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
+ "DROP VIEW IF EXISTS t37164",
+ "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
]
self.compare_sql(df, expected_statements)
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index 7e8bfad..55aa547 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
- expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
+ expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
- expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
+ expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected)
def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
- expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
@@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
- expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
@@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
- expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
+ expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
+
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index a0ebc45..cc44311 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -286,6 +286,10 @@ class TestBigQuery(Validator):
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
},
)
+ self.validate_identity("BEGIN A B C D E F")
+ self.validate_identity("BEGIN TRANSACTION")
+ self.validate_identity("COMMIT TRANSACTION")
+ self.validate_identity("ROLLBACK TRANSACTION")
def test_user_defined_functions(self):
self.validate_identity(
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 1913f53..1b2f9c1 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -69,6 +69,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS STRING)",
"clickhouse": "CAST(a AS TEXT)",
+ "drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
@@ -86,6 +87,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
+ "drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BINARY(4))",
"mysql": "CAST(a AS BINARY(4))",
"hive": "CAST(a AS BINARY(4))",
@@ -146,6 +148,7 @@ class TestDialect(Validator):
"CAST(a AS STRING)",
write={
"bigquery": "CAST(a AS STRING)",
+ "drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
@@ -162,6 +165,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR)",
write={
"bigquery": "CAST(a AS STRING)",
+ "drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS VARCHAR)",
"hive": "CAST(a AS STRING)",
@@ -178,6 +182,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR(3))",
write={
"bigquery": "CAST(a AS STRING(3))",
+ "drill": "CAST(a AS VARCHAR(3))",
"duckdb": "CAST(a AS TEXT(3))",
"mysql": "CAST(a AS VARCHAR(3))",
"hive": "CAST(a AS VARCHAR(3))",
@@ -194,6 +199,7 @@ class TestDialect(Validator):
"CAST(a AS SMALLINT)",
write={
"bigquery": "CAST(a AS INT64)",
+ "drill": "CAST(a AS INTEGER)",
"duckdb": "CAST(a AS SMALLINT)",
"mysql": "CAST(a AS SMALLINT)",
"hive": "CAST(a AS SMALLINT)",
@@ -215,6 +221,7 @@ class TestDialect(Validator):
},
write={
"duckdb": "TRY_CAST(a AS DOUBLE)",
+ "drill": "CAST(a AS DOUBLE)",
"postgres": "CAST(a AS DOUBLE PRECISION)",
"redshift": "CAST(a AS DOUBLE PRECISION)",
},
@@ -225,6 +232,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS FLOAT64)",
"clickhouse": "CAST(a AS Float64)",
+ "drill": "CAST(a AS DOUBLE)",
"duckdb": "CAST(a AS DOUBLE)",
"mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)",
@@ -279,6 +287,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
+ "drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
},
@@ -286,6 +295,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
write={
+ "drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
@@ -298,6 +308,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_TIME(x, '%y')",
write={
+ "drill": "TO_TIMESTAMP(x, 'yy')",
"duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')",
@@ -319,6 +330,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_STR_TO_DATE('2020-01-01')",
write={
+ "drill": "CAST('2020-01-01' AS DATE)",
"duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@@ -328,6 +340,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_STR_TO_TIME('2020-01-01')",
write={
+ "drill": "CAST('2020-01-01' AS TIMESTAMP)",
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@@ -344,6 +357,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_STR(x, '%Y-%m-%d')",
write={
+ "drill": "TO_CHAR(x, 'yyyy-MM-dd')",
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
@@ -355,6 +369,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_TIME_STR(x)",
write={
+ "drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
@@ -364,6 +379,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_UNIX(x)",
write={
+ "drill": "UNIX_TIMESTAMP(x)",
"duckdb": "EPOCH(x)",
"hive": "UNIX_TIMESTAMP(x)",
"presto": "TO_UNIXTIME(x)",
@@ -425,6 +441,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_TO_DATE_STR(x)",
write={
+ "drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
@@ -433,6 +450,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_TO_DI(x)",
write={
+ "drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)",
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
@@ -441,6 +459,7 @@ class TestDialect(Validator):
self.validate_all(
"DI_TO_DATE(x)",
write={
+ "drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')",
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
@@ -463,6 +482,7 @@ class TestDialect(Validator):
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
+ "drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@@ -477,6 +497,7 @@ class TestDialect(Validator):
"DATE_ADD(x, 1)",
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
+ "drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@@ -546,6 +567,7 @@ class TestDialect(Validator):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
},
write={
+ "drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
@@ -556,6 +578,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_DATE(x, '%Y-%m-%d')",
write={
+ "drill": "CAST(x AS DATE)",
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
"hive": "CAST(x AS DATE)",
@@ -566,6 +589,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_STR_TO_DATE(x)",
write={
+ "drill": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
@@ -575,6 +599,7 @@ class TestDialect(Validator):
self.validate_all(
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
write={
+ "drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
@@ -584,6 +609,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
write={
+ "drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
@@ -593,6 +619,7 @@ class TestDialect(Validator):
self.validate_all(
"TIMESTAMP '2022-01-01'",
write={
+ "drill": "CAST('2022-01-01' AS TIMESTAMP)",
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
@@ -614,6 +641,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
+ "drill",
"duckdb",
"mysql",
"presto",
@@ -624,6 +652,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
+ "drill",
"duckdb",
"mysql",
"presto",
@@ -649,6 +678,7 @@ class TestDialect(Validator):
write={
"bigquery": "ARRAY_LENGTH(x)",
"duckdb": "ARRAY_LENGTH(x)",
+ "drill": "REPEATED_COUNT(x)",
"presto": "CARDINALITY(x)",
"spark": "SIZE(x)",
},
@@ -736,6 +766,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
write={
+ "drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
},
@@ -743,6 +774,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
write={
+ "drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
},
@@ -775,6 +807,7 @@ class TestDialect(Validator):
},
write={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
+ "drill": "SELECT * FROM a UNION SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b",
@@ -887,6 +920,7 @@ class TestDialect(Validator):
write={
"bigquery": "LOWER(x) LIKE '%y'",
"clickhouse": "x ILIKE '%y'",
+ "drill": "x `ILIKE` '%y'",
"duckdb": "x ILIKE '%y'",
"hive": "LOWER(x) LIKE '%y'",
"mysql": "LOWER(x) LIKE '%y'",
@@ -910,32 +944,38 @@ class TestDialect(Validator):
self.validate_all(
"POSITION(' ' in x)",
write={
+ "drill": "STRPOS(x, ' ')",
"duckdb": "STRPOS(x, ' ')",
"postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)",
"clickhouse": "position(x, ' ')",
"snowflake": "POSITION(' ', x)",
+ "mysql": "LOCATE(' ', x)",
},
)
self.validate_all(
"STR_POSITION('a', x)",
write={
+ "drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)",
"clickhouse": "position(x, 'a')",
"snowflake": "POSITION('a', x)",
+ "mysql": "LOCATE('a', x)",
},
)
self.validate_all(
"POSITION('a', x, 3)",
write={
+ "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)",
+ "mysql": "LOCATE('a', x, 3)",
},
)
self.validate_all(
@@ -960,6 +1000,7 @@ class TestDialect(Validator):
self.validate_all(
"IF(x > 1, 1, 0)",
write={
+ "drill": "`IF`(x > 1, 1, 0)",
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
"presto": "IF(x > 1, 1, 0)",
"hive": "IF(x > 1, 1, 0)",
@@ -970,6 +1011,7 @@ class TestDialect(Validator):
self.validate_all(
"CASE WHEN 1 THEN x ELSE 0 END",
write={
+ "drill": "CASE WHEN 1 THEN x ELSE 0 END",
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
"presto": "CASE WHEN 1 THEN x ELSE 0 END",
"hive": "CASE WHEN 1 THEN x ELSE 0 END",
@@ -980,6 +1022,7 @@ class TestDialect(Validator):
self.validate_all(
"x[y]",
write={
+ "drill": "x[y]",
"duckdb": "x[y]",
"presto": "x[y]",
"hive": "x[y]",
@@ -1000,6 +1043,7 @@ class TestDialect(Validator):
'true or null as "foo"',
write={
"bigquery": "TRUE OR NULL AS `foo`",
+ "drill": "TRUE OR NULL AS `foo`",
"duckdb": 'TRUE OR NULL AS "foo"',
"presto": 'TRUE OR NULL AS "foo"',
"hive": "TRUE OR NULL AS `foo`",
@@ -1020,6 +1064,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(col1, col2)",
write={
"duckdb": "LEVENSHTEIN(col1, col2)",
+ "drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
"hive": "LEVENSHTEIN(col1, col2)",
"spark": "LEVENSHTEIN(col1, col2)",
@@ -1029,6 +1074,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
write={
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
+ "drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
@@ -1152,6 +1198,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a AS b FROM x GROUP BY b",
write={
+ "drill": "SELECT a AS b FROM x GROUP BY b",
"duckdb": "SELECT a AS b FROM x GROUP BY b",
"presto": "SELECT a AS b FROM x GROUP BY 1",
"hive": "SELECT a AS b FROM x GROUP BY 1",
@@ -1162,6 +1209,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT y x FROM my_table t",
write={
+ "drill": "SELECT y AS x FROM my_table AS t",
"hive": "SELECT y AS x FROM my_table AS t",
"oracle": "SELECT y AS x FROM my_table t",
"postgres": "SELECT y AS x FROM my_table AS t",
@@ -1230,3 +1278,36 @@ SELECT
},
pretty=True,
)
+
+ def test_transactions(self):
+ self.validate_all(
+ "BEGIN TRANSACTION",
+ write={
+ "bigquery": "BEGIN TRANSACTION",
+ "mysql": "BEGIN",
+ "postgres": "BEGIN",
+ "presto": "START TRANSACTION",
+ "trino": "START TRANSACTION",
+ "redshift": "BEGIN",
+ "snowflake": "BEGIN",
+ "sqlite": "BEGIN TRANSACTION",
+ },
+ )
+ self.validate_all(
+ "BEGIN",
+ read={
+ "presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
+ "trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
+ },
+ )
+ self.validate_all(
+ "BEGIN",
+ read={
+ "presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
+ "trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
+ },
+ )
+ self.validate_all(
+ "BEGIN IMMEDIATE TRANSACTION",
+ write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
+ )
diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py
new file mode 100644
index 0000000..9819daa
--- /dev/null
+++ b/tests/dialects/test_drill.py
@@ -0,0 +1,53 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestDrill(Validator):
+ dialect = "drill"
+
+ def test_string_literals(self):
+ self.validate_all(
+ "SELECT '2021-01-01' + INTERVAL 1 MONTH",
+ write={
+ "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH",
+ },
+ )
+
+ def test_quotes(self):
+ self.validate_all(
+ "'\\''",
+ write={
+ "duckdb": "''''",
+ "presto": "''''",
+ "hive": "'\\''",
+ "spark": "'\\''",
+ },
+ )
+ self.validate_all(
+ "'\"x\"'",
+ write={
+ "duckdb": "'\"x\"'",
+ "presto": "'\"x\"'",
+ "hive": "'\"x\"'",
+ "spark": "'\"x\"'",
+ },
+ )
+ self.validate_all(
+ "'\\\\a'",
+ read={
+ "presto": "'\\a'",
+ },
+ write={
+ "duckdb": "'\\a'",
+ "presto": "'\\a'",
+ "hive": "'\\\\a'",
+ "spark": "'\\\\a'",
+ },
+ )
+
+ def test_table_function(self):
+ self.validate_all(
+ "SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))",
+ write={
+ "drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))",
+ },
+ )
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 1ba118b..af98249 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -58,6 +58,16 @@ class TestMySQL(Validator):
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON")
+ self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE")
+ self.validate_identity("SET TRANSACTION READ ONLY")
+ self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
+ self.validate_identity("SELECT SCHEMA()")
+
+ def test_canonical_functions(self):
+ self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
+ self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
+ self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
+ self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
def test_escape(self):
self.validate_all(
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 098ad2b..8179cf7 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -178,6 +178,15 @@ class TestPresto(Validator):
},
)
self.validate_all(
+ "CREATE TABLE test STORED = 'PARQUET' AS SELECT 1",
+ write={
+ "duckdb": "CREATE TABLE test AS SELECT 1",
+ "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
+ "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
+ "spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
+ },
+ )
+ self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
write={
"duckdb": "CREATE TABLE test AS SELECT 1",
@@ -427,3 +436,69 @@ class TestPresto(Validator):
"spark": UnsupportedError,
},
)
+ self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
+ self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
+
+ def test_encode_decode(self):
+ self.validate_all(
+ "TO_UTF8(x)",
+ write={
+ "spark": "ENCODE(x, 'utf-8')",
+ },
+ )
+ self.validate_all(
+ "FROM_UTF8(x)",
+ write={
+ "spark": "DECODE(x, 'utf-8')",
+ },
+ )
+ self.validate_all(
+ "ENCODE(x, 'utf-8')",
+ write={
+ "presto": "TO_UTF8(x)",
+ },
+ )
+ self.validate_all(
+ "DECODE(x, 'utf-8')",
+ write={
+ "presto": "FROM_UTF8(x)",
+ },
+ )
+ self.validate_all(
+ "ENCODE(x, 'invalid')",
+ write={
+ "presto": UnsupportedError,
+ },
+ )
+ self.validate_all(
+ "DECODE(x, 'invalid')",
+ write={
+ "presto": UnsupportedError,
+ },
+ )
+
+ def test_hex_unhex(self):
+ self.validate_all(
+ "TO_HEX(x)",
+ write={
+ "spark": "HEX(x)",
+ },
+ )
+ self.validate_all(
+ "FROM_HEX(x)",
+ write={
+ "spark": "UNHEX(x)",
+ },
+ )
+ self.validate_all(
+ "HEX(x)",
+ write={
+ "presto": "TO_HEX(x)",
+ },
+ )
+ self.validate_all(
+ "UNHEX(x)",
+ write={
+ "presto": "FROM_HEX(x)",
+ },
+ )
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 1846b17..0e69f4e 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -169,6 +169,17 @@ class TestSnowflake(Validator):
"snowflake": "SELECT a FROM test AS unpivot",
},
)
+ self.validate_all(
+ "trim(date_column, 'UTC')",
+ write={
+ "snowflake": "TRIM(date_column, 'UTC')",
+ "postgres": "TRIM('UTC' FROM date_column)",
+ },
+ )
+ self.validate_all(
+ "trim(date_column)",
+ write={"snowflake": "TRIM(date_column)"},
+ )
def test_null_treatment(self):
self.validate_all(
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 836ab28..75bd25d 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1
SET -v
-ADD JAR s3://bucket
-ADD JARS s3://bucket, c
-ADD FILE s3://file
-ADD FILES s3://file, s3://a
-ADD ARCHIVE s3://file
-ADD ARCHIVES s3://file, s3://a
-BEGIN IMMEDIATE TRANSACTION
COMMIT
USE db
NOT 1
@@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test
SELECT COUNT(a) FROM test
SELECT COUNT(1) FROM test
SELECT COUNT(*) FROM test
+SELECT COUNT() FROM test
SELECT COUNT(DISTINCT a) FROM test
SELECT EXP(a) FROM test
SELECT FLOOR(a) FROM test
@@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
+WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query)
+WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query))))
(SELECT 1) UNION (SELECT 2)
(SELECT 1) UNION SELECT 2
SELECT 1 UNION (SELECT 2)
@@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI
CREATE TABLE z (a INT(11) DEFAULT UUID())
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
+CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
CREATE TABLE z (a INT, PRIMARY KEY(a))
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
@@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte))
CREATE TABLE z (a INT UNIQUE)
CREATE TABLE z (a INT AUTO_INCREMENT)
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
+CREATE TABLE z (a INT REFERENCES parent(b, c))
+CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
+CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f
@@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1
DELETE FROM y
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
+DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid
+PREPARE statement
+EXECUTE statement
DROP TABLE a
DROP TABLE a.b
DROP TABLE IF EXISTS a
DROP TABLE IF EXISTS a.b
+DROP TABLE a CASCADE
DROP VIEW a
DROP VIEW a.b
DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b
SHOW TABLES
USE db
+BEGIN
ROLLBACK
+ROLLBACK TO b
EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y)
@@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
+SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql
new file mode 100644
index 0000000..7fcdbb8
--- /dev/null
+++ b/tests/fixtures/optimizer/canonicalize.sql
@@ -0,0 +1,5 @@
+SELECT w.d + w.e AS c FROM w AS w;
+SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
+
+SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
+SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index eb7e9cb..a1e531b 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -119,7 +119,7 @@ GROUP BY
LIMIT 1;
# title: Root subquery is union
-(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
+(SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1;
(
SELECT
"x"."b" AS "b"
@@ -128,6 +128,8 @@ LIMIT 1;
SELECT
"y"."b" AS "b"
FROM "y" AS "y"
+ ORDER BY
+ "b"
)
LIMIT 1;
diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
index b91205c..8138b11 100644
--- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql
+++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
@@ -15,7 +15,7 @@ select
from
lineitem
where
- CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day
+ l_shipdate <= date '1998-12-01' - interval '90' day
group by
l_returnflag,
l_linestatus
@@ -250,8 +250,8 @@ FROM "orders" AS "orders"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
WHERE
- "orders"."o_orderdate" < CAST('1993-10-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE)
+ CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE)
AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY
"orders"."o_orderpriority"
@@ -293,8 +293,8 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
- AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE)
JOIN "region" AS "region"
ON "region"."r_name" = 'ASIA'
JOIN "nation" AS "nation"
@@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07
AND "lineitem"."l_quantity" < 24
- AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE);
+ AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
--------------------------------------
-- TPC-H 7
@@ -384,13 +384,13 @@ WITH "n1" AS (
SELECT
"n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation",
- EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
- ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@@ -409,7 +409,7 @@ JOIN "n1" AS "n2"
GROUP BY
"n1"."n_name",
"n2"."n_name",
- EXTRACT(year FROM "lineitem"."l_shipdate")
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
ORDER BY
"supp_nation",
"cust_nation",
@@ -456,7 +456,7 @@ group by
order by
o_year;
SELECT
- EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM(
CASE
WHEN "nation_2"."n_name" = 'BRAZIL'
@@ -477,7 +477,7 @@ JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey"
- AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey"
@@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2"
WHERE
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY
- EXTRACT(year FROM "orders"."o_orderdate")
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY
"o_year";
@@ -529,7 +529,7 @@ order by
o_year desc;
SELECT
"nation"."n_name" AS "nation",
- EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM(
"lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
@@ -551,7 +551,7 @@ WHERE
"part"."p_name" LIKE '%green%'
GROUP BY
"nation"."n_name",
- EXTRACT(year FROM "orders"."o_orderdate")
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY
"nation",
"o_year" DESC;
@@ -606,8 +606,8 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
- AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R'
JOIN "nation" AS "nation"
@@ -740,8 +740,8 @@ SELECT
FROM "orders" AS "orders"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
- AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
- AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
+ AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE)
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
@@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem"
JOIN "part" AS "part"
ON "lineitem"."l_partkey" = "part"."p_partkey"
WHERE
- "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
+ CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE);
--------------------------------------
-- TPC-H 15
@@ -876,8 +876,8 @@ WITH "revenue" AS (
)) AS "total_revenue"
FROM "lineitem" AS "lineitem"
WHERE
- "lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE)
+ CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
GROUP BY
"lineitem"."l_suppkey"
)
@@ -1220,8 +1220,8 @@ WITH "_u_0" AS (
"lineitem"."l_suppkey" AS "_u_2"
FROM "lineitem" AS "lineitem"
WHERE
- "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE)
+ CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE)
GROUP BY
"lineitem"."l_partkey",
"lineitem"."l_suppkey"
diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql
index 5e27b5e..067fe77 100644
--- a/tests/fixtures/pretty.sql
+++ b/tests/fixtures/pretty.sql
@@ -315,3 +315,10 @@ FROM (
WHERE
id = 1
) /* x */;
+SELECT * /* multi
+ line
+ comment */;
+SELECT
+ * /* multi
+ line
+ comment */;
diff --git a/tests/helpers.py b/tests/helpers.py
index dabaf1c..9abdaae 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower(
TPCH_SCHEMA = {
"lineitem": {
- "l_orderkey": "uint64",
- "l_partkey": "uint64",
- "l_suppkey": "uint64",
- "l_linenumber": "uint64",
- "l_quantity": "float64",
- "l_extendedprice": "float64",
- "l_discount": "float64",
- "l_tax": "float64",
+ "l_orderkey": "bigint",
+ "l_partkey": "bigint",
+ "l_suppkey": "bigint",
+ "l_linenumber": "bigint",
+ "l_quantity": "double",
+ "l_extendedprice": "double",
+ "l_discount": "double",
+ "l_tax": "double",
"l_returnflag": "string",
"l_linestatus": "string",
- "l_shipdate": "date32",
- "l_commitdate": "date32",
- "l_receiptdate": "date32",
+ "l_shipdate": "string",
+ "l_commitdate": "string",
+ "l_receiptdate": "string",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
"orders": {
- "o_orderkey": "uint64",
- "o_custkey": "uint64",
+ "o_orderkey": "bigint",
+ "o_custkey": "bigint",
"o_orderstatus": "string",
- "o_totalprice": "float64",
- "o_orderdate": "date32",
+ "o_totalprice": "double",
+ "o_orderdate": "string",
"o_orderpriority": "string",
"o_clerk": "string",
- "o_shippriority": "int32",
+ "o_shippriority": "int",
"o_comment": "string",
},
"customer": {
- "c_custkey": "uint64",
+ "c_custkey": "bigint",
"c_name": "string",
"c_address": "string",
- "c_nationkey": "uint64",
+ "c_nationkey": "bigint",
"c_phone": "string",
- "c_acctbal": "float64",
+ "c_acctbal": "double",
"c_mktsegment": "string",
"c_comment": "string",
},
"part": {
- "p_partkey": "uint64",
+ "p_partkey": "bigint",
"p_name": "string",
"p_mfgr": "string",
"p_brand": "string",
"p_type": "string",
- "p_size": "int32",
+ "p_size": "int",
"p_container": "string",
- "p_retailprice": "float64",
+ "p_retailprice": "double",
"p_comment": "string",
},
"supplier": {
- "s_suppkey": "uint64",
+ "s_suppkey": "bigint",
"s_name": "string",
"s_address": "string",
- "s_nationkey": "uint64",
+ "s_nationkey": "bigint",
"s_phone": "string",
- "s_acctbal": "float64",
+ "s_acctbal": "double",
"s_comment": "string",
},
"partsupp": {
- "ps_partkey": "uint64",
- "ps_suppkey": "uint64",
- "ps_availqty": "int32",
- "ps_supplycost": "float64",
+ "ps_partkey": "bigint",
+ "ps_suppkey": "bigint",
+ "ps_availqty": "int",
+ "ps_supplycost": "double",
"ps_comment": "string",
},
"nation": {
- "n_nationkey": "uint64",
+ "n_nationkey": "bigint",
"n_name": "string",
- "n_regionkey": "uint64",
+ "n_regionkey": "bigint",
"n_comment": "string",
},
"region": {
- "r_regionkey": "uint64",
+ "r_regionkey": "bigint",
"r_name": "string",
"r_comment": "string",
},
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 49805b9..2c4d7cd 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -1,12 +1,15 @@
import unittest
+from datetime import date
import duckdb
import pandas as pd
from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
+from sqlglot.errors import ExecuteError
from sqlglot.executor import execute
from sqlglot.executor.python import Python
+from sqlglot.executor.table import Table, ensure_tables
from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
@@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
def to_csv(expression):
if isinstance(expression, exp.Table):
return parse_one(
- f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
+ f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
- for sql, _ in self.sqls[0:3]:
- a = self.cached_execute(sql)
- sql = parse_one(sql).transform(to_csv).sql(pretty=True)
- table = execute(sql, TPCH_SCHEMA)
- b = pd.DataFrame(table.rows, columns=table.columns)
- assert_frame_equal(a, b, check_dtype=False)
+ for i, (sql, _) in enumerate(self.sqls[0:7]):
+ with self.subTest(f"tpch-h {i + 1}"):
+ a = self.cached_execute(sql)
+ sql = parse_one(sql).transform(to_csv).sql(pretty=True)
+ table = execute(sql, TPCH_SCHEMA)
+ b = pd.DataFrame(table.rows, columns=table.columns)
+ assert_frame_equal(a, b, check_dtype=False)
+
+ def test_execute_callable(self):
+ tables = {
+ "x": [
+ {"a": "a", "b": "d"},
+ {"a": "b", "b": "e"},
+ {"a": "c", "b": "f"},
+ ],
+ "y": [
+ {"b": "d", "c": "g"},
+ {"b": "e", "c": "h"},
+ {"b": "f", "c": "i"},
+ ],
+ "z": [],
+ }
+ schema = {
+ "x": {
+ "a": "VARCHAR",
+ "b": "VARCHAR",
+ },
+ "y": {
+ "b": "VARCHAR",
+ "c": "VARCHAR",
+ },
+ "z": {"d": "VARCHAR"},
+ }
+
+ for sql, cols, rows in [
+ ("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
+ (
+ "SELECT * FROM x JOIN y ON x.b = y.b",
+ ["a", "b", "b", "c"],
+ [("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
+ ),
+ (
+ "SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
+ ["d"],
+ [("g",), ("h",), ("i",)],
+ ),
+ (
+ "SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
+ ["_col_0"],
+ [("bh",)],
+ ),
+ (
+ "SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
+ ["a", "b", "b", "c"],
+ [("b", "e", "e", "h")],
+ ),
+ (
+ "SELECT * FROM z",
+ ["d"],
+ [],
+ ),
+ (
+ "SELECT d FROM z ORDER BY d",
+ ["d"],
+ [],
+ ),
+ (
+ "SELECT a FROM x WHERE x.a <> 'b'",
+ ["a"],
+ [("a",), ("c",)],
+ ),
+ (
+ "SELECT a AS i FROM x ORDER BY a",
+ ["i"],
+ [("a",), ("b",), ("c",)],
+ ),
+ (
+ "SELECT a AS i FROM x ORDER BY i",
+ ["i"],
+ [("a",), ("b",), ("c",)],
+ ),
+ (
+ "SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
+ ["a", "i"],
+ [(1, "c"), (2, "b"), (3, "a")],
+ ),
+ (
+ "SELECT a /* test */ FROM x LIMIT 1",
+ ["a"],
+ [("a",)],
+ ),
+ ]:
+ with self.subTest(sql):
+ result = execute(sql, schema=schema, tables=tables)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
+
+ def test_set_operations(self):
+ tables = {
+ "x": [
+ {"a": "a"},
+ {"a": "b"},
+ {"a": "c"},
+ ],
+ "y": [
+ {"a": "b"},
+ {"a": "c"},
+ {"a": "d"},
+ ],
+ }
+ schema = {
+ "x": {
+ "a": "VARCHAR",
+ },
+ "y": {
+ "a": "VARCHAR",
+ },
+ }
+
+ for sql, cols, rows in [
+ (
+ "SELECT a FROM x UNION ALL SELECT a FROM y",
+ ["a"],
+ [("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT a FROM x UNION SELECT a FROM y",
+ ["a"],
+ [("a",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT a FROM x EXCEPT SELECT a FROM y",
+ ["a"],
+ [("a",)],
+ ),
+ (
+ "SELECT a FROM x INTERSECT SELECT a FROM y",
+ ["a"],
+ [("b",), ("c",)],
+ ),
+ (
+ """SELECT i.a
+ FROM (
+ SELECT a FROM x UNION SELECT a FROM y
+ ) AS i
+ JOIN (
+ SELECT a FROM x UNION SELECT a FROM y
+ ) AS j
+ ON i.a = j.a""",
+ ["a"],
+ [("a",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
+ ["a"],
+ [(1,), (2,), (3,)],
+ ),
+ ]:
+ with self.subTest(sql):
+ result = execute(sql, schema=schema, tables=tables)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(set(result.rows), set(rows))
+
+ def test_execute_catalog_db_table(self):
+ tables = {
+ "catalog": {
+ "db": {
+ "x": [
+ {"a": "a"},
+ {"a": "b"},
+ {"a": "c"},
+ ],
+ }
+ }
+ }
+ schema = {
+ "catalog": {
+ "db": {
+ "x": {
+ "a": "VARCHAR",
+ }
+ }
+ }
+ }
+ result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
+ result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
+ assert result1.columns == result2.columns
+ assert result1.rows == result2.rows
+
+ def test_execute_tables(self):
+ tables = {
+ "sushi": [
+ {"id": 1, "price": 1.0},
+ {"id": 2, "price": 2.0},
+ {"id": 3, "price": 3.0},
+ ],
+ "order_items": [
+ {"sushi_id": 1, "order_id": 1},
+ {"sushi_id": 1, "order_id": 1},
+ {"sushi_id": 2, "order_id": 1},
+ {"sushi_id": 3, "order_id": 2},
+ ],
+ "orders": [
+ {"id": 1, "user_id": 1},
+ {"id": 2, "user_id": 2},
+ ],
+ }
+
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.user_id,
+ SUM(s.price) AS price
+ FROM orders o
+ JOIN order_items i
+ ON o.id = i.order_id
+ JOIN sushi s
+ ON i.sushi_id = s.id
+ GROUP BY o.user_id
+ """,
+ tables=tables,
+ ).rows,
+ [
+ (1, 4.0),
+ (2, 3.0),
+ ],
+ )
+
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.id, x.*
+ FROM orders o
+ LEFT JOIN (
+ SELECT
+ 1 AS id, 'b' AS x
+ UNION ALL
+ SELECT
+ 3 AS id, 'c' AS x
+ ) x
+ ON o.id = x.id
+ """,
+ tables=tables,
+ ).rows,
+ [(1, 1, "b"), (2, None, None)],
+ )
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.id, x.*
+ FROM orders o
+ RIGHT JOIN (
+ SELECT
+ 1 AS id,
+ 'b' AS x
+ UNION ALL
+ SELECT
+ 3 AS id, 'c' AS x
+ ) x
+ ON o.id = x.id
+ """,
+ tables=tables,
+ ).rows,
+ [
+ (1, 1, "b"),
+ (None, 3, "c"),
+ ],
+ )
+
+ def test_table_depth_mismatch(self):
+ tables = {"table": []}
+ schema = {"db": {"table": {"col": "VARCHAR"}}}
+ with self.assertRaises(ExecuteError):
+ execute("SELECT * FROM table", schema=schema, tables=tables)
+
+ def test_tables(self):
+ tables = ensure_tables(
+ {
+ "catalog1": {
+ "db1": {
+ "t1": [
+ {"a": 1},
+ ],
+ "t2": [
+ {"a": 1},
+ ],
+ },
+ "db2": {
+ "t3": [
+ {"a": 1},
+ ],
+ "t4": [
+ {"a": 1},
+ ],
+ },
+ },
+ "catalog2": {
+ "db3": {
+ "t5": Table(columns=("a",), rows=[(1,)]),
+ "t6": Table(columns=("a",), rows=[(1,)]),
+ },
+ "db4": {
+ "t7": Table(columns=("a",), rows=[(1,)]),
+ "t8": Table(columns=("a",), rows=[(1,)]),
+ },
+ },
+ }
+ )
+
+ t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
+ self.assertEqual(t1.columns, ("a",))
+ self.assertEqual(t1.rows, [(1,)])
+
+ t8 = tables.find(exp.table_(table="t8"))
+ self.assertEqual(t1.columns, t8.columns)
+ self.assertEqual(t1.rows, t8.rows)
+
+ def test_static_queries(self):
+ for sql, cols, rows in [
+ ("SELECT 1", ["_col_0"], [(1,)]),
+ ("SELECT 1 + 2 AS x", ["x"], [(3,)]),
+ ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
+ ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
+ ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
+ ]:
+ result = execute(sql)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
+
+ def test_aggregate_without_group_by(self):
+ result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
+ self.assertEqual(result.columns, ("_col_0",))
+ self.assertEqual(result.rows, [(3,)])
+
+ def test_scalar_functions(self):
+ for sql, expected in [
+ ("CONCAT('a', 'b')", "ab"),
+ ("CONCAT('a', NULL)", None),
+ ("CONCAT_WS('_', 'a', 'b')", "a_b"),
+ ("STR_POSITION('bar', 'foobarbar')", 4),
+ ("STR_POSITION('bar', 'foobarbar', 5)", 7),
+ ("STR_POSITION(NULL, 'foobarbar')", None),
+ ("STR_POSITION('bar', NULL)", None),
+ ("UPPER('foo')", "FOO"),
+ ("UPPER(NULL)", None),
+ ("LOWER('FOO')", "foo"),
+ ("LOWER(NULL)", None),
+ ("IFNULL('a', 'b')", "a"),
+ ("IFNULL(NULL, 'b')", "b"),
+ ("IFNULL(NULL, NULL)", None),
+ ("SUBSTRING('12345')", "12345"),
+ ("SUBSTRING('12345', 3)", "345"),
+ ("SUBSTRING('12345', 3, 0)", ""),
+ ("SUBSTRING('12345', 3, 1)", "3"),
+ ("SUBSTRING('12345', 3, 2)", "34"),
+ ("SUBSTRING('12345', 3, 3)", "345"),
+ ("SUBSTRING('12345', 3, 4)", "345"),
+ ("SUBSTRING('12345', -3)", "345"),
+ ("SUBSTRING('12345', -3, 0)", ""),
+ ("SUBSTRING('12345', -3, 1)", "3"),
+ ("SUBSTRING('12345', -3, 2)", "34"),
+ ("SUBSTRING('12345', 0)", ""),
+ ("SUBSTRING('12345', 0, 1)", ""),
+ ("SUBSTRING(NULL)", None),
+ ("SUBSTRING(NULL, 1)", None),
+ ("CAST(1 AS TEXT)", "1"),
+ ("CAST('1' AS LONG)", 1),
+ ("CAST('1.1' AS FLOAT)", 1.1),
+ ("COALESCE(NULL)", None),
+ ("COALESCE(NULL, NULL)", None),
+ ("COALESCE(NULL, 'b')", "b"),
+ ("COALESCE('a', 'b')", "a"),
+ ("1 << 1", 2),
+ ("1 >> 1", 0),
+ ("1 & 1", 1),
+ ("1 | 1", 1),
+ ("1 < 1", False),
+ ("1 <= 1", True),
+ ("1 > 1", False),
+ ("1 >= 1", True),
+ ("1 + NULL", None),
+ ("IF(true, 1, 0)", 1),
+ ("IF(false, 1, 0)", 0),
+ ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
+ ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
+ ]:
+ with self.subTest(sql):
+ result = execute(f"SELECT {sql}")
+ self.assertEqual(result.rows, [(expected,)])
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index 63371d8..c0927ad 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance)
self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop)
self.assertIsInstance(parse_one("YEAR(a)"), exp.Year)
+ self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
+ self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
+ self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
def test_column(self):
dot = parse_one("a.b.c")
@@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(column.text("expression"), "c")
self.assertEqual(column.text("y"), "")
self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x")
- self.assertEqual(parse_one("select *").text("this"), "")
- self.assertEqual(parse_one("1 + 1").text("this"), "1")
- self.assertEqual(parse_one("'a'").text("this"), "a")
+ self.assertEqual(parse_one("select *").name, "")
+ self.assertEqual(parse_one("1 + 1").name, "1")
+ self.assertEqual(parse_one("'a'").name, "a")
def test_alias(self):
self.assertEqual(alias("foo", "bar").sql(), "foo AS bar")
@@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase):
this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
),
- exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
- exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
+ exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
+ exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
]
),
)
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index a1b7e70..6637a1d 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase):
CREATE TABLE x (a INT, b INT);
CREATE TABLE y (b INT, c INT);
CREATE TABLE z (b INT, c INT);
+ CREATE TABLE w (d TEXT, e TEXT);
INSERT INTO x VALUES (1, 1);
INSERT INTO x VALUES (2, 2);
@@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase):
INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (5, 5);
INSERT INTO y VALUES (null, null);
+
+ INSERT INTO w VALUES ('a', 'b');
"""
)
@@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase):
"b": "INT",
"c": "INT",
},
+ "w": {
+ "d": "TEXT",
+ "e": "TEXT",
+ },
}
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
@@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase):
def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
+ def test_canonicalize(self):
+ optimize = partial(
+ optimizer.optimize,
+ rules=[
+ optimizer.qualify_tables.qualify_tables,
+ optimizer.qualify_columns.qualify_columns,
+ annotate_types,
+ optimizer.canonicalize.canonicalize,
+ ],
+ )
+ self.check_file("canonicalize", optimize, schema=self.schema)
+
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 04c20b1..c747ea3 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -41,12 +41,41 @@ class TestParser(unittest.TestCase):
)
def test_command(self):
- expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1")
+ expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
self.assertEqual(len(expressions), 3)
self.assertEqual(expressions[0].sql(), "SET x = 1")
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
self.assertEqual(expressions[2].sql(), "SELECT 1")
+ def test_transactions(self):
+ expression = parse_one("BEGIN TRANSACTION")
+ self.assertIsNone(expression.this)
+ self.assertEqual(expression.args["modes"], [])
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one("START TRANSACTION", read="mysql")
+ self.assertIsNone(expression.this)
+ self.assertEqual(expression.args["modes"], [])
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one("BEGIN DEFERRED TRANSACTION")
+ self.assertEqual(expression.this, "DEFERRED")
+ self.assertEqual(expression.args["modes"], [])
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one(
+ "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto"
+ )
+ self.assertIsNone(expression.this)
+ self.assertEqual(expression.args["modes"][0], "READ WRITE")
+ self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE")
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one("BEGIN", read="bigquery")
+ self.assertNotIsInstance(expression, exp.Transaction)
+ self.assertIsNone(expression.expression)
+ self.assertEqual(expression.sql(), "BEGIN")
+
def test_identify(self):
expression = parse_one(
"""
@@ -55,14 +84,14 @@ class TestParser(unittest.TestCase):
"""
)
- assert expression.expressions[0].text("this") == "a"
- assert expression.expressions[1].text("this") == "b"
- assert expression.expressions[2].text("alias") == "c"
- assert expression.expressions[3].text("alias") == "D"
- assert expression.expressions[4].text("alias") == "y|z'"
+ assert expression.expressions[0].name == "a"
+ assert expression.expressions[1].name == "b"
+ assert expression.expressions[2].alias == "c"
+ assert expression.expressions[3].alias == "D"
+ assert expression.expressions[4].alias == "y|z'"
table = expression.args["from"].expressions[0]
- assert table.args["this"].args["this"] == "z"
- assert table.args["db"].args["this"] == "y"
+ assert table.this.name == "z"
+ assert table.args["db"].name == "y"
def test_multi(self):
expressions = parse(
@@ -72,8 +101,8 @@ class TestParser(unittest.TestCase):
)
assert len(expressions) == 2
- assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
- assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
+ assert expressions[0].args["from"].expressions[0].this.name == "a"
+ assert expressions[1].args["from"].expressions[0].this.name == "b"
def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE)
@@ -200,7 +229,7 @@ class TestParser(unittest.TestCase):
@patch("sqlglot.parser.logger")
def test_comment_error_n(self, logger):
parse_one(
- """CREATE TABLE x
+ """SUM
(
-- test
)""",
@@ -208,19 +237,19 @@ class TestParser(unittest.TestCase):
)
assert_logger_contains(
- "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.",
+ "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.",
logger,
)
@patch("sqlglot.parser.logger")
def test_comment_error_r(self, logger):
parse_one(
- """CREATE TABLE x (-- test\r)""",
+ """SUM(-- test\r)""",
error_level=ErrorLevel.WARN,
)
assert_logger_contains(
- "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.",
+ "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.",
logger,
)
diff --git a/tests/test_tokens.py b/tests/test_tokens.py
index 943c2b0..d4772ba 100644
--- a/tests/test_tokens.py
+++ b/tests/test_tokens.py
@@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase):
("--comment\nfoo --test", "comment"),
("foo --comment", "comment"),
("foo", None),
+ ("foo /*comment 1*/ /*comment 2*/", "comment 1"),
]
for sql, comment in sql_comment:
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 942053e..1bd2527 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile(sql, **kwargs)[0], target)
def test_alias(self):
+ self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
+ self.assertEqual(
+ transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp"
+ )
+ self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
+ self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
+
for key in ("union", "filter", "over", "from", "join"):
with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
@@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self):
+ self.validate("SELECT */*comment*/", "SELECT * /* comment */")
+ self.validate(
+ "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
+ )
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")