summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--CHANGELOG.md11
-rw-r--r--LICENSE2
-rwxr-xr-xrun_checks.sh2
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/__main__.py7
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py9
-rw-r--r--sqlglot/dialects/dialect.py31
-rw-r--r--sqlglot/dialects/duckdb.py5
-rw-r--r--sqlglot/dialects/hive.py15
-rw-r--r--sqlglot/dialects/mysql.py29
-rw-r--r--sqlglot/dialects/oracle.py8
-rw-r--r--sqlglot/dialects/postgres.py116
-rw-r--r--sqlglot/dialects/presto.py6
-rw-r--r--sqlglot/dialects/redshift.py34
-rw-r--r--sqlglot/dialects/snowflake.py4
-rw-r--r--sqlglot/dialects/spark.py15
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/dialects/trino.py3
-rw-r--r--sqlglot/diff.py35
-rw-r--r--sqlglot/executor/__init__.py10
-rw-r--r--sqlglot/executor/context.py4
-rw-r--r--sqlglot/executor/python.py14
-rw-r--r--sqlglot/executor/table.py5
-rw-r--r--sqlglot/expressions.py169
-rw-r--r--sqlglot/generator.py167
-rw-r--r--sqlglot/optimizer/__init__.py2
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py4
-rw-r--r--sqlglot/optimizer/merge_derived_tables.py232
-rw-r--r--sqlglot/optimizer/normalize.py22
-rw-r--r--sqlglot/optimizer/optimize_joins.py6
-rw-r--r--sqlglot/optimizer/optimizer.py39
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py20
-rw-r--r--sqlglot/optimizer/qualify_columns.py36
-rw-r--r--sqlglot/optimizer/qualify_tables.py4
-rw-r--r--sqlglot/optimizer/schema.py4
-rw-r--r--sqlglot/optimizer/scope.py58
-rw-r--r--sqlglot/optimizer/simplify.py8
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py22
-rw-r--r--sqlglot/parser.py404
-rw-r--r--sqlglot/planner.py21
-rw-r--r--sqlglot/tokens.py184
-rw-r--r--sqlglot/transforms.py4
-rw-r--r--tests/dialects/test_dialect.py133
-rw-r--r--tests/dialects/test_hive.py15
-rw-r--r--tests/dialects/test_mysql.py52
-rw-r--r--tests/dialects/test_postgres.py93
-rw-r--r--tests/dialects/test_redshift.py64
-rw-r--r--tests/dialects/test_snowflake.py32
-rw-r--r--tests/dialects/test_sqlite.py18
-rw-r--r--tests/fixtures/identity.sql7
-rw-r--r--tests/fixtures/optimizer/merge_derived_tables.sql63
-rw-r--r--tests/fixtures/optimizer/optimizer.sql57
-rw-r--r--tests/fixtures/optimizer/tpc-h/tpc-h.sql761
-rw-r--r--tests/helpers.py8
-rw-r--r--tests/test_build.py127
-rw-r--r--tests/test_executor.py20
-rw-r--r--tests/test_expressions.py53
-rw-r--r--tests/test_optimizer.py33
-rw-r--r--tests/test_parser.py37
-rw-r--r--tests/test_transpile.py51
61 files changed, 1844 insertions, 1555 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..0eba6cc
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,11 @@
+Changelog
+=========
+
+v6.1.0
+------
+
+Changes:
+
+- New: mysql group\_concat separator [49a4099](https://github.com/tobymao/sqlglot/commit/49a4099adc93780eeffef8204af36559eab50a9f)
+
+- Improvement: Better nested select parsing [45603f](https://github.com/tobymao/sqlglot/commit/45603f14bf9146dc3f8b330b85a0e25b77630b9b)
diff --git a/LICENSE b/LICENSE
index 388cd5e..05dbdae 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2021 Toby Mao
+Copyright (c) 2022 Toby Mao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/run_checks.sh b/run_checks.sh
index a7dddf4..770f443 100755
--- a/run_checks.sh
+++ b/run_checks.sh
@@ -8,5 +8,5 @@ python -m autoflake -i -r \
--remove-unused-variables \
sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/
-python -m black sqlglot/ tests/
+python -m black --line-length 120 sqlglot/ tests/
python -m unittest
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 0007e34..3fa40ce 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "6.0.4"
+__version__ = "6.1.1"
pretty = False
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
index 25200c4..4161259 100644
--- a/sqlglot/__main__.py
+++ b/sqlglot/__main__.py
@@ -49,12 +49,7 @@ args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
if args.parse:
- sqls = [
- repr(expression)
- for expression in sqlglot.parse(
- args.sql, read=args.read, error_level=error_level
- )
- ]
+ sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
else:
sqls = sqlglot.transpile(
args.sql,
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 5aa7d77..f7d03ad 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
+from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index f4e87c3..1f1f90a 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -44,6 +44,7 @@ class BigQuery(Dialect):
]
IDENTIFIERS = ["`"]
ESCAPE = "\\"
+ HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
@@ -120,9 +121,5 @@ class BigQuery(Dialect):
def intersect_op(self, expression):
if not expression.args.get("distinct", False):
- self.unsupported(
- "INTERSECT without DISTINCT is not supported in BigQuery"
- )
- return (
- f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
- )
+ self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
+ return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 8045f7a..f338c81 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -20,6 +20,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
+ REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
SQLITE = "sqlite"
@@ -53,12 +54,19 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class()
- klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[
- 0
- ]
- klass.identifier_start, klass.identifier_end = list(
- klass.tokenizer_class.IDENTIFIERS.items()
- )[0]
+ klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
+ klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
+
+ if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
+ bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
+ klass.generator_class.TRANSFORMS[
+ exp.BitString
+ ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
+ if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
+ hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
+ klass.generator_class.TRANSFORMS[
+ exp.HexString
+ ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
return klass
@@ -122,9 +130,7 @@ class Dialect(metaclass=_Dialect):
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
def parse_into(self, expression_type, sql, **opts):
- return self.parser(**opts).parse_into(
- expression_type, self.tokenizer.tokenize(sql), sql
- )
+ return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
def generate(self, expression, **opts):
return self.generator(**opts).generate(expression)
@@ -164,9 +170,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
- return (
- lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
- )
+ return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
def approx_count_distinct_sql(self, expression):
@@ -260,8 +264,7 @@ def format_time_lambda(exp_class, dialect, default=None):
return exp_class(
this=list_get(args, 0),
format=Dialect[dialect].format_time(
- list_get(args, 1)
- or (Dialect[dialect].time_format if default is True else default)
+ list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index d83a620..ff3a8b1 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -63,10 +63,7 @@ def _sort_array_reverse(args):
def _struct_pack_sql(self, expression):
- args = [
- self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
- for e in expression.expressions
- ]
+ args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
return f"STRUCT_PACK({', '.join(args)})"
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index e3f3f39..59aa8fa 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -109,9 +109,7 @@ def _unnest_to_explode_sql(self, expression):
alias=exp.TableAlias(this=alias.this, columns=[column]),
)
)
- for expression, column in zip(
- unnest.expressions, alias.columns if alias else []
- )
+ for expression, column in zip(unnest.expressions, alias.columns if alias else [])
)
return self.join_sql(expression)
@@ -206,14 +204,11 @@ class Hive(Dialect):
substr=list_get(args, 0),
position=list_get(args, 2),
),
- "LOG": (
- lambda args: exp.Log.from_arg_list(args)
- if len(args) > 1
- else exp.Ln.from_arg_list(args)
- ),
+ "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": _parse_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
+ "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
"COLLECT_SET": exp.SetAgg.from_arg_list,
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
@@ -262,6 +257,7 @@ class Hive(Dialect):
HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
exp.Quantile: rename_func("PERCENTILE"),
+ exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql,
@@ -296,8 +292,7 @@ class Hive(Dialect):
def datatype_sql(self, expression):
if (
- expression.this
- in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
+ expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions
):
expression = exp.DataType.build("text")
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 93800a6..87a2c41 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -49,6 +49,21 @@ def _str_to_date_sql(self, expression):
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
+def _trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+ remove_chars = self.sql(expression, "expression")
+
+ # Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific
+ if not remove_chars:
+ return self.trim_sql(expression)
+
+ trim_type = f"{trim_type} " if trim_type else ""
+ remove_chars = f"{remove_chars} " if remove_chars else ""
+ from_part = "FROM " if trim_type or remove_chars else ""
+ return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
+
+
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
@@ -88,9 +103,12 @@ class MySQL(Dialect):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
+ BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
+ "SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@@ -145,6 +163,15 @@ class MySQL(Dialect):
"STR_TO_DATE": _str_to_date,
}
+ FUNCTION_PARSERS = {
+ **Parser.FUNCTION_PARSERS,
+ "GROUP_CONCAT": lambda self: self.expression(
+ exp.GroupConcat,
+ this=self._parse_lambda(),
+ separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
+ ),
+ }
+
class Generator(Generator):
NULL_ORDERING_SUPPORTED = False
@@ -158,6 +185,8 @@ class MySQL(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
+ exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
+ exp.Trim: _trim_sql,
}
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 9c8b6f2..91e30b2 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -51,6 +51,14 @@ class Oracle(Dialect):
sep="",
)
+ def alias_sql(self, expression):
+ if isinstance(expression.this, exp.Table):
+ to_sql = self.sql(expression, "alias")
+ # oracle does not allow "AS" between table and alias
+ to_sql = f" {to_sql}" if to_sql else ""
+ return f"{self.sql(expression, 'this')}{to_sql}"
+ return super().alias_sql(expression)
+
def offset_sql(self, expression):
return f"{super().offset_sql(expression)} ROWS"
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 61dff86..c796839 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.transforms import delegate, preprocess
def _date_add_sql(kind):
@@ -32,11 +33,96 @@ def _date_add_sql(kind):
return func
+def _lateral_sql(self, expression):
+ this = self.sql(expression, "this")
+ if isinstance(expression.this, exp.Subquery):
+ return f"LATERAL{self.sep()}{this}"
+ alias = expression.args["alias"]
+ table = alias.name
+ table = f" {table}" if table else table
+ columns = self.expressions(alias, key="columns", flat=True)
+ columns = f" AS {columns}" if columns else ""
+ return f"LATERAL{self.sep()}{this}{table}{columns}"
+
+
+def _substring_sql(self, expression):
+ this = self.sql(expression, "this")
+ start = self.sql(expression, "start")
+ length = self.sql(expression, "length")
+
+ from_part = f" FROM {start}" if start else ""
+ for_part = f" FOR {length}" if length else ""
+
+ return f"SUBSTRING({this}{from_part}{for_part})"
+
+
+def _trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+ remove_chars = self.sql(expression, "expression")
+ collation = self.sql(expression, "collation")
+
+ # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
+ if not remove_chars and not collation:
+ return self.trim_sql(expression)
+
+ trim_type = f"{trim_type} " if trim_type else ""
+ remove_chars = f"{remove_chars} " if remove_chars else ""
+ from_part = "FROM " if trim_type or remove_chars else ""
+ collation = f" COLLATE {collation}" if collation else ""
+ return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
+
+
+def _auto_increment_to_serial(expression):
+ auto = expression.find(exp.AutoIncrementColumnConstraint)
+
+ if auto:
+ expression = expression.copy()
+ expression.args["constraints"].remove(auto.parent)
+ kind = expression.args["kind"]
+
+ if kind.this == exp.DataType.Type.INT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL))
+ elif kind.this == exp.DataType.Type.SMALLINT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL))
+ elif kind.this == exp.DataType.Type.BIGINT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL))
+
+ return expression
+
+
+def _serial_to_generated(expression):
+ kind = expression.args["kind"]
+
+ if kind.this == exp.DataType.Type.SERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.INT)
+ elif kind.this == exp.DataType.Type.SMALLSERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.SMALLINT)
+ elif kind.this == exp.DataType.Type.BIGSERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.BIGINT)
+ else:
+ data_type = None
+
+ if data_type:
+ expression = expression.copy()
+ expression.args["kind"].replace(data_type)
+ constraints = expression.args["constraints"]
+ generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
+ notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
+ if notnull not in constraints:
+ constraints.insert(0, notnull)
+ if generated not in constraints:
+ constraints.insert(0, generated)
+
+ return expression
+
+
class Postgres(Dialect):
null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
- "AM": "%p", # AM or PM
+ "AM": "%p",
+ "PM": "%p",
"D": "%w", # 1-based day of week
"DD": "%d", # day of month
"DDD": "%j", # zero padded day of year
@@ -65,14 +151,25 @@ class Postgres(Dialect):
}
class Tokenizer(Tokenizer):
+ BIT_STRINGS = [("b'", "'"), ("B'", "'")]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
- "SERIAL": TokenType.AUTO_INCREMENT,
+ "ALWAYS": TokenType.ALWAYS,
+ "BY DEFAULT": TokenType.BY_DEFAULT,
+ "IDENTITY": TokenType.IDENTITY,
+ "FOR": TokenType.FOR,
+ "GENERATED": TokenType.GENERATED,
+ "DOUBLE PRECISION": TokenType.DOUBLE,
+ "BIGSERIAL": TokenType.BIGSERIAL,
+ "SERIAL": TokenType.SERIAL,
+ "SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
}
class Parser(Parser):
STRICT_CAST = False
+
FUNCTIONS = {
**Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
@@ -86,14 +183,18 @@ class Postgres(Dialect):
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
- }
-
- TOKEN_MAPPING = {
- TokenType.AUTO_INCREMENT: "SERIAL",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
+ exp.ColumnDef: preprocess(
+ [
+ _auto_increment_to_serial,
+ _serial_to_generated,
+ ],
+ delegate("columndef_sql"),
+ ),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
@@ -102,8 +203,11 @@ class Postgres(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
+ exp.Lateral: _lateral_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
+ exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
}
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index ca913e4..7253f7e 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -96,9 +96,7 @@ def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
- return (
- f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
- )
+ return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
def _ts_or_ds_add_sql(self, expression):
@@ -141,6 +139,7 @@ class Presto(Dialect):
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Generator):
@@ -193,6 +192,7 @@ class Presto(Dialect):
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql,
+ exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
new file mode 100644
index 0000000..e1f7b78
--- /dev/null
+++ b/sqlglot/dialects/redshift.py
@@ -0,0 +1,34 @@
+from sqlglot import exp
+from sqlglot.dialects.postgres import Postgres
+from sqlglot.tokens import TokenType
+
+
+class Redshift(Postgres):
+ time_format = "'YYYY-MM-DD HH:MI:SS'"
+ time_mapping = {
+ **Postgres.time_mapping,
+ "MON": "%b",
+ "HH": "%H",
+ }
+
+ class Tokenizer(Postgres.Tokenizer):
+ ESCAPE = "\\"
+
+ KEYWORDS = {
+ **Postgres.Tokenizer.KEYWORDS,
+ "GEOMETRY": TokenType.GEOMETRY,
+ "GEOGRAPHY": TokenType.GEOGRAPHY,
+ "HLLSKETCH": TokenType.HLLSKETCH,
+ "SUPER": TokenType.SUPER,
+ "TIME": TokenType.TIMESTAMP,
+ "TIMETZ": TokenType.TIMESTAMPTZ,
+ "VARBYTE": TokenType.BINARY,
+ "SIMILAR TO": TokenType.SIMILAR_TO,
+ }
+
+ class Generator(Postgres.Generator):
+ TYPE_MAPPING = {
+ **Postgres.Generator.TYPE_MAPPING,
+ exp.DataType.Type.BINARY: "VARBYTE",
+ exp.DataType.Type.INT: "INTEGER",
+ }
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 148dfb5..8d6ee78 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -23,9 +23,7 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
- raise ValueError(
- f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
- )
+ raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 89c7ed5..a331191 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -65,12 +65,11 @@ class Spark(Hive):
this=list_get(args, 0),
start=exp.Sub(
this=exp.Length(this=list_get(args, 0)),
- expression=exp.Add(
- this=list_get(args, 1), expression=exp.Literal.number(1)
- ),
+ expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
),
length=list_get(args, 1),
),
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Hive.Generator):
@@ -82,11 +81,7 @@ class Spark(Hive):
}
TRANSFORMS = {
- **{
- k: v
- for k, v in Hive.Generator.TRANSFORMS.items()
- if k not in {exp.ArraySort}
- },
+ **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
@@ -102,5 +97,5 @@ class Spark(Hive):
HiveMap: _map_sql,
}
- def bitstring_sql(self, expression):
- return f"X'{self.sql(expression, 'this')}'"
+ class Tokenizer(Hive.Tokenizer):
+ HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 6cf5022..cfdbe1b 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -16,6 +16,7 @@ from sqlglot.tokens import Tokenizer, TokenType
class SQLite(Dialect):
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 805106c..9a6f7fe 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -8,3 +8,6 @@ class Trino(Presto):
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}
+
+ class Tokenizer(Presto.Tokenizer):
+ HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 8eeb4e9..0567c12 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -115,13 +115,8 @@ class ChangeDistiller:
for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
- if (
- not isinstance(source_node, LEAF_EXPRESSION_TYPES)
- or source_node == target_node
- ):
- edit_script.extend(
- self._generate_move_edits(source_node, target_node, matching_set)
- )
+ if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
+ edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
@@ -132,9 +127,7 @@ class ChangeDistiller:
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
- args_lcs = set(
- _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
- )
+ args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set))
move_edits = []
for a in source_args:
@@ -148,14 +141,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
- id(n[0]): None
- for n in self._source.bfs()
- if id(n[0]) in self._unmatched_source_nodes
+ id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
- id(n[0]): None
- for n in self._target.bfs()
- if id(n[0]) in self._unmatched_target_nodes
+ id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@@ -169,18 +158,13 @@ class ChangeDistiller:
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
common_leaves_num = sum(
- 1 if s in source_leaf_ids and t in target_leaf_ids else 0
- for s, t in leaves_matching_set
+ 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
- adjusted_t = (
- self.t
- if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
- else 0.4
- )
+ adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
@@ -217,10 +201,7 @@ class ChangeDistiller:
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
- if (
- id(source_leaf) in self._unmatched_source_nodes
- and id(target_leaf) in self._unmatched_target_nodes
- ):
+ if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf))
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index a437431..bca9f3e 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -3,11 +3,17 @@ import time
from sqlglot import parse_one
from sqlglot.executor.python import PythonExecutor
-from sqlglot.optimizer import optimize
+from sqlglot.optimizer import RULES, optimize
+from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.planner import Plan
logger = logging.getLogger("sqlglot")
+OPTIMIZER_RULES = list(RULES)
+
+# The executor needs isolated table selects
+OPTIMIZER_RULES.remove(merge_derived_tables)
+
def execute(sql, schema, read=None):
"""
@@ -28,7 +34,7 @@ def execute(sql, schema, read=None):
"""
expression = parse_one(sql, read=read)
now = time.time()
- expression = optimize(expression, schema)
+ expression = optimize(expression, schema, rules=OPTIMIZER_RULES)
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression)
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 457bea7..d265a2c 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -19,9 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
- self.range_readers = {
- name: table.range_reader for name, table in self.tables.items()
- }
+ self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 388a419..610aa4b 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -26,11 +26,7 @@ class PythonExecutor:
while queue:
node = queue.pop()
context = self.context(
- {
- name: table
- for dep in node.dependencies
- for name, table in contexts[dep].tables.items()
- }
+ {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
)
running.add(node)
@@ -151,9 +147,7 @@ class PythonExecutor:
return self.context({name: table for name in ctx.tables})
for name, join in step.joins.items():
- join_context = self.context(
- {**join_context.tables, name: context.tables[name]}
- )
+ join_context = self.context({**join_context.tables, name: context.tables[name]})
if join.get("source_key"):
table = self.hash_join(join, source, name, join_context)
@@ -247,9 +241,7 @@ class PythonExecutor:
if step.operands:
source_table = context.tables[source]
- operand_table = Table(
- source_table.columns + self.table(step.operands).columns
- )
+ operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 6df49f7..80674cb 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -37,10 +37,7 @@ class Table:
break
lines.append(
- " ".join(
- str(row[column]).rjust(widths[column])[0 : widths[column]]
- for column in self.columns
- )
+ " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
)
return "\n".join(lines)
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 7acc63d..b983bf9 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -47,10 +47,7 @@ class Expression(metaclass=_Expression):
return hash(
(
self.key,
- tuple(
- (k, tuple(v) if isinstance(v, list) else v)
- for k, v in _norm_args(self).items()
- ),
+ tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
)
)
@@ -116,9 +113,22 @@ class Expression(metaclass=_Expression):
item.parent = parent
return new
+ def append(self, arg_key, value):
+ """
+ Appends value to arg_key if it's a list or sets it as a new list.
+
+ Args:
+ arg_key (str): name of the list expression arg
+ value (Any): value to append to the list
+ """
+ if not isinstance(self.args.get(arg_key), list):
+ self.args[arg_key] = []
+ self.args[arg_key].append(value)
+ self._set_parent(arg_key, value)
+
def set(self, arg_key, value):
"""
- Sets `arg` to `value`.
+ Sets `arg_key` to `value`.
Args:
arg_key (str): name of the expression arg
@@ -267,6 +277,14 @@ class Expression(metaclass=_Expression):
expression = expression.this
return expression
+ def unalias(self):
+ """
+ Returns the inner expression if this is an Alias.
+ """
+ if isinstance(self, Alias):
+ return self.this
+ return self
+
def unnest_operands(self):
"""
Returns unnested operands as a tuple.
@@ -279,9 +297,7 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(
- prune=lambda n, p, *_: p and not isinstance(n, self.__class__)
- ):
+ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
if not isinstance(node, self.__class__):
yield node.unnest() if unnest else node
@@ -314,9 +330,7 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
- v.to_s(hide_missing=hide_missing, level=level + 1)
- if hasattr(v, "to_s")
- else str(v)
+ v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_list(vs)
if v is not None
)
@@ -354,9 +368,7 @@ class Expression(metaclass=_Expression):
new_node.parent = node.parent
return new_node
- replace_children(
- new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)
- )
+ replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
return new_node
def replace(self, expression):
@@ -546,6 +558,10 @@ class BitString(Condition):
pass
+class HexString(Condition):
+ pass
+
+
class Column(Condition):
arg_types = {"this": True, "table": False}
@@ -566,35 +582,44 @@ class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
-class AutoIncrementColumnConstraint(Expression):
+class ColumnConstraintKind(Expression):
pass
-class CheckColumnConstraint(Expression):
+class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass
-class CollateColumnConstraint(Expression):
+class CheckColumnConstraint(ColumnConstraintKind):
pass
-class CommentColumnConstraint(Expression):
+class CollateColumnConstraint(ColumnConstraintKind):
pass
-class DefaultColumnConstraint(Expression):
+class CommentColumnConstraint(ColumnConstraintKind):
pass
-class NotNullColumnConstraint(Expression):
+class DefaultColumnConstraint(ColumnConstraintKind):
pass
-class PrimaryKeyColumnConstraint(Expression):
+class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
+ # this: True -> ALWAYS, this: False -> BY DEFAULT
+ arg_types = {"this": True, "expression": False}
+
+
+class NotNullColumnConstraint(ColumnConstraintKind):
pass
-class UniqueColumnConstraint(Expression):
+class PrimaryKeyColumnConstraint(ColumnConstraintKind):
+ pass
+
+
+class UniqueColumnConstraint(ColumnConstraintKind):
pass
@@ -651,9 +676,7 @@ class Identifier(Expression):
return bool(self.args.get("quoted"))
def __eq__(self, other):
- return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(
- other.this
- )
+ return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
def __hash__(self):
return hash((self.key, self.this.lower()))
@@ -709,9 +732,7 @@ class Literal(Condition):
def __eq__(self, other):
return (
- isinstance(other, Literal)
- and self.this == other.this
- and self.args["is_string"] == other.args["is_string"]
+ isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
@@ -733,6 +754,7 @@ class Join(Expression):
"side": False,
"kind": False,
"using": False,
+ "natural": False,
}
@property
@@ -743,6 +765,10 @@ class Join(Expression):
def side(self):
return self.text("side").upper()
+ @property
+ def alias_or_name(self):
+ return self.this.alias_or_name
+
def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
"""
Append to or set the ON expressions.
@@ -873,10 +899,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
-class Table(Expression):
- arg_types = {"this": True, "db": False, "catalog": False}
-
-
class Tuple(Expression):
arg_types = {"expressions": False}
@@ -986,6 +1008,16 @@ QUERY_MODIFIERS = {
}
+class Table(Expression):
+ arg_types = {
+ "this": True,
+ "db": False,
+ "catalog": False,
+ "laterals": False,
+ "joins": False,
+ }
+
+
class Union(Subqueryable, Expression):
arg_types = {
"with": False,
@@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression):
join.this.replace(join.this.subquery())
if join_type:
- side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ if natural:
+ join.set("natural", True)
if side:
join.set("side", side.text)
if kind:
@@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression):
properties_expression = None
if properties:
properties_str = " ".join(
- [
- f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
- for k, v in properties.items()
- ]
+ [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
)
properties_expression = maybe_parse(
properties_str,
@@ -1654,6 +1685,7 @@ class DataType(Expression):
DECIMAL = auto()
BOOLEAN = auto()
JSON = auto()
+ INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
DATE = auto()
@@ -1662,15 +1694,19 @@ class DataType(Expression):
MAP = auto()
UUID = auto()
GEOGRAPHY = auto()
+ GEOMETRY = auto()
STRUCT = auto()
NULLABLE = auto()
+ HLLSKETCH = auto()
+ SUPER = auto()
+ SERIAL = auto()
+ SMALLSERIAL = auto()
+ BIGSERIAL = auto()
@classmethod
def build(cls, dtype, **kwargs):
return DataType(
- this=dtype
- if isinstance(dtype, DataType.Type)
- else DataType.Type[dtype.upper()],
+ this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
)
@@ -1798,6 +1834,14 @@ class Like(Binary, Predicate):
pass
+class SimilarTo(Binary, Predicate):
+ pass
+
+
+class Distance(Binary):
+ pass
+
+
class LT(Binary, Predicate):
pass
@@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression):
pass
+class RespectNulls(Expression):
+ pass
+
+
# Functions
class Func(Condition):
"""
@@ -1924,9 +1972,7 @@ class Func(Condition):
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
- non_var_len_arg_keys = (
- all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
- )
+ non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
args_dict = {}
arg_idx = 0
@@ -1944,9 +1990,7 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
- raise NotImplementedError(
- "SQL name is only supported by concrete function implementations"
- )
+ raise NotImplementedError("SQL name is only supported by concrete function implementations")
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@@ -2178,6 +2222,10 @@ class Greatest(Func):
is_var_len_args = True
+class GroupConcat(Func):
+ arg_types = {"this": True, "separator": False}
+
+
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -2274,6 +2322,10 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
+class ApproxQuantile(Quantile):
+ pass
+
+
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@@ -2306,8 +2358,10 @@ class Split(Func):
arg_types = {"this": True, "expression": True}
+# Start may be omitted in the case of postgres
+# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func):
- arg_types = {"this": True, "start": True, "length": False}
+ arg_types = {"this": True, "start": False, "length": False}
class StrPosition(Func):
@@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func):
pass
+class Trim(Func):
+ arg_types = {
+ "this": True,
+ "position": False,
+ "expression": False,
+ "collation": False,
+ }
+
+
class TsOrDsAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -2455,9 +2518,7 @@ def _all_functions():
obj
for _, obj in inspect.getmembers(
sys.modules[__name__],
- lambda obj: inspect.isclass(obj)
- and issubclass(obj, Func)
- and obj not in (AggFunc, Anonymous, Func),
+ lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
)
]
@@ -2633,9 +2694,7 @@ def _apply_conjunction_builder(
def _combine(expressions, operator, dialect=None, **opts):
- expressions = [
- condition(expression, dialect=dialect, **opts) for expression in expressions
- ]
+ expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
@@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None):
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
identifier = Identifier(this=alias, quoted=quoted)
else:
- raise ValueError(
- f"Alias needs to be a string or an Identifier, got: {alias.__class__}"
- )
+ raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
return identifier
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 793cff0..a445178 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -41,6 +41,8 @@ class Generator:
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
+ leading_comma (bool): if the the comma is leading or trailing in select statements
+ Default: False
"""
TRANSFORMS = {
@@ -108,6 +110,7 @@ class Generator:
"_indent",
"_replace_backslash",
"_escaped_quote_end",
+ "_leading_comma",
)
def __init__(
@@ -131,6 +134,7 @@ class Generator:
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
max_unsupported=3,
+ leading_comma=False,
):
import sqlglot
@@ -157,6 +161,7 @@ class Generator:
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
+ self._leading_comma = leading_comma
def generate(self, expression):
"""
@@ -178,9 +183,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
- raise UnsupportedError(
- concat_errors(self.unsupported_messages, self.max_unsupported)
- )
+ raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
return sql
@@ -197,9 +200,7 @@ class Generator:
def wrap(self, expression):
this_sql = self.indent(
- self.sql(expression)
- if isinstance(expression, (exp.Select, exp.Union))
- else self.sql(expression, "this"),
+ self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
level=1,
pad=0,
)
@@ -251,9 +252,7 @@ class Generator:
return transform
if not isinstance(expression, exp.Expression):
- raise ValueError(
- f"Expected an Expression. Received {type(expression)}: {expression}"
- )
+ raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
@@ -276,11 +275,7 @@ class Generator:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
- options = (
- f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})"
- if options
- else ""
- )
+ options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
sql = self.sql(expression, "expression")
sql = f" AS{self.sep()}{sql}" if sql else ""
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
@@ -306,9 +301,7 @@ class Generator:
def columndef_sql(self, expression):
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
- constraints = self.expressions(
- expression, key="constraints", sep=" ", flat=True
- )
+ constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
if not constraints:
return f"{column} {kind}"
@@ -338,6 +331,9 @@ class Generator:
default = self.sql(expression, "this")
return f"DEFAULT {default}"
+ def generatedasidentitycolumnconstraint_sql(self, expression):
+ return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
+
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
@@ -384,7 +380,10 @@ class Generator:
return f"{alias}{columns}"
def bitstring_sql(self, expression):
- return f"b'{self.sql(expression, 'this')}'"
+ return self.sql(expression, "this")
+
+ def hexstring_sql(self, expression):
+ return self.sql(expression, "this")
def datatype_sql(self, expression):
type_value = expression.this
@@ -452,10 +451,7 @@ class Generator:
def partition_sql(self, expression):
keys = csv(
- *[
- f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"]
- for k, v in expression.args.get("this")
- ]
+ *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
)
return f"PARTITION({keys})"
@@ -470,9 +466,9 @@ class Generator:
elif p_class in self.WITH_PROPERTIES:
with_properties.append(p)
- return self.root_properties(
- exp.Properties(expressions=root_properties)
- ) + self.with_properties(exp.Properties(expressions=with_properties))
+ return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
+ exp.Properties(expressions=with_properties)
+ )
def root_properties(self, properties):
if properties.expressions:
@@ -508,11 +504,7 @@ class Generator:
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
this = self.sql(expression, "this")
exists = " IF EXISTS " if expression.args.get("exists") else " "
- partition_sql = (
- self.sql(expression, "partition")
- if expression.args.get("partition")
- else ""
- )
+ partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
@@ -531,7 +523,7 @@ class Generator:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def table_sql(self, expression):
- return ".".join(
+ table = ".".join(
part
for part in [
self.sql(expression, "catalog"),
@@ -541,6 +533,10 @@ class Generator:
if part
)
+ laterals = self.expressions(expression, key="laterals", sep="")
+ joins = self.expressions(expression, key="joins", sep="")
+ return f"{table}{laterals}{joins}"
+
def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
this = self.sql(expression.this, "this")
@@ -586,11 +582,7 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
- grouping_sets = (
- f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}"
- if grouping_sets
- else ""
- )
+ grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@@ -603,7 +595,16 @@ class Generator:
def join_sql(self, expression):
op_sql = self.seg(
- " ".join(op for op in (expression.side, expression.kind, "JOIN") if op)
+ " ".join(
+ op
+ for op in (
+ "NATURAL" if expression.args.get("natural") else None,
+ expression.side,
+ expression.kind,
+ "JOIN",
+ )
+ if op
+ )
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@@ -630,9 +631,9 @@ class Generator:
def lateral_sql(self, expression):
this = self.sql(expression, "this")
- op_sql = self.seg(
- f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}"
- )
+ if isinstance(expression.this, exp.Subquery):
+ return f"LATERAL{self.sep()}{this}"
+ op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
@@ -688,21 +689,13 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
- if nulls_first and (
- (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
- ):
+ if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
nulls_sort_change = " NULLS FIRST"
- elif (
- nulls_last
- and ((asc and nulls_are_small) or (desc and nulls_are_large))
- and not nulls_are_last
- ):
+ elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
- self.unsupported(
- "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
- )
+ self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@@ -798,14 +791,20 @@ class Generator:
def window_sql(self, expression):
this = self.sql(expression, "this")
+
partition = self.expressions(expression, key="partition_by", flat=True)
partition = f"PARTITION BY {partition}" if partition else ""
+
order = expression.args.get("order")
order_sql = self.order_sql(order, flat=True) if order else ""
+
partition_sql = partition + " " if partition and order else partition
+
spec = expression.args.get("spec")
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
+
alias = self.sql(expression, "alias")
+
if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS"
else:
@@ -818,13 +817,8 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
- start = csv(
- self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" "
- )
- end = (
- csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
- or "CURRENT ROW"
- )
+ start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
+ end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@@ -879,6 +873,17 @@ class Generator:
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
+ def trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+
+ if trim_type == "LEADING":
+ return f"LTRIM({target})"
+ elif trim_type == "TRAILING":
+ return f"RTRIM({target})"
+ else:
+ return f"TRIM({target})"
+
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@@ -898,9 +903,7 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
- return self.case_sql(
- exp.Case(ifs=[expression], default=expression.args.get("false"))
- )
+ return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def in_sql(self, expression):
query = expression.args.get("query")
@@ -917,7 +920,9 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression):
- return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}"
+ unit = self.sql(expression, "unit")
+ unit = f" {unit}" if unit else ""
+ return f"INTERVAL {self.sql(expression, 'this')}{unit}"
def reference_sql(self, expression):
this = self.sql(expression, "this")
@@ -925,9 +930,7 @@ class Generator:
return f"REFERENCES {this}({expressions})"
def anonymous_sql(self, expression):
- args = self.indent(
- self.expressions(expression, flat=True), skip_first=True, skip_last=True
- )
+ args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
def paren_sql(self, expression):
@@ -1006,6 +1009,9 @@ class Generator:
def ignorenulls_sql(self, expression):
return f"{self.sql(expression, 'this')} IGNORE NULLS"
+ def respectnulls_sql(self, expression):
+ return f"{self.sql(expression, 'this')} RESPECT NULLS"
+
def intdiv_sql(self, expression):
return self.sql(
exp.Cast(
@@ -1023,6 +1029,9 @@ class Generator:
def div_sql(self, expression):
return self.binary(expression, "/")
+ def distance_sql(self, expression):
+ return self.binary(expression, "<->")
+
def dot_sql(self, expression):
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
@@ -1047,6 +1056,9 @@ class Generator:
def like_sql(self, expression):
return self.binary(expression, "LIKE")
+ def similarto_sql(self, expression):
+ return self.binary(expression, "SIMILAR TO")
+
def lt_sql(self, expression):
return self.binary(expression, "<")
@@ -1069,14 +1081,10 @@ class Generator:
return self.binary(expression, "-")
def trycast_sql(self, expression):
- return (
- f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- )
+ return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def binary(self, expression, op):
- return (
- f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
- )
+ return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression):
args = []
@@ -1089,9 +1097,7 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({args_str})"
def format_time(self, expression):
- return format_time(
- self.sql(expression, "format"), self.time_mapping, self.time_trie
- )
+ return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
expressions = expression.args.get(key or "expressions")
@@ -1102,7 +1108,14 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
- expressions = self.sep(sep).join(self.sql(e) for e in expressions)
+ sql = (self.sql(e) for e in expressions)
+ # the only time leading_comma changes the output is if pretty print is enabled
+ if self._leading_comma and self.pretty:
+ pad = " " * self.pad
+ expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
+ else:
+ expressions = self.sep(sep).join(sql)
+
if indent:
return self.indent(expressions, skip_first=False)
return expressions
@@ -1116,9 +1129,7 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
- return self.query_modifiers(
- expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
- )
+ return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
index a4c4cc2..d1146ca 100644
--- a/sqlglot/optimizer/__init__.py
+++ b/sqlglot/optimizer/__init__.py
@@ -1,2 +1,2 @@
-from sqlglot.optimizer.optimizer import optimize
+from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.schema import Schema
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index c2e021e..e060739 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -13,9 +13,7 @@ def isolate_table_selects(expression):
continue
if not isinstance(source.parent, exp.Alias):
- raise OptimizeError(
- "Tables require an alias. Run qualify_tables optimization."
- )
+ raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
parent = source.parent
diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_derived_tables.py
new file mode 100644
index 0000000..8b161fb
--- /dev/null
+++ b/sqlglot/optimizer/merge_derived_tables.py
@@ -0,0 +1,232 @@
+from collections import defaultdict
+
+from sqlglot import expressions as exp
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def merge_derived_tables(expression):
+ """
+ Rewrite sqlglot AST to merge derived tables into the outer query.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
+ >>> merge_derived_tables(expression).sql()
+ 'SELECT x.a FROM x'
+
+ Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ for outer_scope in traverse_scope(expression):
+ for subquery in outer_scope.derived_tables:
+ inner_select = subquery.unnest()
+ if (
+ isinstance(outer_scope.expression, exp.Select)
+ and isinstance(inner_select, exp.Select)
+ and _mergeable(inner_select)
+ ):
+ alias = subquery.alias_or_name
+ from_or_join = subquery.find_ancestor(exp.From, exp.Join)
+ inner_scope = outer_scope.sources[alias]
+
+ _rename_inner_sources(outer_scope, inner_scope, alias)
+ _merge_from(outer_scope, inner_scope, subquery)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
+ _merge_expressions(outer_scope, inner_scope, alias)
+ _merge_where(outer_scope, inner_scope, from_or_join)
+ _merge_order(outer_scope, inner_scope)
+ return expression
+
+
+# If a derived table has these Select args, it can't be merged
+UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
+ "expressions",
+ "from",
+ "joins",
+ "where",
+ "order",
+}
+
+
+def _mergeable(inner_select):
+ """
+ Return True if `inner_select` can be merged into outer query.
+
+ Args:
+ inner_select (exp.Select)
+ Returns:
+ bool: True if can be merged
+ """
+ return (
+ isinstance(inner_select, exp.Select)
+ and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
+ and inner_select.args.get("from")
+ and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
+ )
+
+
+def _rename_inner_sources(outer_scope, inner_scope, alias):
+ """
+ Renames any sources in the inner query that conflict with names in the outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ alias (str)
+ """
+ taken = set(outer_scope.selected_sources)
+ conflicts = taken.intersection(set(inner_scope.selected_sources))
+ conflicts = conflicts - {alias}
+
+ for conflict in conflicts:
+ new_name = _find_new_name(taken, conflict)
+
+ source, _ = inner_scope.selected_sources[conflict]
+ new_alias = exp.to_identifier(new_name)
+
+ if isinstance(source, exp.Subquery):
+ source.set("alias", exp.TableAlias(this=new_alias))
+ elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
+ source.parent.set("alias", new_alias)
+ elif isinstance(source, exp.Table):
+ source.replace(exp.alias_(source.copy(), new_alias))
+
+ for column in inner_scope.source_columns(conflict):
+ column.set("table", exp.to_identifier(new_name))
+
+ inner_scope.rename_source(conflict, new_name)
+
+
+def _find_new_name(taken, base):
+ """
+ Searches for a new source name.
+
+ Args:
+ taken (set[str]): set of taken names
+ base (str): base name to alter
+ """
+ i = 2
+ new = f"{base}_{i}"
+ while new in taken:
+ i += 1
+ new = f"{base}_{i}"
+ return new
+
+
+def _merge_from(outer_scope, inner_scope, subquery):
+ """
+ Merge FROM clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ subquery (exp.Subquery)
+ """
+ new_subquery = inner_scope.expression.args.get("from").expressions[0]
+ subquery.replace(new_subquery)
+ outer_scope.remove_source(subquery.alias_or_name)
+ outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
+
+
+def _merge_joins(outer_scope, inner_scope, from_or_join):
+ """
+ Merge JOIN clauses of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ from_or_join (exp.From|exp.Join)
+ """
+
+ new_joins = []
+ comma_joins = inner_scope.expression.args.get("from").expressions[1:]
+ for subquery in comma_joins:
+ new_joins.append(exp.Join(this=subquery, kind="CROSS"))
+ outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
+
+ joins = inner_scope.expression.args.get("joins") or []
+ for join in joins:
+ new_joins.append(join)
+ outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
+
+ if new_joins:
+ outer_joins = outer_scope.expression.args.get("joins", [])
+
+ # Maintain the join order
+ if isinstance(from_or_join, exp.From):
+ position = 0
+ else:
+ position = outer_joins.index(from_or_join) + 1
+ outer_joins[position:position] = new_joins
+
+ outer_scope.expression.set("joins", outer_joins)
+
+
+def _merge_expressions(outer_scope, inner_scope, alias):
+ """
+ Merge projections of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ alias (str)
+ """
+ # Collect all columns that for the alias of the inner query
+ outer_columns = defaultdict(list)
+ for column in outer_scope.columns:
+ if column.table == alias:
+ outer_columns[column.name].append(column)
+
+ # Replace columns with the projection expression in the inner query
+ for expression in inner_scope.expression.expressions:
+ projection_name = expression.alias_or_name
+ if not projection_name:
+ continue
+ columns_to_replace = outer_columns.get(projection_name, [])
+ for column in columns_to_replace:
+ column.replace(expression.unalias())
+
+
+def _merge_where(outer_scope, inner_scope, from_or_join):
+ """
+ Merge WHERE clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ from_or_join (exp.From|exp.Join)
+ """
+ where = inner_scope.expression.args.get("where")
+ if not where or not where.this:
+ return
+
+ if isinstance(from_or_join, exp.Join) and from_or_join.side:
+ # Merge predicates from an outer join to the ON clause
+ from_or_join.on(where.this, copy=False)
+ from_or_join.set("on", simplify(from_or_join.args.get("on")))
+ else:
+ outer_scope.expression.where(where.this, copy=False)
+ outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
+
+
+def _merge_order(outer_scope, inner_scope):
+ """
+ Merge ORDER clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ """
+ if (
+ any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
+ or len(outer_scope.selected_sources) != 1
+ or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
+ ):
+ return
+
+ outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 2c9f89c..ab30d7a 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128):
"""
expression = simplify(expression)
- expression = while_changing(
- expression, lambda e: distributive_law(e, dnf, max_distance)
- )
+ expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
return simplify(expression)
def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
- return not any(
- connector.find_ancestor(ancestor) for connector in expression.find_all(root)
- )
+ return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
def normalization_distance(expression, dnf=False):
@@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False):
Returns:
int: difference
"""
- return sum(_predicate_lengths(expression, dnf)) - (
- len(list(expression.find_all(exp.Connector))) + 1
- )
+ return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
def _predicate_lengths(expression, dnf):
@@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- x = [
- a + b
- for a in _predicate_lengths(left, dnf)
- for b in _predicate_lengths(right, dnf)
- ]
+ x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
return x
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
@@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance):
to_func = exp.and_ if to_exp == exp.And else exp.or_
if isinstance(a, to_exp) and isinstance(b, to_exp):
- if len(tuple(a.find_all(exp.Connector))) > len(
- tuple(b.find_all(exp.Connector))
- ):
+ if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 40e4ab1..0c74e36 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -68,8 +68,4 @@ def normalize(expression):
def other_table_names(join, exclude):
- return [
- name
- for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
- if name != exclude
- ]
+ return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index c03fe3c..c8c2403 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,6 +1,7 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
@@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
+RULES = (
+ qualify_tables,
+ isolate_table_selects,
+ qualify_columns,
+ pushdown_projections,
+ normalize,
+ unnest_subqueries,
+ expand_multi_table_selects,
+ pushdown_predicates,
+ optimize_joins,
+ eliminate_subqueries,
+ merge_derived_tables,
+ quote_identities,
+)
-def optimize(expression, schema=None, db=None, catalog=None):
+
+def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
"""
Rewrite a sqlglot AST into an optimized form.
@@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None):
3. {catalog: {db: {table: {col: type}}}}
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
+ rules (list): sequence of optimizer rules to use
+ **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
"""
+ possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy()
- expression = qualify_tables(expression, db=db, catalog=catalog)
- expression = isolate_table_selects(expression)
- expression = qualify_columns(expression, schema)
- expression = pushdown_projections(expression)
- expression = normalize(expression)
- expression = unnest_subqueries(expression)
- expression = expand_multi_table_selects(expression)
- expression = pushdown_predicates(expression)
- expression = optimize_joins(expression)
- expression = eliminate_subqueries(expression)
- expression = quote_identities(expression)
+ for rule in rules:
+
+ # Find any additional rule parameters, beyond `expression`
+ rule_params = rule.__code__.co_varnames
+ rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
+
+ expression = rule(expression, **rule_kwargs)
return expression
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index e757322..a070d70 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -42,11 +42,7 @@ def pushdown(condition, sources):
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
- predicates = list(
- condition.flatten()
- if isinstance(condition, exp.And if cnf_like else exp.Or)
- else [condition]
- )
+ predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
if cnf_like:
pushdown_cnf(predicates, sources)
@@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope):
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
- predicate_condition = (
- exp.and_(predicate_condition, condition)
- if predicate_condition
- else condition
- )
+ predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
if predicate_condition:
conditions[table] = (
- exp.or_(conditions[table], predicate_condition)
- if table in conditions
- else predicate_condition
+ exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
)
for name, node in nodes.items():
@@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope):
def nodes_for_predicate(predicate, sources):
nodes = {}
tables = exp.column_table_names(predicate)
- where_condition = isinstance(
- predicate.find_ancestor(exp.Join, exp.Where), exp.Where
- )
+ where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
for table in tables:
node, source = sources.get(table) or (None, None)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 394f49e..0bb947a 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -226,9 +226,7 @@ def _expand_stars(scope, resolver):
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
- elif isinstance(expression, exp.Column) and isinstance(
- expression.this, exp.Star
- ):
+ elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
@@ -245,9 +243,7 @@ def _expand_stars(scope, resolver):
if name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
- new_selections.append(
- alias(column, alias_) if alias_ != name else column
- )
+ new_selections.append(alias(column, alias_) if alias_ != name else column)
scope.expression.set("expressions", new_selections)
@@ -280,9 +276,7 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
- for i, (selection, aliased_column) in enumerate(
- itertools.zip_longest(scope.selects, scope.outer_column_list)
- ):
+ for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
@@ -302,11 +296,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
- if (
- scope.external_columns
- and not scope.is_unnest
- and not scope.is_correlated_subquery
- ):
+ if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
@@ -334,20 +324,14 @@ class _Resolver:
(str) table name
"""
if self._unambiguous_columns is None:
- self._unambiguous_columns = self._get_unambiguous_columns(
- self._get_all_source_columns()
- )
+ self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
- self._all_columns = set(
- column
- for columns in self._get_all_source_columns().values()
- for column in columns
- )
+ self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
return self._all_columns
def get_source_columns(self, name):
@@ -369,9 +353,7 @@ class _Resolver:
def _get_all_source_columns(self):
if self._source_columns is None:
- self._source_columns = {
- k: self.get_source_columns(k) for k in self.scope.selected_sources
- }
+ self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
return self._source_columns
def _get_unambiguous_columns(self, source_columns):
@@ -389,9 +371,7 @@ class _Resolver:
source_columns = list(source_columns.items())
first_table, first_columns = source_columns[0]
- unambiguous_columns = {
- col: first_table for col in self._find_unique_columns(first_columns)
- }
+ unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
all_columns = set(unambiguous_columns)
for table, columns in source_columns[1:]:
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 9f8b9f5..30e93ba 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
- derived_table.set(
- "alias", exp.TableAlias(this=exp.to_identifier(alias_))
- )
+ derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
for source in scope.sources.values():
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
index 9968108..1761228 100644
--- a/sqlglot/optimizer/schema.py
+++ b/sqlglot/optimizer/schema.py
@@ -57,9 +57,7 @@ class MappingSchema(Schema):
for forbidden in self.forbidden_args:
if table.text(forbidden):
- raise ValueError(
- f"Schema doesn't support {forbidden}. Received: {table.sql()}"
- )
+ raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index f6f59e8..e816e10 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -104,9 +104,7 @@ class Scope:
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
- elif isinstance(node, exp.Subquery) and isinstance(
- parent, (exp.From, exp.Join)
- ):
+ elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
@@ -195,20 +193,14 @@ class Scope:
self._ensure_collected()
columns = self._raw_columns
- external_columns = [
- column
- for scope in self.subquery_scopes
- for column in scope.external_columns
- ]
+ external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
- if not (
- c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
- )
+ if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
]
return self._columns
@@ -229,9 +221,7 @@ class Scope:
for table in self.tables:
referenced_names.append(
(
- table.parent.alias
- if isinstance(table.parent, exp.Alias)
- else table.name,
+ table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
table,
)
)
@@ -274,9 +264,7 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
- self._external_columns = [
- c for c in self.columns if c.table not in self.selected_sources
- ]
+ self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
return self._external_columns
def source_columns(self, source_name):
@@ -310,6 +298,16 @@ class Scope:
columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns
+ def add_source(self, name, source):
+ """Add a source to this scope"""
+ self.sources[name] = source
+ self.clear_cache()
+
+ def remove_source(self, name):
+ """Remove a source from this scope"""
+ self.sources.pop(name, None)
+ self.clear_cache()
+
def traverse_scope(expression):
"""
@@ -334,7 +332,7 @@ def traverse_scope(expression):
Args:
expression (exp.Expression): expression to traverse
Returns:
- List[Scope]: scope instances
+ list[Scope]: scope instances
"""
return list(_traverse_scope(Scope(expression)))
@@ -356,9 +354,7 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
- yield from _traverse_derived_tables(
- scope.derived_tables, scope, ScopeType.DERIVED_TABLE
- )
+ yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
_add_table_sources(scope)
@@ -367,15 +363,11 @@ def _traverse_union(scope):
# The last scope to be yield should be the top most scope
left = None
- for left in _traverse_scope(
- scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
- ):
+ for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
yield left
right = None
- for right in _traverse_scope(
- scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
- ):
+ for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
scope.union = (left, right)
@@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
for derived_table in derived_tables:
for child_scope in _traverse_scope(
scope.branch(
- derived_table
- if isinstance(derived_table, (exp.Unnest, exp.Lateral))
- else derived_table.this,
+ derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UNNEST
- if isinstance(derived_table, exp.Unnest)
- else scope_type,
+ scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
)
):
yield child_scope
@@ -430,9 +418,7 @@ def _add_table_sources(scope):
def _traverse_subqueries(scope):
for subquery in scope.subqueries:
top = None
- for child_scope in _traverse_scope(
- scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
- ):
+ for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 6771153..319e6b6 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -188,9 +188,7 @@ def absorb_and_eliminate(expression):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
- elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
- a.flatten()
- ):
+ elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
elif isinstance(b, kind):
# eliminate
@@ -227,9 +225,7 @@ def simplify_literals(expression):
operands.append(a)
if len(operands) < size:
- return functools.reduce(
- lambda a, b: expression.__class__(this=a, expression=b), operands
- )
+ return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 55c81c5..11c6eba 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
return
if isinstance(predicate, exp.Binary):
- key = (
- predicate.right
- if any(node is column for node, *_ in predicate.left.walk())
- else predicate.left
- )
+ key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
else:
return
@@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
if not value.find(exp.AggFunc) and value.this not in group_by:
- select.select(
- f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
- )
+ select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
- parent_predicate = _replace(
- parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
- )
+ parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
else:
- parent_predicate = _replace(
- parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
- )
+ parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
elif isinstance(parent_predicate, exp.In):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
@@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
- parent_predicate = _replace(
- parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
- )
+ parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 9396c50..f46bafe 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -78,6 +78,7 @@ class Parser:
TokenType.TEXT,
TokenType.BINARY,
TokenType.JSON,
+ TokenType.INTERVAL,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.DATETIME,
@@ -85,6 +86,12 @@ class Parser:
TokenType.DECIMAL,
TokenType.UUID,
TokenType.GEOGRAPHY,
+ TokenType.GEOMETRY,
+ TokenType.HLLSKETCH,
+ TokenType.SUPER,
+ TokenType.SERIAL,
+ TokenType.SMALLSERIAL,
+ TokenType.BIGSERIAL,
*NESTED_TYPE_TOKENS,
}
@@ -100,13 +107,14 @@ class Parser:
ID_VAR_TOKENS = {
TokenType.VAR,
TokenType.ALTER,
+ TokenType.ALWAYS,
TokenType.BEGIN,
+ TokenType.BOTH,
TokenType.BUCKET,
TokenType.CACHE,
TokenType.COLLATE,
TokenType.COMMIT,
TokenType.CONSTRAINT,
- TokenType.CONVERT,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.ENGINE,
@@ -115,14 +123,19 @@ class Parser:
TokenType.FALSE,
TokenType.FIRST,
TokenType.FOLLOWING,
+ TokenType.FOR,
TokenType.FORMAT,
TokenType.FUNCTION,
+ TokenType.GENERATED,
+ TokenType.IDENTITY,
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.LAZY,
+ TokenType.LEADING,
TokenType.LOCATION,
+ TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
TokenType.OPTIMIZE,
@@ -141,6 +154,7 @@ class Parser:
TokenType.TABLE_FORMAT,
TokenType.TEMPORARY,
TokenType.TOP,
+ TokenType.TRAILING,
TokenType.TRUNCATE,
TokenType.TRUE,
TokenType.UNBOUNDED,
@@ -150,18 +164,15 @@ class Parser:
*TYPE_TOKENS,
}
- CASTS = {
- TokenType.CAST,
- TokenType.TRY_CAST,
- }
+ TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL}
+
+ TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
- TokenType.CONVERT,
TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
TokenType.CURRENT_TIME,
- TokenType.EXTRACT,
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
@@ -178,7 +189,6 @@ class Parser:
TokenType.DATETIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
- *CASTS,
*NESTED_TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@@ -215,6 +225,7 @@ class Parser:
FACTOR = {
TokenType.DIV: exp.IntDiv,
+ TokenType.LR_ARROW: exp.Distance,
TokenType.SLASH: exp.Div,
TokenType.STAR: exp.Mul,
}
@@ -299,14 +310,13 @@ class Parser:
PRIMARY_PARSERS = {
TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
- TokenType.STAR: lambda self, _: exp.Star(
- **{"except": self._parse_except(), "replace": self._parse_replace()}
- ),
+ TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
TokenType.NULL: lambda *_: exp.Null(),
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
+ TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self.expression(
exp.Introducer,
this=token.text,
@@ -319,13 +329,16 @@ class Parser:
TokenType.IN: lambda self, this: self._parse_in(this),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: lambda self, this: self._parse_escape(
- self.expression(exp.Like, this=this, expression=self._parse_type())
+ self.expression(exp.Like, this=this, expression=self._parse_bitwise())
),
TokenType.ILIKE: lambda self, this: self._parse_escape(
- self.expression(exp.ILike, this=this, expression=self._parse_type())
+ self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
),
TokenType.RLIKE: lambda self, this: self.expression(
- exp.RegexpLike, this=this, expression=self._parse_type()
+ exp.RegexpLike, this=this, expression=self._parse_bitwise()
+ ),
+ TokenType.SIMILAR_TO: lambda self, this: self.expression(
+ exp.SimilarTo, this=this, expression=self._parse_bitwise()
),
}
@@ -363,28 +376,21 @@ class Parser:
}
FUNCTION_PARSERS = {
- TokenType.CONVERT: lambda self, _: self._parse_convert(),
- TokenType.EXTRACT: lambda self, _: self._parse_extract(),
- **{
- token_type: lambda self, token_type: self._parse_cast(
- self.STRICT_CAST and token_type == TokenType.CAST
- )
- for token_type in CASTS
- },
+ "CONVERT": lambda self: self._parse_convert(),
+ "EXTRACT": lambda self: self._parse_extract(),
+ "SUBSTRING": lambda self: self._parse_substring(),
+ "TRIM": lambda self: self._parse_trim(),
+ "CAST": lambda self: self._parse_cast(self.STRICT_CAST),
+ "TRY_CAST": lambda self: self._parse_cast(False),
}
QUERY_MODIFIER_PARSERS = {
- "laterals": lambda self: self._parse_laterals(),
- "joins": lambda self: self._parse_joins(),
"where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
- "window": lambda self: self._match(TokenType.WINDOW)
- and self._parse_window(self._parse_id_var(), alias=True),
- "distribute": lambda self: self._parse_sort(
- TokenType.DISTRIBUTE_BY, exp.Distribute
- ),
+ "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
+ "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
"order": lambda self: self._parse_order(),
@@ -392,6 +398,8 @@ class Parser:
"offset": lambda self: self._parse_offset(),
}
+ MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
+
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
STRICT_CAST = True
@@ -457,9 +465,7 @@ class Parser:
Returns
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
"""
- return self._parse(
- parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
- )
+ return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
def parse_into(self, expression_types, raw_tokens, sql=None):
for expression_type in ensure_list(expression_types):
@@ -532,21 +538,13 @@ class Parser:
for k in expression.args:
if k not in expression.arg_types:
- self.raise_error(
- f"Unexpected keyword: '{k}' for {expression.__class__}"
- )
+ self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
for k, mandatory in expression.arg_types.items():
v = expression.args.get(k)
if mandatory and (v is None or (isinstance(v, list) and not v)):
- self.raise_error(
- f"Required keyword: '{k}' missing for {expression.__class__}"
- )
+ self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
- if (
- args
- and len(args) > len(expression.arg_types)
- and not expression.is_var_len_args
- ):
+ if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args:
self.raise_error(
f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(expression.arg_types)})"
@@ -594,11 +592,7 @@ class Parser:
)
expression = self._parse_expression()
- expression = (
- self._parse_set_operations(expression)
- if expression
- else self._parse_select()
- )
+ expression = self._parse_set_operations(expression) if expression else self._parse_select()
self._parse_query_modifiers(expression)
return expression
@@ -618,11 +612,7 @@ class Parser:
)
def _parse_exists(self, not_=False):
- return (
- self._match(TokenType.IF)
- and (not not_ or self._match(TokenType.NOT))
- and self._match(TokenType.EXISTS)
- )
+ return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
@@ -647,11 +637,9 @@ class Parser:
this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
this = self._parse_table(schema=True)
- properties = self._parse_properties(
- this if isinstance(this, exp.Schema) else None
- )
+ properties = self._parse_properties(this if isinstance(this, exp.Schema) else None)
if self._match(TokenType.ALIAS):
- expression = self._parse_select()
+ expression = self._parse_select(nested=True)
return self.expression(
exp.Create,
@@ -682,9 +670,7 @@ class Parser:
if schema and not isinstance(value, exp.Schema):
columns = {v.name.upper() for v in value.expressions}
partitions = [
- expression
- for expression in schema.expressions
- if expression.this.name.upper() in columns
+ expression for expression in schema.expressions if expression.this.name.upper() in columns
]
schema.set(
"expressions",
@@ -811,7 +797,7 @@ class Parser:
this=self._parse_table(schema=True),
exists=self._parse_exists(),
partition=self._parse_partition(),
- expression=self._parse_select(),
+ expression=self._parse_select(nested=True),
overwrite=overwrite,
)
@@ -829,8 +815,7 @@ class Parser:
exp.Update,
**{
"this": self._parse_table(schema=True),
- "expressions": self._match(TokenType.SET)
- and self._parse_csv(self._parse_equality),
+ "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(),
"where": self._parse_where(),
},
@@ -865,7 +850,7 @@ class Parser:
this=table,
lazy=lazy,
options=options,
- expression=self._parse_select(),
+ expression=self._parse_select(nested=True),
)
def _parse_partition(self):
@@ -894,9 +879,7 @@ class Parser:
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
- def _parse_select(self, table=None):
- index = self._index
-
+ def _parse_select(self, nested=False, table=False):
if self._match(TokenType.SELECT):
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
@@ -912,9 +895,7 @@ class Parser:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True)
- expressions = self._parse_csv(
- lambda: self._parse_annotation(self._parse_expression())
- )
+ expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
this = self.expression(
exp.Select,
@@ -960,19 +941,13 @@ class Parser:
)
else:
self.raise_error(f"{this.key} does not support CTE")
- elif self._match(TokenType.L_PAREN):
- this = self._parse_table() if table else self._parse_select()
-
- if this:
- self._parse_query_modifiers(this)
- self._match_r_paren()
- this = self._parse_subquery(this)
- else:
- self._retreat(index)
+ elif (table or nested) and self._match(TokenType.L_PAREN):
+ this = self._parse_table() if table else self._parse_select(nested=True)
+ self._parse_query_modifiers(this)
+ self._match_r_paren()
+ this = self._parse_subquery(this)
elif self._match(TokenType.VALUES):
- this = self.expression(
- exp.Values, expressions=self._parse_csv(self._parse_value)
- )
+ this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
alias = self._parse_table_alias()
if alias:
this = self.expression(exp.Subquery, this=this, alias=alias)
@@ -1001,7 +976,7 @@ class Parser:
def _parse_table_alias(self):
any_token = self._match(TokenType.ALIAS)
- alias = self._parse_id_var(any_token)
+ alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
columns = None
if self._match(TokenType.L_PAREN):
@@ -1021,9 +996,24 @@ class Parser:
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
def _parse_query_modifiers(self, this):
- if not isinstance(this, (exp.Subquery, exp.Subqueryable)):
+ if not isinstance(this, self.MODIFIABLES):
return
+ table = isinstance(this, exp.Table)
+
+ while True:
+ lateral = self._parse_lateral()
+ join = self._parse_join()
+ comma = None if table else self._match(TokenType.COMMA)
+ if lateral:
+ this.append("laterals", lateral)
+ if join:
+ this.append("joins", join)
+ if comma:
+ this.args["from"].append("expressions", self._parse_table())
+ if not (lateral or join or comma):
+ break
+
for key, parser in self.QUERY_MODIFIER_PARSERS.items():
expression = parser(self)
@@ -1032,9 +1022,7 @@ class Parser:
def _parse_annotation(self, expression):
if self._match(TokenType.ANNOTATION):
- return self.expression(
- exp.Annotation, this=self._prev.text, expression=expression
- )
+ return self.expression(exp.Annotation, this=self._prev.text, expression=expression)
return expression
@@ -1052,16 +1040,16 @@ class Parser:
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
- def _parse_laterals(self):
- return self._parse_all(self._parse_lateral)
-
def _parse_lateral(self):
if not self._match(TokenType.LATERAL):
return None
- if not self._match(TokenType.VIEW):
- self.raise_error("Expected VIEW after LATERAL")
+ subquery = self._parse_select(table=True)
+ if subquery:
+ return self.expression(exp.Lateral, this=subquery)
+
+ self._match(TokenType.VIEW)
outer = self._match(TokenType.OUTER)
return self.expression(
@@ -1071,31 +1059,27 @@ class Parser:
alias=self.expression(
exp.TableAlias,
this=self._parse_id_var(any_token=False),
- columns=(
- self._parse_csv(self._parse_id_var)
- if self._match(TokenType.ALIAS)
- else None
- ),
+ columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None),
),
)
- def _parse_joins(self):
- return self._parse_all(self._parse_join)
-
def _parse_join_side_and_kind(self):
return (
+ self._match(TokenType.NATURAL) and self._prev,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
def _parse_join(self):
- side, kind = self._parse_join_side_and_kind()
+ natural, side, kind = self._parse_join_side_and_kind()
if not self._match(TokenType.JOIN):
return None
kwargs = {"this": self._parse_table()}
+ if natural:
+ kwargs["natural"] = True
if side:
kwargs["side"] = side.text
if kind:
@@ -1120,6 +1104,11 @@ class Parser:
)
def _parse_table(self, schema=False):
+ lateral = self._parse_lateral()
+
+ if lateral:
+ return lateral
+
unnest = self._parse_unnest()
if unnest:
@@ -1172,9 +1161,7 @@ class Parser:
expressions = self._parse_csv(self._parse_column)
self._match_r_paren()
- ordinality = bool(
- self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)
- )
+ ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
alias = self._parse_table_alias()
@@ -1280,17 +1267,13 @@ class Parser:
if not self._match(TokenType.ORDER_BY):
return this
- return self.expression(
- exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
- )
+ return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
def _parse_sort(self, token_type, exp_class):
if not self._match(token_type):
return None
- return self.expression(
- exp_class, expressions=self._parse_csv(self._parse_ordered)
- )
+ return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
def _parse_ordered(self):
this = self._parse_conjunction()
@@ -1305,22 +1288,17 @@ class Parser:
if (
not explicitly_null_ordered
and (
- (asc and self.null_ordering == "nulls_are_small")
- or (desc and self.null_ordering != "nulls_are_small")
+ (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
)
and self.null_ordering != "nulls_are_last"
):
nulls_first = True
- return self.expression(
- exp.Ordered, this=this, desc=desc, nulls_first=nulls_first
- )
+ return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
def _parse_limit(self, this=None, top=False):
if self._match(TokenType.TOP if top else TokenType.LIMIT):
- return self.expression(
- exp.Limit, this=this, expression=self._parse_number()
- )
+ return self.expression(exp.Limit, this=this, expression=self._parse_number())
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
@@ -1354,7 +1332,7 @@ class Parser:
expression,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
- expression=self._parse_select(),
+ expression=self._parse_select(nested=True),
)
def _parse_expression(self):
@@ -1396,9 +1374,7 @@ class Parser:
this = self.expression(exp.In, this=this, unnest=unnest)
else:
self._match_l_paren()
- expressions = self._parse_csv(
- lambda: self._parse_select() or self._parse_expression()
- )
+ expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression())
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0])
@@ -1430,13 +1406,9 @@ class Parser:
expression=self._parse_term(),
)
elif self._match_pair(TokenType.LT, TokenType.LT):
- this = self.expression(
- exp.BitwiseLeftShift, this=this, expression=self._parse_term()
- )
+ this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
elif self._match_pair(TokenType.GT, TokenType.GT):
- this = self.expression(
- exp.BitwiseRightShift, this=this, expression=self._parse_term()
- )
+ this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
else:
break
@@ -1524,7 +1496,7 @@ class Parser:
self.raise_error("Expecting >")
if type_token in self.TIMESTAMPS:
- tz = self._match(TokenType.WITH_TIME_ZONE)
+ tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
self._match(TokenType.WITHOUT_TIME_ZONE)
if tz:
return exp.DataType(
@@ -1594,16 +1566,14 @@ class Parser:
if query:
expressions = [query]
else:
- expressions = self._parse_csv(
- lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
- )
+ expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
this = list_get(expressions, 0)
self._parse_query_modifiers(this)
self._match_r_paren()
if isinstance(this, exp.Subqueryable):
- return self._parse_subquery(this)
+ return self._parse_set_operations(self._parse_subquery(this))
if len(expressions) > 1:
return self.expression(exp.Tuple, expressions=expressions)
return self.expression(exp.Paren, this=this)
@@ -1611,11 +1581,7 @@ class Parser:
return None
def _parse_field(self, any_token=False):
- return (
- self._parse_primary()
- or self._parse_function()
- or self._parse_id_var(any_token)
- )
+ return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
def _parse_function(self):
if not self._curr:
@@ -1628,21 +1594,22 @@ class Parser:
if not self._next or self._next.token_type != TokenType.L_PAREN:
if token_type in self.NO_PAREN_FUNCTIONS:
- return self.expression(
- self._advance() or self.NO_PAREN_FUNCTIONS[token_type]
- )
+ return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type])
return None
if token_type not in self.FUNC_TOKENS:
return None
- if self._match_set(self.FUNCTION_PARSERS):
- self._advance()
- this = self.FUNCTION_PARSERS[token_type](self, token_type)
+ this = self._curr.text
+ upper = this.upper()
+ self._advance(2)
+
+ parser = self.FUNCTION_PARSERS.get(upper)
+
+ if parser:
+ this = parser(self)
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
- this = self._curr.text
- self._advance(2)
if subquery_predicate and self._curr.token_type in (
TokenType.SELECT,
@@ -1652,7 +1619,7 @@ class Parser:
self._match_r_paren()
return this
- function = self.FUNCTIONS.get(this.upper())
+ function = self.FUNCTIONS.get(upper)
args = self._parse_csv(self._parse_lambda)
if function:
@@ -1700,10 +1667,7 @@ class Parser:
self._retreat(index)
return this
- args = self._parse_csv(
- lambda: self._parse_constraint()
- or self._parse_column_def(self._parse_field())
- )
+ args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@@ -1720,12 +1684,9 @@ class Parser:
break
constraints.append(constraint)
- return self.expression(
- exp.ColumnDef, this=this, kind=kind, constraints=constraints
- )
+ return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
def _parse_column_constraint(self):
- kind = None
this = None
if self._match(TokenType.CONSTRAINT):
@@ -1735,28 +1696,28 @@ class Parser:
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
self._match_l_paren()
- kind = self.expression(
- exp.CheckColumnConstraint, this=self._parse_conjunction()
- )
+ kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
self._match_r_paren()
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
- kind = self.expression(
- exp.DefaultColumnConstraint, this=self._parse_field()
- )
- elif self._match(TokenType.NOT) and self._match(TokenType.NULL):
+ kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
+ elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.SCHEMA_COMMENT):
- kind = self.expression(
- exp.CommentColumnConstraint, this=self._parse_string()
- )
+ kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
kind = exp.PrimaryKeyColumnConstraint()
elif self._match(TokenType.UNIQUE):
kind = exp.UniqueColumnConstraint()
-
- if kind is None:
+ elif self._match(TokenType.GENERATED):
+ if self._match(TokenType.BY_DEFAULT):
+ kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
+ else:
+ self._match(TokenType.ALWAYS)
+ kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
+ self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
+ else:
return None
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
@@ -1864,9 +1825,7 @@ class Parser:
if not self._match(TokenType.END):
self.raise_error("Expected END after CASE", self._prev)
- return self._parse_window(
- self.expression(exp.Case, this=expression, ifs=ifs, default=default)
- )
+ return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
def _parse_if(self):
if self._match(TokenType.L_PAREN):
@@ -1889,7 +1848,7 @@ class Parser:
if not self._match(TokenType.FROM):
self.raise_error("Expected FROM after EXTRACT", self._prev)
- return self.expression(exp.Extract, this=this, expression=self._parse_type())
+ return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
def _parse_cast(self, strict):
this = self._parse_conjunction()
@@ -1917,12 +1876,54 @@ class Parser:
to = None
return self.expression(exp.Cast, this=this, to=to)
+ def _parse_substring(self):
+ # Postgres supports the form: substring(string [from int] [for int])
+ # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
+
+ args = self._parse_csv(self._parse_bitwise)
+
+ if self._match(TokenType.FROM):
+ args.append(self._parse_bitwise())
+ if self._match(TokenType.FOR):
+ args.append(self._parse_bitwise())
+
+ this = exp.Substring.from_arg_list(args)
+ self.validate_expression(this, args)
+
+ return this
+
+ def _parse_trim(self):
+ # https://www.w3resource.com/sql/character-functions/trim.php
+ # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
+
+ position = None
+ collation = None
+
+ if self._match_set(self.TRIM_TYPES):
+ position = self._prev.text.upper()
+
+ expression = self._parse_term()
+ if self._match(TokenType.FROM):
+ this = self._parse_term()
+ else:
+ this = expression
+ expression = None
+
+ if self._match(TokenType.COLLATE):
+ collation = self._parse_term()
+
+ return self.expression(
+ exp.Trim,
+ this=this,
+ position=position,
+ expression=expression,
+ collation=collation,
+ )
+
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
self._match_l_paren()
- this = self.expression(
- exp.Filter, this=this, expression=self._parse_where()
- )
+ this = self.expression(exp.Filter, this=this, expression=self._parse_where())
self._match_r_paren()
if self._match(TokenType.WITHIN_GROUP):
@@ -1935,6 +1936,25 @@ class Parser:
self._match_r_paren()
return this
+ # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
+ # Some dialects choose to implement and some do not.
+ # https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html
+
+ # There is some code above in _parse_lambda that handles
+ # SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ...
+
+ # The below changes handle
+ # SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ...
+
+ # Oracle allows both formats
+ # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
+ # and Snowflake chose to do the same for familiarity
+ # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
+ if self._match(TokenType.IGNORE_NULLS):
+ this = self.expression(exp.IgnoreNulls, this=this)
+ elif self._match(TokenType.RESPECT_NULLS):
+ this = self.expression(exp.RespectNulls, this=this)
+
# bigquery select from window x AS (partition by ...)
if alias:
self._match(TokenType.ALIAS)
@@ -1992,13 +2012,9 @@ class Parser:
self._match(TokenType.BETWEEN)
return {
- "value": (
- self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW))
- and self._prev.text
- )
+ "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
or self._parse_bitwise(),
- "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING))
- and self._prev.text,
+ "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
}
def _parse_alias(self, this, explicit=False):
@@ -2023,22 +2039,16 @@ class Parser:
return this
- def _parse_id_var(self, any_token=True):
+ def _parse_id_var(self, any_token=True, tokens=None):
identifier = self._parse_identifier()
if identifier:
return identifier
- if (
- any_token
- and self._curr
- and self._curr.token_type not in self.RESERVED_KEYWORDS
- ):
+ if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
- return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier(
- this=self._prev.text, quoted=False
- )
+ return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
def _parse_string(self):
if self._match(TokenType.STRING):
@@ -2077,9 +2087,7 @@ class Parser:
def _parse_star(self):
if self._match(TokenType.STAR):
- return exp.Star(
- **{"except": self._parse_except(), "replace": self._parse_replace()}
- )
+ return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
return None
def _parse_placeholder(self):
@@ -2117,15 +2125,10 @@ class Parser:
this = parse()
while self._match_set(expressions):
- this = self.expression(
- expressions[self._prev.token_type], this=this, expression=parse()
- )
+ this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
return this
- def _parse_all(self, parse):
- return list(iter(parse, None))
-
def _parse_wrapped_id_vars(self):
self._match_l_paren()
expressions = self._parse_csv(self._parse_id_var)
@@ -2156,10 +2159,7 @@ class Parser:
if not self._curr or not self._next:
return None
- if (
- self._curr.token_type == token_type_a
- and self._next.token_type == token_type_b
- ):
+ if self._curr.token_type == token_type_a and self._next.token_type == token_type_b:
if advance:
self._advance(2)
return True
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 2006a75..ed0b66c 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -72,9 +72,7 @@ class Step:
if from_:
from_ = from_.expressions
if len(from_) > 1:
- raise UnsupportedError(
- "Multi-from statements are unsupported. Run it through the optimizer"
- )
+ raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
step = Scan.from_expression(from_[0], ctes)
else:
@@ -104,9 +102,7 @@ class Step:
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
- operand.replace(
- exp.column(operands[operand], step.name, quoted=True)
- )
+ operand.replace(exp.column(operands[operand], step.name, quoted=True))
else:
projections.append(e)
@@ -121,14 +117,9 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
- aggregate.operands = tuple(
- alias(operand, alias_) for operand, alias_ in operands.items()
- )
+ aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
aggregate.aggregations = aggregations
- aggregate.group = [
- exp.column(e.alias_or_name, step.name, quoted=True)
- for e in group.expressions
- ]
+ aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
aggregate.add_dependency(step)
step = aggregate
@@ -212,9 +203,7 @@ class Scan(Step):
alias_ = expression.alias
if not alias_:
- raise UnsupportedError(
- "Tables/Subqueries must be aliased. Run it through the optimizer"
- )
+ raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
if isinstance(expression, exp.Subquery):
step = Step.from_expression(table, ctes)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index e4b754d..bd95bc7 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -38,6 +38,7 @@ class TokenType(AutoName):
DARROW = auto()
HASH_ARROW = auto()
DHASH_ARROW = auto()
+ LR_ARROW = auto()
ANNOTATION = auto()
DOLLAR = auto()
@@ -53,6 +54,7 @@ class TokenType(AutoName):
TABLE = auto()
VAR = auto()
BIT_STRING = auto()
+ HEX_STRING = auto()
# types
BOOLEAN = auto()
@@ -78,10 +80,17 @@ class TokenType(AutoName):
UUID = auto()
GEOGRAPHY = auto()
NULLABLE = auto()
+ GEOMETRY = auto()
+ HLLSKETCH = auto()
+ SUPER = auto()
+ SERIAL = auto()
+ SMALLSERIAL = auto()
+ BIGSERIAL = auto()
# keywords
ADD_FILE = auto()
ALIAS = auto()
+ ALWAYS = auto()
ALL = auto()
ALTER = auto()
ANALYZE = auto()
@@ -92,11 +101,12 @@ class TokenType(AutoName):
AUTO_INCREMENT = auto()
BEGIN = auto()
BETWEEN = auto()
+ BOTH = auto()
BUCKET = auto()
+ BY_DEFAULT = auto()
CACHE = auto()
CALL = auto()
CASE = auto()
- CAST = auto()
CHARACTER_SET = auto()
CHECK = auto()
CLUSTER_BY = auto()
@@ -104,7 +114,6 @@ class TokenType(AutoName):
COMMENT = auto()
COMMIT = auto()
CONSTRAINT = auto()
- CONVERT = auto()
CREATE = auto()
CROSS = auto()
CUBE = auto()
@@ -127,22 +136,24 @@ class TokenType(AutoName):
EXCEPT = auto()
EXISTS = auto()
EXPLAIN = auto()
- EXTRACT = auto()
FALSE = auto()
FETCH = auto()
FILTER = auto()
FINAL = auto()
FIRST = auto()
FOLLOWING = auto()
+ FOR = auto()
FOREIGN_KEY = auto()
FORMAT = auto()
FULL = auto()
FUNCTION = auto()
FROM = auto()
+ GENERATED = auto()
GROUP_BY = auto()
GROUPING_SETS = auto()
HAVING = auto()
HINT = auto()
+ IDENTITY = auto()
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
@@ -159,12 +170,14 @@ class TokenType(AutoName):
JOIN = auto()
LATERAL = auto()
LAZY = auto()
+ LEADING = auto()
LEFT = auto()
LIKE = auto()
LIMIT = auto()
LOCATION = auto()
MAP = auto()
MOD = auto()
+ NATURAL = auto()
NEXT = auto()
NO_ACTION = auto()
NULL = auto()
@@ -204,8 +217,10 @@ class TokenType(AutoName):
ROWS = auto()
SCHEMA_COMMENT = auto()
SELECT = auto()
+ SEPARATOR = auto()
SET = auto()
SHOW = auto()
+ SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
STORED = auto()
@@ -213,12 +228,11 @@ class TokenType(AutoName):
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
- TIME = auto()
TOP = auto()
THEN = auto()
TRUE = auto()
+ TRAILING = auto()
TRUNCATE = auto()
- TRY_CAST = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
@@ -272,35 +286,32 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
- klass.QUOTES = dict(
- (quote, quote) if isinstance(quote, str) else (quote[0], quote[1])
- for quote in klass.QUOTES
- )
-
- klass.IDENTIFIERS = dict(
- (identifier, identifier)
- if isinstance(identifier, str)
- else (identifier[0], identifier[1])
- for identifier in klass.IDENTIFIERS
- )
-
- klass.COMMENTS = dict(
- (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
- for comment in klass.COMMENTS
+ klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
+ klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
+ klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
+ klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
+ klass._COMMENTS = dict(
+ (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
)
klass.KEYWORD_TRIE = new_trie(
key.upper()
for key, value in {
**klass.KEYWORDS,
- **{comment: TokenType.COMMENT for comment in klass.COMMENTS},
- **{quote: TokenType.QUOTE for quote in klass.QUOTES},
+ **{comment: TokenType.COMMENT for comment in klass._COMMENTS},
+ **{quote: TokenType.QUOTE for quote in klass._QUOTES},
+ **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
+ **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
}.items()
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
+ @staticmethod
+ def _delimeter_list_to_dict(list):
+ return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
+
class Tokenizer(metaclass=_Tokenizer):
SINGLE_TOKENS = {
@@ -339,6 +350,10 @@ class Tokenizer(metaclass=_Tokenizer):
QUOTES = ["'"]
+ BIT_STRINGS = []
+
+ HEX_STRINGS = []
+
IDENTIFIERS = ['"']
ESCAPE = "'"
@@ -357,6 +372,7 @@ class Tokenizer(metaclass=_Tokenizer):
"->>": TokenType.DARROW,
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
+ "<->": TokenType.LR_ARROW,
"ADD ARCHIVE": TokenType.ADD_FILE,
"ADD ARCHIVES": TokenType.ADD_FILE,
"ADD FILE": TokenType.ADD_FILE,
@@ -374,12 +390,12 @@ class Tokenizer(metaclass=_Tokenizer):
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN,
+ "BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
"CALL": TokenType.CALL,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
- "CAST": TokenType.CAST,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
@@ -387,7 +403,6 @@ class Tokenizer(metaclass=_Tokenizer):
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"CONSTRAINT": TokenType.CONSTRAINT,
- "CONVERT": TokenType.CONVERT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
"CUBE": TokenType.CUBE,
@@ -408,7 +423,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
- "EXTRACT": TokenType.EXTRACT,
"FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER,
@@ -437,10 +451,12 @@ class Tokenizer(metaclass=_Tokenizer):
"JOIN": TokenType.JOIN,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
+ "LEADING": TokenType.LEADING,
"LEFT": TokenType.LEFT,
"LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT,
"LOCATION": TokenType.LOCATION,
+ "NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
@@ -490,8 +506,8 @@ class Tokenizer(metaclass=_Tokenizer):
"TEMPORARY": TokenType.TEMPORARY,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
+ "TRAILING": TokenType.TRAILING,
"TRUNCATE": TokenType.TRUNCATE,
- "TRY_CAST": TokenType.TRY_CAST,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNNEST": TokenType.UNNEST,
@@ -626,14 +642,12 @@ class Tokenizer(metaclass=_Tokenizer):
break
white_space = self.WHITE_SPACE.get(self._char)
- identifier_end = self.IDENTIFIERS.get(self._char)
+ identifier_end = self._IDENTIFIERS.get(self._char)
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
- elif self._char == "0" and self._peek == "x":
- self._scan_hex()
elif self._char.isdigit():
self._scan_number()
elif identifier_end:
@@ -666,9 +680,7 @@ class Tokenizer(metaclass=_Tokenizer):
text = self._text if text is None else text
self.tokens.append(Token(token_type, text, self._line, self._col))
- if token_type in self.COMMANDS and (
- len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
- ):
+ if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
@@ -725,6 +737,8 @@ class Tokenizer(metaclass=_Tokenizer):
if self._scan_string(word):
return
+ if self._scan_numeric_string(word):
+ return
if self._scan_comment(word):
return
@@ -732,10 +746,10 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(self.KEYWORDS[word.upper()])
def _scan_comment(self, comment_start):
- if comment_start not in self.COMMENTS:
+ if comment_start not in self._COMMENTS:
return False
- comment_end = self.COMMENTS[comment_start]
+ comment_end = self._COMMENTS[comment_start]
if comment_end:
comment_end_size = len(comment_end)
@@ -749,15 +763,18 @@ class Tokenizer(metaclass=_Tokenizer):
return True
def _scan_annotation(self):
- while (
- not self._end
- and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK
- and self._peek != ","
- ):
+ while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
self._advance()
self._add(TokenType.ANNOTATION, self._text[1:])
def _scan_number(self):
+ if self._char == "0":
+ peek = self._peek.upper()
+ if peek == "B":
+ return self._scan_bits()
+ elif peek == "X":
+ return self._scan_hex()
+
decimal = False
scientific = 0
@@ -788,57 +805,71 @@ class Tokenizer(metaclass=_Tokenizer):
else:
return self._add(TokenType.NUMBER)
+ def _scan_bits(self):
+ self._advance()
+ value = self._extract_value()
+ try:
+ self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
+ except ValueError:
+ self._add(TokenType.IDENTIFIER)
+
def _scan_hex(self):
self._advance()
+ value = self._extract_value()
+ try:
+ self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
+ except ValueError:
+ self._add(TokenType.IDENTIFIER)
+ def _extract_value(self):
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
break
- try:
- self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}")
- except ValueError:
- self._add(TokenType.IDENTIFIER)
+
+ return self._text
def _scan_string(self, quote):
- quote_end = self.QUOTES.get(quote)
+ quote_end = self._QUOTES.get(quote)
if quote_end is None:
return False
- text = ""
self._advance(len(quote))
- quote_end_size = len(quote_end)
-
- while True:
- if self._char == self.ESCAPE and self._peek == quote_end:
- text += quote
- self._advance(2)
- else:
- if self._chars(quote_end_size) == quote_end:
- if quote_end_size > 1:
- self._advance(quote_end_size - 1)
- break
-
- if self._end:
- raise RuntimeError(
- f"Missing {quote} from {self._line}:{self._start}"
- )
- text += self._char
- self._advance()
+ text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
self._add(TokenType.STRING, text)
return True
+ def _scan_numeric_string(self, string_start):
+ if string_start in self._HEX_STRINGS:
+ delimiters = self._HEX_STRINGS
+ token_type = TokenType.HEX_STRING
+ base = 16
+ elif string_start in self._BIT_STRINGS:
+ delimiters = self._BIT_STRINGS
+ token_type = TokenType.BIT_STRING
+ base = 2
+ else:
+ return False
+
+ self._advance(len(string_start))
+ string_end = delimiters.get(string_start)
+ text = self._extract_string(string_end)
+
+ try:
+ self._add(token_type, f"{int(text, base)}")
+ except ValueError:
+ raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
+ return True
+
def _scan_identifier(self, identifier_end):
while self._peek != identifier_end:
if self._end:
- raise RuntimeError(
- f"Missing {identifier_end} from {self._line}:{self._start}"
- )
+ raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
self._advance()
self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1])
@@ -851,3 +882,24 @@ class Tokenizer(metaclass=_Tokenizer):
else:
break
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
+
+ def _extract_string(self, delimiter):
+ text = ""
+ delim_size = len(delimiter)
+
+ while True:
+ if self._char == self.ESCAPE and self._peek == delimiter:
+ text += delimiter
+ self._advance(2)
+ else:
+ if self._chars(delim_size) == delimiter:
+ if delim_size > 1:
+ self._advance(delim_size - 1)
+ break
+
+ if self._end:
+ raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
+ text += self._char
+ self._advance()
+
+ return text
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index e7ccb8e..7fc71dd 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -12,9 +12,7 @@ def unalias_group(expression):
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
- e.alias: i
- for i, e in enumerate(expression.parent.expressions, start=1)
- if isinstance(e, exp.Alias)
+ e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias)
}
expression = expression.copy()
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 3993565..6b7bfd3 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -36,9 +36,7 @@ class Validator(unittest.TestCase):
for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
- parse_one(read_sql, read_dialect).sql(
- self.dialect, unsupported_level=ErrorLevel.IGNORE
- ),
+ parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
sql,
)
@@ -46,9 +44,7 @@ class Validator(unittest.TestCase):
with self.subTest(f"{sql} -> {write_dialect}"):
if write_sql is UnsupportedError:
with self.assertRaises(UnsupportedError):
- expression.sql(
- write_dialect, unsupported_level=ErrorLevel.RAISE
- )
+ expression.sql(write_dialect, unsupported_level=ErrorLevel.RAISE)
else:
self.assertEqual(
expression.sql(
@@ -82,12 +78,20 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
+ "redshift": "CAST(a AS TEXT)",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
},
)
self.validate_all(
+ "CAST(a AS DATETIME)",
+ write={
+ "postgres": "CAST(a AS TIMESTAMP)",
+ "sqlite": "CAST(a AS DATETIME)",
+ },
+ )
+ self.validate_all(
"CAST(a AS STRING)",
write={
"bigquery": "CAST(a AS STRING)",
@@ -97,6 +101,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
+ "redshift": "CAST(a AS TEXT)",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
@@ -112,6 +117,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS VARCHAR2)",
"postgres": "CAST(a AS VARCHAR)",
"presto": "CAST(a AS VARCHAR)",
+ "redshift": "CAST(a AS VARCHAR)",
"snowflake": "CAST(a AS VARCHAR)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS VARCHAR)",
@@ -127,6 +133,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS VARCHAR2(3))",
"postgres": "CAST(a AS VARCHAR(3))",
"presto": "CAST(a AS VARCHAR(3))",
+ "redshift": "CAST(a AS VARCHAR(3))",
"snowflake": "CAST(a AS VARCHAR(3))",
"spark": "CAST(a AS VARCHAR(3))",
"starrocks": "CAST(a AS VARCHAR(3))",
@@ -142,6 +149,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS NUMBER)",
"postgres": "CAST(a AS SMALLINT)",
"presto": "CAST(a AS SMALLINT)",
+ "redshift": "CAST(a AS SMALLINT)",
"snowflake": "CAST(a AS SMALLINT)",
"spark": "CAST(a AS SHORT)",
"sqlite": "CAST(a AS INTEGER)",
@@ -149,6 +157,19 @@ class TestDialect(Validator):
},
)
self.validate_all(
+ "TRY_CAST(a AS DOUBLE)",
+ read={
+ "postgres": "CAST(a AS DOUBLE PRECISION)",
+ "redshift": "CAST(a AS DOUBLE PRECISION)",
+ },
+ write={
+ "duckdb": "TRY_CAST(a AS DOUBLE)",
+ "postgres": "CAST(a AS DOUBLE PRECISION)",
+ "redshift": "CAST(a AS DOUBLE PRECISION)",
+ },
+ )
+
+ self.validate_all(
"CAST(a AS DOUBLE)",
write={
"bigquery": "CAST(a AS FLOAT64)",
@@ -159,16 +180,32 @@ class TestDialect(Validator):
"oracle": "CAST(a AS DOUBLE PRECISION)",
"postgres": "CAST(a AS DOUBLE PRECISION)",
"presto": "CAST(a AS DOUBLE)",
+ "redshift": "CAST(a AS DOUBLE PRECISION)",
"snowflake": "CAST(a AS DOUBLE)",
"spark": "CAST(a AS DOUBLE)",
"starrocks": "CAST(a AS DOUBLE)",
},
)
self.validate_all(
- "CAST(a AS TIMESTAMP)", write={"starrocks": "CAST(a AS DATETIME)"}
+ "CAST('1 DAY' AS INTERVAL)",
+ write={
+ "postgres": "CAST('1 DAY' AS INTERVAL)",
+ "redshift": "CAST('1 DAY' AS INTERVAL)",
+ },
)
self.validate_all(
- "CAST(a AS TIMESTAMPTZ)", write={"starrocks": "CAST(a AS DATETIME)"}
+ "CAST(a AS TIMESTAMP)",
+ write={
+ "starrocks": "CAST(a AS DATETIME)",
+ "redshift": "CAST(a AS TIMESTAMP)",
+ },
+ )
+ self.validate_all(
+ "CAST(a AS TIMESTAMPTZ)",
+ write={
+ "starrocks": "CAST(a AS DATETIME)",
+ "redshift": "CAST(a AS TIMESTAMPTZ)",
+ },
)
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"})
@@ -552,6 +589,7 @@ class TestDialect(Validator):
write={
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
+ "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
@@ -566,6 +604,7 @@ class TestDialect(Validator):
"presto": "JSON_EXTRACT(x, 'y')",
},
write={
+ "oracle": "JSON_EXTRACT(x, 'y')",
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')",
},
@@ -623,6 +662,37 @@ class TestDialect(Validator):
},
)
+ # https://dev.mysql.com/doc/refman/8.0/en/join.html
+ # https://www.postgresql.org/docs/current/queries-table-expressions.html
+ def test_joined_tables(self):
+ self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)")
+ self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)")
+ self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)")
+ self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)")
+
+ self.validate_all(
+ "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
+ write={
+ "postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
+ "mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
+ write={
+ "postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
+ "mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
+ },
+ )
+
+ def test_lateral_subquery(self):
+ self.validate_identity(
+ "SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art"
+ )
+ self.validate_identity(
+ "SELECT * FROM tbl AS t LEFT JOIN LATERAL (SELECT * FROM b WHERE b.t_id = t.t_id) AS t ON TRUE"
+ )
+
def test_set_operators(self):
self.validate_all(
"SELECT * FROM a UNION SELECT * FROM b",
@@ -731,6 +801,9 @@ class TestDialect(Validator):
)
def test_operators(self):
+ self.validate_identity("some.column LIKE 'foo' || another.column || 'bar' || LOWER(x)")
+ self.validate_identity("some.column LIKE 'foo' + another.column + 'bar'")
+
self.validate_all(
"x ILIKE '%y'",
read={
@@ -874,16 +947,8 @@ class TestDialect(Validator):
"spark": "FILTER(the_array, x -> x > 0)",
},
)
- self.validate_all(
- "SELECT a AS b FROM x GROUP BY b",
- write={
- "duckdb": "SELECT a AS b FROM x GROUP BY b",
- "presto": "SELECT a AS b FROM x GROUP BY 1",
- "hive": "SELECT a AS b FROM x GROUP BY 1",
- "oracle": "SELECT a AS b FROM x GROUP BY 1",
- "spark": "SELECT a AS b FROM x GROUP BY 1",
- },
- )
+
+ def test_limit(self):
self.validate_all(
"SELECT x FROM y LIMIT 10",
write={
@@ -915,6 +980,7 @@ class TestDialect(Validator):
read={
"clickhouse": '`x` + "y"',
"sqlite": '`x` + "y"',
+ "redshift": '"x" + "y"',
},
)
self.validate_all(
@@ -977,5 +1043,36 @@ class TestDialect(Validator):
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
+ "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))",
+ },
+ )
+
+ def test_alias(self):
+ self.validate_all(
+ "SELECT a AS b FROM x GROUP BY b",
+ write={
+ "duckdb": "SELECT a AS b FROM x GROUP BY b",
+ "presto": "SELECT a AS b FROM x GROUP BY 1",
+ "hive": "SELECT a AS b FROM x GROUP BY 1",
+ "oracle": "SELECT a AS b FROM x GROUP BY 1",
+ "spark": "SELECT a AS b FROM x GROUP BY 1",
+ },
+ )
+ self.validate_all(
+ "SELECT y x FROM my_table t",
+ write={
+ "hive": "SELECT y AS x FROM my_table AS t",
+ "oracle": "SELECT y AS x FROM my_table t",
+ "postgres": "SELECT y AS x FROM my_table AS t",
+ "sqlite": "SELECT y AS x FROM my_table AS t",
+ },
+ )
+ self.validate_all(
+ "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
+ write={
+ "hive": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
+ "oracle": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 t JOIN cte2 WHERE cte1.a = cte2.c",
+ "postgres": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
+ "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
},
)
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index eccd75a..55086e3 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -342,6 +342,21 @@ class TestHive(Validator):
},
)
self.validate_all(
+ "PERCENTILE_APPROX(x, 0.5)",
+ read={
+ "hive": "PERCENTILE_APPROX(x, 0.5)",
+ "presto": "APPROX_PERCENTILE(x, 0.5)",
+ "duckdb": "APPROX_QUANTILE(x, 0.5)",
+ "spark": "PERCENTILE_APPROX(x, 0.5)",
+ },
+ write={
+ "hive": "PERCENTILE_APPROX(x, 0.5)",
+ "presto": "APPROX_PERCENTILE(x, 0.5)",
+ "duckdb": "APPROX_QUANTILE(x, 0.5)",
+ "spark": "PERCENTILE_APPROX(x, 0.5)",
+ },
+ )
+ self.validate_all(
"APPROX_COUNT_DISTINCT(a)",
write={
"duckdb": "APPROX_COUNT_DISTINCT(a)",
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index ee0c5f5..87a3d64 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -15,6 +15,10 @@ class TestMySQL(Validator):
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
+ self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ')")
+ self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
+ self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
+ self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
def test_introducers(self):
self.validate_all(
@@ -27,12 +31,22 @@ class TestMySQL(Validator):
},
)
- def test_binary_literal(self):
+ def test_hexadecimal_literal(self):
self.validate_all(
"SELECT 0xCC",
write={
- "mysql": "SELECT b'11001100'",
- "spark": "SELECT X'11001100'",
+ "mysql": "SELECT x'CC'",
+ "sqlite": "SELECT x'CC'",
+ "spark": "SELECT X'CC'",
+ "trino": "SELECT X'CC'",
+ "bigquery": "SELECT 0xCC",
+ "oracle": "SELECT 204",
+ },
+ )
+ self.validate_all(
+ "SELECT X'1A'",
+ write={
+ "mysql": "SELECT x'1A'",
},
)
self.validate_all(
@@ -41,10 +55,22 @@ class TestMySQL(Validator):
"mysql": "SELECT `0xz`",
},
)
+
+ def test_bits_literal(self):
+ self.validate_all(
+ "SELECT 0b1011",
+ write={
+ "mysql": "SELECT b'1011'",
+ "postgres": "SELECT b'1011'",
+ "oracle": "SELECT 11",
+ },
+ )
self.validate_all(
- "SELECT 0XCC",
+ "SELECT B'1011'",
write={
- "mysql": "SELECT 0 AS XCC",
+ "mysql": "SELECT b'1011'",
+ "postgres": "SELECT b'1011'",
+ "oracle": "SELECT 11",
},
)
@@ -77,3 +103,19 @@ class TestMySQL(Validator):
"mysql": "SELECT 1",
},
)
+
+ def test_mysql(self):
+ self.validate_all(
+ "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
+ write={
+ "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')",
+ "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
+ },
+ )
+ self.validate_all(
+ "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
+ write={
+ "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')",
+ "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')",
+ },
+ )
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 15dbfd0..e0934d7 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -8,9 +8,7 @@ class TestPostgres(Validator):
def test_ddl(self):
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
- write={
- "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
- },
+ write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
)
self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
@@ -42,11 +40,17 @@ class TestPostgres(Validator):
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
},
)
+ self.validate_all(
+ "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)",
+ write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"},
+ )
+ self.validate_all(
+ "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)",
+ write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"},
+ )
with self.assertRaises(ParseError):
- transpile(
- "CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres"
- )
+ transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")
with self.assertRaises(ParseError):
transpile(
"CREATE TABLE products (price DECIMAL, CHECK price > 1)",
@@ -54,11 +58,16 @@ class TestPostgres(Validator):
)
def test_postgres(self):
- self.validate_all(
- "CREATE TABLE x (a INT SERIAL)",
- read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
- write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
- )
+ self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
+ self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
+ self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
+ self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
+ self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
+ self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
+ self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
+ self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
+ self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
+
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
@@ -91,3 +100,65 @@ class TestPostgres(Validator):
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
+ self.validate_all(
+ "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END",
+ write={
+ "hive": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
+ "spark": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM x WHERE SUBSTRING(col1 FROM 3 + LENGTH(col1) - 10 FOR 10) IN (col2)",
+ write={
+ "hive": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
+ "spark": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)",
+ },
+ )
+ self.validate_all(
+ "SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)",
+ read={
+ "postgres": "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)",
+ },
+ )
+ self.validate_all(
+ "SELECT TRIM(BOTH ' XXX ')",
+ write={
+ "mysql": "SELECT TRIM(' XXX ')",
+ "postgres": "SELECT TRIM(' XXX ')",
+ "hive": "SELECT TRIM(' XXX ')",
+ },
+ )
+ self.validate_all(
+ "TRIM(LEADING FROM ' XXX ')",
+ write={
+ "mysql": "LTRIM(' XXX ')",
+ "postgres": "LTRIM(' XXX ')",
+ "hive": "LTRIM(' XXX ')",
+ "presto": "LTRIM(' XXX ')",
+ },
+ )
+ self.validate_all(
+ "TRIM(TRAILING FROM ' XXX ')",
+ write={
+ "mysql": "RTRIM(' XXX ')",
+ "postgres": "RTRIM(' XXX ')",
+ "hive": "RTRIM(' XXX ')",
+ "presto": "RTRIM(' XXX ')",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
+ read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
+ )
+ self.validate_all(
+ "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
+ read={
+ "postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
+ },
+ )
+ self.validate_all(
+ "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id",
+ read={
+ "postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id",
+ },
+ )
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
new file mode 100644
index 0000000..1ed2bb6
--- /dev/null
+++ b/tests/dialects/test_redshift.py
@@ -0,0 +1,64 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestRedshift(Validator):
+ dialect = "redshift"
+
+ def test_redshift(self):
+ self.validate_all(
+ 'create table "group" ("col" char(10))',
+ write={
+ "redshift": 'CREATE TABLE "group" ("col" CHAR(10))',
+ "mysql": "CREATE TABLE `group` (`col` CHAR(10))",
+ },
+ )
+ self.validate_all(
+ 'create table if not exists city_slash_id("city/id" integer not null, state char(2) not null)',
+ write={
+ "redshift": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)',
+ "presto": 'CREATE TABLE IF NOT EXISTS city_slash_id ("city/id" INTEGER NOT NULL, state CHAR(2) NOT NULL)',
+ },
+ )
+ self.validate_all(
+ "SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)",
+ write={
+ "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))",
+ "bigquery": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))",
+ },
+ )
+ self.validate_all(
+ "SELECT ST_AsEWKT(ST_GeogFromText('LINESTRING(110 40, 2 3, -10 80, -7 9)')::geometry)",
+ write={
+ "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOGFROMTEXT('LINESTRING(110 40, 2 3, -10 80, -7 9)') AS GEOMETRY))",
+ },
+ )
+ self.validate_all(
+ "SELECT 'abc'::BINARY",
+ write={
+ "redshift": "SELECT CAST('abc' AS VARBYTE)",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid",
+ write={
+ "redshift": "SELECT * FROM venue WHERE (venuecity, venuestate) IN (('Miami', 'FL'), ('Tampa', 'FL')) ORDER BY venueid",
+ },
+ )
+ self.validate_all(
+ 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\_%\' LIMIT 5',
+ write={
+ "redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5'
+ },
+ )
+
+ def test_identity(self):
+ self.validate_identity("CAST('bla' AS SUPER)")
+ self.validate_identity("CREATE TABLE real1 (realcol REAL)")
+ self.validate_identity("CAST('foo' AS HLLSKETCH)")
+ self.validate_identity("SELECT DATEADD(day, 1, 'today')")
+ self.validate_identity("'abc' SIMILAR TO '(b|c)%'")
+ self.validate_identity(
+ "SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
+ )
+ self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
+ self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 62f78e1..2eeff52 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -143,3 +143,35 @@ class TestSnowflake(Validator):
"snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '",
},
)
+
+ def test_null_treatment(self):
+ self.validate_all(
+ r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
+ write={
+ "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
+ },
+ )
+ self.validate_all(
+ r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
+ write={
+ "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
+ },
+ )
+ self.validate_all(
+ r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
+ write={
+ "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
+ },
+ )
+ self.validate_all(
+ r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
+ write={
+ "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
+ },
+ )
+ self.validate_all(
+ r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
+ write={
+ "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
+ },
+ )
diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py
index a0576de..3cc974c 100644
--- a/tests/dialects/test_sqlite.py
+++ b/tests/dialects/test_sqlite.py
@@ -34,6 +34,7 @@ class TestSQLite(Validator):
write={
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
+ "postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
},
)
self.validate_all(
@@ -70,3 +71,20 @@ class TestSQLite(Validator):
"sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
},
)
+
+ def test_hexadecimal_literal(self):
+ self.validate_all(
+ "SELECT 0XCC",
+ write={
+ "sqlite": "SELECT x'CC'",
+ "mysql": "SELECT x'CC'",
+ },
+ )
+
+ def test_window_null_treatment(self):
+ self.validate_all(
+ "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks",
+ write={
+ "sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks"
+ },
+ )
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 40f11a2..1b4168c 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -318,6 +318,9 @@ SELECT 1 FROM a JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar
SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar
+SELECT 1 FROM a NATURAL JOIN b
+SELECT 1 FROM a NATURAL LEFT JOIN b
+SELECT 1 FROM a NATURAL LEFT OUTER JOIN b
SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar
SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar
SELECT 1 UNION ALL SELECT 2
@@ -329,6 +332,7 @@ SELECT 1 AS delete, 2 AS alter
SELECT * FROM (x)
SELECT * FROM ((x))
SELECT * FROM ((SELECT 1))
+SELECT * FROM (x LATERAL VIEW EXPLODE(y) JOIN foo)
SELECT * FROM (SELECT 1) AS x
SELECT * FROM (SELECT 1 UNION SELECT 2) AS x
SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x
@@ -430,6 +434,7 @@ CREATE TEMPORARY VIEW x AS SELECT a FROM d
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
+CREATE TABLE z (end INT)
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3))
CREATE TABLE z (a INT(11) DEFAULT UUID())
@@ -466,6 +471,7 @@ CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
+CACHE TABLE x AS (SELECT 1 AS y)
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
@@ -512,3 +518,4 @@ SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
+SELECT ((SELECT 1) + 1)
diff --git a/tests/fixtures/optimizer/merge_derived_tables.sql b/tests/fixtures/optimizer/merge_derived_tables.sql
new file mode 100644
index 0000000..c5aa7e9
--- /dev/null
+++ b/tests/fixtures/optimizer/merge_derived_tables.sql
@@ -0,0 +1,63 @@
+-- Simple
+SELECT a, b FROM (SELECT a, b FROM x);
+SELECT x.a AS a, x.b AS b FROM x AS x;
+
+-- Inner table alias is merged
+SELECT a, b FROM (SELECT a, b FROM x AS q) AS r;
+SELECT q.a AS a, q.b AS b FROM x AS q;
+
+-- Double nesting
+SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x));
+SELECT x.a AS a, x.b AS b FROM x AS x;
+
+-- WHERE clause is merged
+SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
+SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
+
+-- Outer query has join
+SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
+SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
+
+-- Join on derived table
+SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
+SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
+
+-- Inner query has a join
+SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b);
+SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
+
+-- Inner query has conflicting name in outer query
+SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b;
+SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b;
+
+-- Inner query has conflicting name in joined source
+SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b;
+SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b;
+
+-- Inner query has multiple conflicting names
+SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b;
+SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b;
+
+-- Inner queries have conflicting names with each other
+SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
+SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
+
+-- WHERE clause in joined derived table is merged
+SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
+SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y WHERE y.c > 1;
+
+-- WHERE clause in outer joined derived table is merged to ON clause
+SELECT x.a, y.c FROM x LEFT JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
+SELECT x.a AS a, y.c AS c FROM x AS x LEFT JOIN y AS y ON y.c > 1;
+
+-- Comma JOIN in outer query
+SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
+SELECT x.a AS a, y.c AS c FROM x AS x, y AS y;
+
+-- Comma JOIN in inner query
+SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x;
+SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
+
+-- (Regression) Column in ORDER BY
+SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
+SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index f7bbdda..f1d0f7d 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -2,11 +2,7 @@ SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
SELECT
"z"."a" AS "a",
"q"."m" AS "m"
-FROM (
- SELECT
- "z"."a" AS "a"
- FROM "z" AS "z"
-) AS "z"
+FROM "z" AS "z"
LATERAL VIEW
EXPLODE(ARRAY(1, 2)) q AS "m";
@@ -91,41 +87,26 @@ FROM (
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
GROUP BY a;
SELECT
- "d"."a" AS "a",
- SUM("d"."b") AS "_col_1"
-FROM (
+ "x"."a" AS "a",
+ SUM("y"."b") AS "_col_1"
+FROM "x" AS "x"
+LEFT JOIN (
SELECT
- "x"."a" AS "a",
- "y"."b" AS "b"
- FROM (
- SELECT
- "x"."a" AS "a"
- FROM "x" AS "x"
- WHERE
- "x"."a" > 1
- ) AS "x"
- LEFT JOIN (
- SELECT
- MAX("y"."b") AS "_col_0",
- "y"."a" AS "_u_1"
- FROM "y" AS "y"
- GROUP BY
- "y"."a"
- ) AS "_u_0"
- ON "x"."a" = "_u_0"."_u_1"
- JOIN (
- SELECT
- "y"."a" AS "a",
- "y"."b" AS "b"
- FROM "y" AS "y"
- ) AS "y"
- ON "x"."a" = "y"."a"
- WHERE
- "_u_0"."_col_0" >= 0
- AND NOT "_u_0"."_u_1" IS NULL
-) AS "d"
+ MAX("y"."b") AS "_col_0",
+ "y"."a" AS "_u_1"
+ FROM "y" AS "y"
+ GROUP BY
+ "y"."a"
+) AS "_u_0"
+ ON "x"."a" = "_u_0"."_u_1"
+JOIN "y" AS "y"
+ ON "x"."a" = "y"."a"
+WHERE
+ "_u_0"."_col_0" >= 0
+ AND "x"."a" > 1
+ AND NOT "_u_0"."_u_1" IS NULL
GROUP BY
- "d"."a";
+ "x"."a";
(SELECT a FROM x) LIMIT 1;
(
diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
index 482e231..0b6d382 100644
--- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql
+++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
@@ -120,36 +120,16 @@ SELECT
"supplier"."s_address" AS "s_address",
"supplier"."s_phone" AS "s_phone",
"supplier"."s_comment" AS "s_comment"
-FROM (
- SELECT
- "part"."p_partkey" AS "p_partkey",
- "part"."p_mfgr" AS "p_mfgr",
- "part"."p_type" AS "p_type",
- "part"."p_size" AS "p_size"
- FROM "part" AS "part"
- WHERE
- "part"."p_size" = 15
- AND "part"."p_type" LIKE '%BRASS'
-) AS "part"
+FROM "part" AS "part"
LEFT JOIN (
SELECT
MIN("partsupp"."ps_supplycost") AS "_col_0",
"partsupp"."ps_partkey" AS "_u_1"
FROM "_e_0" AS "partsupp"
CROSS JOIN "_e_1" AS "region"
- JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_regionkey" AS "n_regionkey"
- FROM "nation" AS "nation"
- ) AS "nation"
+ JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
- JOIN (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_nationkey" AS "s_nationkey"
- FROM "supplier" AS "supplier"
- ) AS "supplier"
+ JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
GROUP BY
@@ -157,31 +137,17 @@ LEFT JOIN (
) AS "_u_0"
ON "part"."p_partkey" = "_u_0"."_u_1"
CROSS JOIN "_e_1" AS "region"
-JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_name" AS "n_name",
- "nation"."n_regionkey" AS "n_regionkey"
- FROM "nation" AS "nation"
-) AS "nation"
+JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "_e_0" AS "partsupp"
ON "part"."p_partkey" = "partsupp"."ps_partkey"
-JOIN (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_name" AS "s_name",
- "supplier"."s_address" AS "s_address",
- "supplier"."s_nationkey" AS "s_nationkey",
- "supplier"."s_phone" AS "s_phone",
- "supplier"."s_acctbal" AS "s_acctbal",
- "supplier"."s_comment" AS "s_comment"
- FROM "supplier" AS "supplier"
-) AS "supplier"
+JOIN "supplier" AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
WHERE
- "partsupp"."ps_supplycost" = "_u_0"."_col_0"
+ "part"."p_size" = 15
+ AND "part"."p_type" LIKE '%BRASS'
+ AND "partsupp"."ps_supplycost" = "_u_0"."_col_0"
AND NOT "_u_0"."_u_1" IS NULL
ORDER BY
"s_acctbal" DESC,
@@ -224,36 +190,15 @@ SELECT
)) AS "revenue",
CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate",
"orders"."o_shippriority" AS "o_shippriority"
-FROM (
- SELECT
- "customer"."c_custkey" AS "c_custkey",
- "customer"."c_mktsegment" AS "c_mktsegment"
- FROM "customer" AS "customer"
- WHERE
- "customer"."c_mktsegment" = 'BUILDING'
-) AS "customer"
-JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_custkey" AS "o_custkey",
- "orders"."o_orderdate" AS "o_orderdate",
- "orders"."o_shippriority" AS "o_shippriority"
- FROM "orders" AS "orders"
- WHERE
- "orders"."o_orderdate" < '1995-03-15'
-) AS "orders"
+FROM "customer" AS "customer"
+JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
-JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount",
- "lineitem"."l_shipdate" AS "l_shipdate"
- FROM "lineitem" AS "lineitem"
- WHERE
- "lineitem"."l_shipdate" > '1995-03-15'
-) AS "lineitem"
+JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
+WHERE
+ "customer"."c_mktsegment" = 'BUILDING'
+ AND "lineitem"."l_shipdate" > '1995-03-15'
+ AND "orders"."o_orderdate" < '1995-03-15'
GROUP BY
"lineitem"."l_orderkey",
"orders"."o_orderdate",
@@ -342,57 +287,22 @@ SELECT
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue"
-FROM (
- SELECT
- "customer"."c_custkey" AS "c_custkey",
- "customer"."c_nationkey" AS "c_nationkey"
- FROM "customer" AS "customer"
-) AS "customer"
-JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_custkey" AS "o_custkey",
- "orders"."o_orderdate" AS "o_orderdate"
- FROM "orders" AS "orders"
- WHERE
- "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
-) AS "orders"
+FROM "customer" AS "customer"
+JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
-CROSS JOIN (
- SELECT
- "region"."r_regionkey" AS "r_regionkey",
- "region"."r_name" AS "r_name"
- FROM "region" AS "region"
- WHERE
- "region"."r_name" = 'ASIA'
-) AS "region"
-JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_name" AS "n_name",
- "nation"."n_regionkey" AS "n_regionkey"
- FROM "nation" AS "nation"
-) AS "nation"
+CROSS JOIN "region" AS "region"
+JOIN "nation" AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
-JOIN (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_nationkey" AS "s_nationkey"
- FROM "supplier" AS "supplier"
-) AS "supplier"
+JOIN "supplier" AS "supplier"
ON "customer"."c_nationkey" = "supplier"."s_nationkey"
AND "supplier"."s_nationkey" = "nation"."n_nationkey"
-JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_suppkey" AS "l_suppkey",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount"
- FROM "lineitem" AS "lineitem"
-) AS "lineitem"
+JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "lineitem"."l_suppkey" = "supplier"."s_suppkey"
+WHERE
+ "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
+ AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
+ AND "region"."r_name" = 'ASIA'
GROUP BY
"nation"."n_name"
ORDER BY
@@ -471,67 +381,37 @@ WITH "_e_0" AS (
OR "nation"."n_name" = 'GERMANY'
)
SELECT
- "shipping"."supp_nation" AS "supp_nation",
- "shipping"."cust_nation" AS "cust_nation",
- "shipping"."l_year" AS "l_year",
- SUM("shipping"."volume") AS "revenue"
-FROM (
- SELECT
- "n1"."n_name" AS "supp_nation",
- "n2"."n_name" AS "cust_nation",
- EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
- "lineitem"."l_extendedprice" * (
+ "n1"."n_name" AS "supp_nation",
+ "n2"."n_name" AS "cust_nation",
+ EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
+ SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
- ) AS "volume"
- FROM (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_nationkey" AS "s_nationkey"
- FROM "supplier" AS "supplier"
- ) AS "supplier"
- JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_suppkey" AS "l_suppkey",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount",
- "lineitem"."l_shipdate" AS "l_shipdate"
- FROM "lineitem" AS "lineitem"
- WHERE
- "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
- ) AS "lineitem"
- ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
- JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_custkey" AS "o_custkey"
- FROM "orders" AS "orders"
- ) AS "orders"
- ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
- JOIN (
- SELECT
- "customer"."c_custkey" AS "c_custkey",
- "customer"."c_nationkey" AS "c_nationkey"
- FROM "customer" AS "customer"
- ) AS "customer"
- ON "customer"."c_custkey" = "orders"."o_custkey"
- JOIN "_e_0" AS "n1"
- ON "supplier"."s_nationkey" = "n1"."n_nationkey"
- JOIN "_e_0" AS "n2"
- ON "customer"."c_nationkey" = "n2"."n_nationkey"
- AND (
- "n1"."n_name" = 'FRANCE'
- OR "n2"."n_name" = 'FRANCE'
- )
- AND (
- "n1"."n_name" = 'GERMANY'
- OR "n2"."n_name" = 'GERMANY'
- )
-) AS "shipping"
+ )) AS "revenue"
+FROM "supplier" AS "supplier"
+JOIN "lineitem" AS "lineitem"
+ ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
+JOIN "orders" AS "orders"
+ ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
+JOIN "customer" AS "customer"
+ ON "customer"."c_custkey" = "orders"."o_custkey"
+JOIN "_e_0" AS "n1"
+ ON "supplier"."s_nationkey" = "n1"."n_nationkey"
+JOIN "_e_0" AS "n2"
+ ON "customer"."c_nationkey" = "n2"."n_nationkey"
+ AND (
+ "n1"."n_name" = 'FRANCE'
+ OR "n2"."n_name" = 'FRANCE'
+ )
+ AND (
+ "n1"."n_name" = 'GERMANY'
+ OR "n2"."n_name" = 'GERMANY'
+ )
+WHERE
+ "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
GROUP BY
- "shipping"."supp_nation",
- "shipping"."cust_nation",
- "shipping"."l_year"
+ "n1"."n_name",
+ "n2"."n_name",
+ EXTRACT(year FROM "lineitem"."l_shipdate")
ORDER BY
"supp_nation",
"cust_nation",
@@ -578,87 +458,37 @@ group by
order by
o_year;
SELECT
- "all_nations"."o_year" AS "o_year",
+ EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
SUM(CASE
- WHEN "all_nations"."nation" = 'BRAZIL'
- THEN "all_nations"."volume"
+ WHEN "nation_2"."n_name" = 'BRAZIL'
+ THEN "lineitem"."l_extendedprice" * (
+ 1 - "lineitem"."l_discount"
+ )
ELSE 0
- END) / SUM("all_nations"."volume") AS "mkt_share"
-FROM (
- SELECT
- EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
- "lineitem"."l_extendedprice" * (
+ END) / SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
- ) AS "volume",
- "n2"."n_name" AS "nation"
- FROM (
- SELECT
- "part"."p_partkey" AS "p_partkey",
- "part"."p_type" AS "p_type"
- FROM "part" AS "part"
- WHERE
- "part"."p_type" = 'ECONOMY ANODIZED STEEL'
- ) AS "part"
- CROSS JOIN (
- SELECT
- "region"."r_regionkey" AS "r_regionkey",
- "region"."r_name" AS "r_name"
- FROM "region" AS "region"
- WHERE
- "region"."r_name" = 'AMERICA'
- ) AS "region"
- JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_regionkey" AS "n_regionkey"
- FROM "nation" AS "nation"
- ) AS "n1"
- ON "n1"."n_regionkey" = "region"."r_regionkey"
- JOIN (
- SELECT
- "customer"."c_custkey" AS "c_custkey",
- "customer"."c_nationkey" AS "c_nationkey"
- FROM "customer" AS "customer"
- ) AS "customer"
- ON "customer"."c_nationkey" = "n1"."n_nationkey"
- JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_custkey" AS "o_custkey",
- "orders"."o_orderdate" AS "o_orderdate"
- FROM "orders" AS "orders"
- WHERE
- "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
- ) AS "orders"
- ON "orders"."o_custkey" = "customer"."c_custkey"
- JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_partkey" AS "l_partkey",
- "lineitem"."l_suppkey" AS "l_suppkey",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount"
- FROM "lineitem" AS "lineitem"
- ) AS "lineitem"
- ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
- AND "part"."p_partkey" = "lineitem"."l_partkey"
- JOIN (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_nationkey" AS "s_nationkey"
- FROM "supplier" AS "supplier"
- ) AS "supplier"
- ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
- JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_name" AS "n_name"
- FROM "nation" AS "nation"
- ) AS "n2"
- ON "supplier"."s_nationkey" = "n2"."n_nationkey"
-) AS "all_nations"
+ )) AS "mkt_share"
+FROM "part" AS "part"
+CROSS JOIN "region" AS "region"
+JOIN "nation" AS "nation"
+ ON "nation"."n_regionkey" = "region"."r_regionkey"
+JOIN "customer" AS "customer"
+ ON "customer"."c_nationkey" = "nation"."n_nationkey"
+JOIN "orders" AS "orders"
+ ON "orders"."o_custkey" = "customer"."c_custkey"
+JOIN "lineitem" AS "lineitem"
+ ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
+ AND "part"."p_partkey" = "lineitem"."l_partkey"
+JOIN "supplier" AS "supplier"
+ ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
+JOIN "nation" AS "nation_2"
+ ON "supplier"."s_nationkey" = "nation_2"."n_nationkey"
+WHERE
+ "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ AND "part"."p_type" = 'ECONOMY ANODIZED STEEL'
+ AND "region"."r_name" = 'AMERICA'
GROUP BY
- "all_nations"."o_year"
+ EXTRACT(year FROM "orders"."o_orderdate")
ORDER BY
"o_year";
@@ -698,69 +528,28 @@ order by
nation,
o_year desc;
SELECT
- "profit"."nation" AS "nation",
- "profit"."o_year" AS "o_year",
- SUM("profit"."amount") AS "sum_profit"
-FROM (
- SELECT
- "nation"."n_name" AS "nation",
- EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
- "lineitem"."l_extendedprice" * (
+ "nation"."n_name" AS "nation",
+ EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
+ SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
- ) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity" AS "amount"
- FROM (
- SELECT
- "part"."p_partkey" AS "p_partkey",
- "part"."p_name" AS "p_name"
- FROM "part" AS "part"
- WHERE
- "part"."p_name" LIKE '%green%'
- ) AS "part"
- JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_partkey" AS "l_partkey",
- "lineitem"."l_suppkey" AS "l_suppkey",
- "lineitem"."l_quantity" AS "l_quantity",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount"
- FROM "lineitem" AS "lineitem"
- ) AS "lineitem"
- ON "part"."p_partkey" = "lineitem"."l_partkey"
- JOIN (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_nationkey" AS "s_nationkey"
- FROM "supplier" AS "supplier"
- ) AS "supplier"
- ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
- JOIN (
- SELECT
- "partsupp"."ps_partkey" AS "ps_partkey",
- "partsupp"."ps_suppkey" AS "ps_suppkey",
- "partsupp"."ps_supplycost" AS "ps_supplycost"
- FROM "partsupp" AS "partsupp"
- ) AS "partsupp"
- ON "partsupp"."ps_partkey" = "lineitem"."l_partkey"
- AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey"
- JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_orderdate" AS "o_orderdate"
- FROM "orders" AS "orders"
- ) AS "orders"
- ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
- JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_name" AS "n_name"
- FROM "nation" AS "nation"
- ) AS "nation"
- ON "supplier"."s_nationkey" = "nation"."n_nationkey"
-) AS "profit"
+ ) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity") AS "sum_profit"
+FROM "part" AS "part"
+JOIN "lineitem" AS "lineitem"
+ ON "part"."p_partkey" = "lineitem"."l_partkey"
+JOIN "supplier" AS "supplier"
+ ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
+JOIN "partsupp" AS "partsupp"
+ ON "partsupp"."ps_partkey" = "lineitem"."l_partkey"
+ AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey"
+JOIN "orders" AS "orders"
+ ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
+JOIN "nation" AS "nation"
+ ON "supplier"."s_nationkey" = "nation"."n_nationkey"
+WHERE
+ "part"."p_name" LIKE '%green%'
GROUP BY
- "profit"."nation",
- "profit"."o_year"
+ "nation"."n_name",
+ EXTRACT(year FROM "orders"."o_orderdate")
ORDER BY
"nation",
"o_year" DESC;
@@ -812,46 +601,17 @@ SELECT
"customer"."c_address" AS "c_address",
"customer"."c_phone" AS "c_phone",
"customer"."c_comment" AS "c_comment"
-FROM (
- SELECT
- "customer"."c_custkey" AS "c_custkey",
- "customer"."c_name" AS "c_name",
- "customer"."c_address" AS "c_address",
- "customer"."c_nationkey" AS "c_nationkey",
- "customer"."c_phone" AS "c_phone",
- "customer"."c_acctbal" AS "c_acctbal",
- "customer"."c_comment" AS "c_comment"
- FROM "customer" AS "customer"
-) AS "customer"
-JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_custkey" AS "o_custkey",
- "orders"."o_orderdate" AS "o_orderdate"
- FROM "orders" AS "orders"
- WHERE
- "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
-) AS "orders"
+FROM "customer" AS "customer"
+JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
-JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount",
- "lineitem"."l_returnflag" AS "l_returnflag"
- FROM "lineitem" AS "lineitem"
- WHERE
- "lineitem"."l_returnflag" = 'R'
-) AS "lineitem"
+JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
-JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_name" AS "n_name"
- FROM "nation" AS "nation"
-) AS "nation"
+JOIN "nation" AS "nation"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
+WHERE
+ "lineitem"."l_returnflag" = 'R'
+ AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
+ AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
GROUP BY
"customer"."c_custkey",
"customer"."c_name",
@@ -910,14 +670,7 @@ WITH "_e_0" AS (
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value"
-FROM (
- SELECT
- "partsupp"."ps_partkey" AS "ps_partkey",
- "partsupp"."ps_suppkey" AS "ps_suppkey",
- "partsupp"."ps_availqty" AS "ps_availqty",
- "partsupp"."ps_supplycost" AS "ps_supplycost"
- FROM "partsupp" AS "partsupp"
-) AS "partsupp"
+FROM "partsupp" AS "partsupp"
JOIN "_e_0" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation"
@@ -928,13 +681,7 @@ HAVING
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > (
SELECT
SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0"
- FROM (
- SELECT
- "partsupp"."ps_suppkey" AS "ps_suppkey",
- "partsupp"."ps_availqty" AS "ps_availqty",
- "partsupp"."ps_supplycost" AS "ps_supplycost"
- FROM "partsupp" AS "partsupp"
- ) AS "partsupp"
+ FROM "partsupp" AS "partsupp"
JOIN "_e_0" AS "supplier"
ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey"
JOIN "_e_1" AS "nation"
@@ -988,28 +735,15 @@ SELECT
THEN 1
ELSE 0
END) AS "low_line_count"
-FROM (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_orderpriority" AS "o_orderpriority"
- FROM "orders" AS "orders"
-) AS "orders"
-JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_shipdate" AS "l_shipdate",
- "lineitem"."l_commitdate" AS "l_commitdate",
- "lineitem"."l_receiptdate" AS "l_receiptdate",
- "lineitem"."l_shipmode" AS "l_shipmode"
- FROM "lineitem" AS "lineitem"
- WHERE
- "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
- AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
- AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
- AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
- AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
-) AS "lineitem"
+FROM "orders" AS "orders"
+JOIN "lineitem" AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
+WHERE
+ "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
+ AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
+ AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
+ AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
+ AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
GROUP BY
"lineitem"."l_shipmode"
ORDER BY
@@ -1044,21 +778,10 @@ SELECT
FROM (
SELECT
COUNT("orders"."o_orderkey") AS "c_count"
- FROM (
- SELECT
- "customer"."c_custkey" AS "c_custkey"
- FROM "customer" AS "customer"
- ) AS "customer"
- LEFT JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_custkey" AS "o_custkey",
- "orders"."o_comment" AS "o_comment"
- FROM "orders" AS "orders"
- WHERE
- NOT "orders"."o_comment" LIKE '%special%requests%'
- ) AS "orders"
+ FROM "customer" AS "customer"
+ LEFT JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
+ AND NOT "orders"."o_comment" LIKE '%special%requests%'
GROUP BY
"customer"."c_custkey"
) AS "c_orders"
@@ -1094,24 +817,12 @@ SELECT
END) / SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "promo_revenue"
-FROM (
- SELECT
- "lineitem"."l_partkey" AS "l_partkey",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount",
- "lineitem"."l_shipdate" AS "l_shipdate"
- FROM "lineitem" AS "lineitem"
- WHERE
- "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE)
-) AS "lineitem"
-JOIN (
- SELECT
- "part"."p_partkey" AS "p_partkey",
- "part"."p_type" AS "p_type"
- FROM "part" AS "part"
-) AS "part"
- ON "lineitem"."l_partkey" = "part"."p_partkey";
+FROM "lineitem" AS "lineitem"
+JOIN "part" AS "part"
+ ON "lineitem"."l_partkey" = "part"."p_partkey"
+WHERE
+ "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
+ AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
--------------------------------------
-- TPC-H 15
@@ -1165,14 +876,7 @@ SELECT
"supplier"."s_address" AS "s_address",
"supplier"."s_phone" AS "s_phone",
"revenue"."total_revenue" AS "total_revenue"
-FROM (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_name" AS "s_name",
- "supplier"."s_address" AS "s_address",
- "supplier"."s_phone" AS "s_phone"
- FROM "supplier" AS "supplier"
-) AS "supplier"
+FROM "supplier" AS "supplier"
JOIN "revenue"
ON "revenue"."total_revenue" = (
SELECT
@@ -1221,12 +925,7 @@ SELECT
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size",
COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt"
-FROM (
- SELECT
- "partsupp"."ps_partkey" AS "ps_partkey",
- "partsupp"."ps_suppkey" AS "ps_suppkey"
- FROM "partsupp" AS "partsupp"
-) AS "partsupp"
+FROM "partsupp" AS "partsupp"
LEFT JOIN (
SELECT
"supplier"."s_suppkey" AS "s_suppkey"
@@ -1237,21 +936,13 @@ LEFT JOIN (
"supplier"."s_suppkey"
) AS "_u_0"
ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey"
-JOIN (
- SELECT
- "part"."p_partkey" AS "p_partkey",
- "part"."p_brand" AS "p_brand",
- "part"."p_type" AS "p_type",
- "part"."p_size" AS "p_size"
- FROM "part" AS "part"
- WHERE
- "part"."p_brand" <> 'Brand#45'
- AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
- AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
-) AS "part"
+JOIN "part" AS "part"
ON "part"."p_partkey" = "partsupp"."ps_partkey"
WHERE
"_u_0"."s_suppkey" IS NULL
+ AND "part"."p_brand" <> 'Brand#45'
+ AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9)
+ AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%'
GROUP BY
"part"."p_brand",
"part"."p_type",
@@ -1284,23 +975,8 @@ where
);
SELECT
SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly"
-FROM (
- SELECT
- "lineitem"."l_partkey" AS "l_partkey",
- "lineitem"."l_quantity" AS "l_quantity",
- "lineitem"."l_extendedprice" AS "l_extendedprice"
- FROM "lineitem" AS "lineitem"
-) AS "lineitem"
-JOIN (
- SELECT
- "part"."p_partkey" AS "p_partkey",
- "part"."p_brand" AS "p_brand",
- "part"."p_container" AS "p_container"
- FROM "part" AS "part"
- WHERE
- "part"."p_brand" = 'Brand#23'
- AND "part"."p_container" = 'MED BOX'
-) AS "part"
+FROM "lineitem" AS "lineitem"
+JOIN "part" AS "part"
ON "part"."p_partkey" = "lineitem"."l_partkey"
LEFT JOIN (
SELECT
@@ -1313,6 +989,8 @@ LEFT JOIN (
ON "_u_0"."_u_1" = "part"."p_partkey"
WHERE
"lineitem"."l_quantity" < "_u_0"."_col_0"
+ AND "part"."p_brand" = 'Brand#23'
+ AND "part"."p_container" = 'MED BOX'
AND NOT "_u_0"."_u_1" IS NULL;
--------------------------------------
@@ -1359,20 +1037,8 @@ SELECT
"orders"."o_orderdate" AS "o_orderdate",
"orders"."o_totalprice" AS "o_totalprice",
SUM("lineitem"."l_quantity") AS "_col_5"
-FROM (
- SELECT
- "customer"."c_custkey" AS "c_custkey",
- "customer"."c_name" AS "c_name"
- FROM "customer" AS "customer"
-) AS "customer"
-JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_custkey" AS "o_custkey",
- "orders"."o_totalprice" AS "o_totalprice",
- "orders"."o_orderdate" AS "o_orderdate"
- FROM "orders" AS "orders"
-) AS "orders"
+FROM "customer" AS "customer"
+JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
LEFT JOIN (
SELECT
@@ -1385,12 +1051,7 @@ LEFT JOIN (
SUM("lineitem"."l_quantity") > 300
) AS "_u_0"
ON "orders"."o_orderkey" = "_u_0"."l_orderkey"
-JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_quantity" AS "l_quantity"
- FROM "lineitem" AS "lineitem"
-) AS "lineitem"
+JOIN "lineitem" AS "lineitem"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
WHERE
NOT "_u_0"."l_orderkey" IS NULL
@@ -1447,24 +1108,8 @@ SELECT
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue"
-FROM (
- SELECT
- "lineitem"."l_partkey" AS "l_partkey",
- "lineitem"."l_quantity" AS "l_quantity",
- "lineitem"."l_extendedprice" AS "l_extendedprice",
- "lineitem"."l_discount" AS "l_discount",
- "lineitem"."l_shipinstruct" AS "l_shipinstruct",
- "lineitem"."l_shipmode" AS "l_shipmode"
- FROM "lineitem" AS "lineitem"
-) AS "lineitem"
-JOIN (
- SELECT
- "part"."p_partkey" AS "p_partkey",
- "part"."p_brand" AS "p_brand",
- "part"."p_size" AS "p_size",
- "part"."p_container" AS "p_container"
- FROM "part" AS "part"
-) AS "part"
+FROM "lineitem" AS "lineitem"
+JOIN "part" AS "part"
ON (
"part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
@@ -1558,14 +1203,7 @@ order by
SELECT
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address"
-FROM (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_name" AS "s_name",
- "supplier"."s_address" AS "s_address",
- "supplier"."s_nationkey" AS "s_nationkey"
- FROM "supplier" AS "supplier"
-) AS "supplier"
+FROM "supplier" AS "supplier"
LEFT JOIN (
SELECT
"partsupp"."ps_suppkey" AS "ps_suppkey"
@@ -1604,17 +1242,11 @@ LEFT JOIN (
"partsupp"."ps_suppkey"
) AS "_u_4"
ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey"
-JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_name" AS "n_name"
- FROM "nation" AS "nation"
- WHERE
- "nation"."n_name" = 'CANADA'
-) AS "nation"
+JOIN "nation" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE
- NOT "_u_4"."ps_suppkey" IS NULL
+ "nation"."n_name" = 'CANADA'
+ AND NOT "_u_4"."ps_suppkey" IS NULL
ORDER BY
"s_name";
@@ -1665,24 +1297,9 @@ limit
SELECT
"supplier"."s_name" AS "s_name",
COUNT(*) AS "numwait"
-FROM (
- SELECT
- "supplier"."s_suppkey" AS "s_suppkey",
- "supplier"."s_name" AS "s_name",
- "supplier"."s_nationkey" AS "s_nationkey"
- FROM "supplier" AS "supplier"
-) AS "supplier"
-JOIN (
- SELECT
- "lineitem"."l_orderkey" AS "l_orderkey",
- "lineitem"."l_suppkey" AS "l_suppkey",
- "lineitem"."l_commitdate" AS "l_commitdate",
- "lineitem"."l_receiptdate" AS "l_receiptdate"
- FROM "lineitem" AS "lineitem"
- WHERE
- "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
-) AS "l1"
- ON "supplier"."s_suppkey" = "l1"."l_suppkey"
+FROM "supplier" AS "supplier"
+JOIN "lineitem" AS "lineitem"
+ ON "supplier"."s_suppkey" = "lineitem"."l_suppkey"
LEFT JOIN (
SELECT
"l2"."l_orderkey" AS "l_orderkey",
@@ -1691,7 +1308,7 @@ LEFT JOIN (
GROUP BY
"l2"."l_orderkey"
) AS "_u_0"
- ON "_u_0"."l_orderkey" = "l1"."l_orderkey"
+ ON "_u_0"."l_orderkey" = "lineitem"."l_orderkey"
LEFT JOIN (
SELECT
"l3"."l_orderkey" AS "l_orderkey",
@@ -1702,31 +1319,20 @@ LEFT JOIN (
GROUP BY
"l3"."l_orderkey"
) AS "_u_2"
- ON "_u_2"."l_orderkey" = "l1"."l_orderkey"
-JOIN (
- SELECT
- "orders"."o_orderkey" AS "o_orderkey",
- "orders"."o_orderstatus" AS "o_orderstatus"
- FROM "orders" AS "orders"
- WHERE
- "orders"."o_orderstatus" = 'F'
-) AS "orders"
- ON "orders"."o_orderkey" = "l1"."l_orderkey"
-JOIN (
- SELECT
- "nation"."n_nationkey" AS "n_nationkey",
- "nation"."n_name" AS "n_name"
- FROM "nation" AS "nation"
- WHERE
- "nation"."n_name" = 'SAUDI ARABIA'
-) AS "nation"
+ ON "_u_2"."l_orderkey" = "lineitem"."l_orderkey"
+JOIN "orders" AS "orders"
+ ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
+JOIN "nation" AS "nation"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
WHERE
(
"_u_2"."l_orderkey" IS NULL
- OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "l1"."l_suppkey")
+ OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "lineitem"."l_suppkey")
)
- AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "l1"."l_suppkey")
+ AND "lineitem"."l_receiptdate" > "lineitem"."l_commitdate"
+ AND "nation"."n_name" = 'SAUDI ARABIA'
+ AND "orders"."o_orderstatus" = 'F'
+ AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "lineitem"."l_suppkey")
AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY
"supplier"."s_name"
@@ -1776,35 +1382,30 @@ group by
order by
cntrycode;
SELECT
- "custsale"."cntrycode" AS "cntrycode",
+ SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
COUNT(*) AS "numcust",
- SUM("custsale"."c_acctbal") AS "totacctbal"
-FROM (
+ SUM("customer"."c_acctbal") AS "totacctbal"
+FROM "customer" AS "customer"
+LEFT JOIN (
SELECT
- SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode",
- "customer"."c_acctbal" AS "c_acctbal"
- FROM "customer" AS "customer"
- LEFT JOIN (
+ "orders"."o_custkey" AS "_u_1"
+ FROM "orders" AS "orders"
+ GROUP BY
+ "orders"."o_custkey"
+) AS "_u_0"
+ ON "_u_0"."_u_1" = "customer"."c_custkey"
+WHERE
+ "_u_0"."_u_1" IS NULL
+ AND "customer"."c_acctbal" > (
SELECT
- "orders"."o_custkey" AS "_u_1"
- FROM "orders" AS "orders"
- GROUP BY
- "orders"."o_custkey"
- ) AS "_u_0"
- ON "_u_0"."_u_1" = "customer"."c_custkey"
- WHERE
- "_u_0"."_u_1" IS NULL
- AND "customer"."c_acctbal" > (
- SELECT
- AVG("customer"."c_acctbal") AS "_col_0"
- FROM "customer" AS "customer"
- WHERE
- "customer"."c_acctbal" > 0.00
- AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
- )
- AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
-) AS "custsale"
+ AVG("customer"."c_acctbal") AS "_col_0"
+ FROM "customer" AS "customer"
+ WHERE
+ "customer"."c_acctbal" > 0.00
+ AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
+ )
+ AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')
GROUP BY
- "custsale"."cntrycode"
+ SUBSTRING("customer"."c_phone", 1, 2)
ORDER BY
"cntrycode";
diff --git a/tests/helpers.py b/tests/helpers.py
index d4edb14..ad50483 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -5,9 +5,7 @@ FIXTURES_DIR = os.path.join(FILE_DIR, "fixtures")
def _filter_comments(s):
- return "\n".join(
- [line for line in s.splitlines() if line and not line.startswith("--")]
- )
+ return "\n".join([line for line in s.splitlines() if line and not line.startswith("--")])
def _extract_meta(sql):
@@ -23,9 +21,7 @@ def _extract_meta(sql):
def assert_logger_contains(message, logger, level="error"):
- output = "\n".join(
- str(args[0][0]) for args in getattr(logger, level).call_args_list
- )
+ output = "\n".join(str(args[0][0]) for args in getattr(logger, level).call_args_list)
assert message in output
diff --git a/tests/test_build.py b/tests/test_build.py
index a4cffde..18c0e47 100644
--- a/tests/test_build.py
+++ b/tests/test_build.py
@@ -46,10 +46,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl WHERE FALSE",
),
(
- lambda: select("x")
- .from_("tbl")
- .where("x > 0")
- .where("x < 9", append=False),
+ lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False),
"SELECT x FROM tbl WHERE x < 9",
),
(
@@ -61,10 +58,7 @@ class TestBuild(unittest.TestCase):
"SELECT x, y FROM tbl GROUP BY x, y",
),
(
- lambda: select("x", "y", "z", "a")
- .from_("tbl")
- .group_by("x, y", "z")
- .group_by("a"),
+ lambda: select("x", "y", "z", "a").from_("tbl").group_by("x, y", "z").group_by("a"),
"SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a",
),
(
@@ -85,9 +79,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y",
),
(
- lambda: select("x")
- .from_("tbl")
- .join("tbl2", on=["tbl.y = tbl2.y", "a = b"]),
+ lambda: select("x").from_("tbl").join("tbl2", on=["tbl.y = tbl2.y", "a = b"]),
"SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b",
),
(
@@ -95,21 +87,15 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
- lambda: select("x")
- .from_("tbl")
- .join(exp.Table(this="tbl2"), join_type="left outer"),
+ lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
- lambda: select("x")
- .from_("tbl")
- .join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
+ lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
),
(
- lambda: select("x")
- .from_("tbl")
- .join(select("y").from_("tbl2"), join_type="left outer"),
+ lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
),
(
@@ -132,9 +118,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
),
(
- lambda: select("x")
- .from_("tbl")
- .join(parse_one("left join x", into=exp.Join), on="a=b"),
+ lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
@@ -142,9 +126,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
- lambda: select("x")
- .from_("tbl")
- .join("select b from tbl2", on="a=b", join_type="left"),
+ lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
),
(
@@ -159,10 +141,7 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b",
),
(
- lambda: select("x", "COUNT(y)")
- .from_("tbl")
- .group_by("x")
- .having("COUNT(y) > 0"),
+ lambda: select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 0"),
"SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0",
),
(
@@ -190,24 +169,15 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl SORT BY x, y DESC",
),
(
- lambda: select("x", "y", "z", "a")
- .from_("tbl")
- .order_by("x, y", "z")
- .order_by("a"),
+ lambda: select("x", "y", "z", "a").from_("tbl").order_by("x, y", "z").order_by("a"),
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
),
(
- lambda: select("x", "y", "z", "a")
- .from_("tbl")
- .cluster_by("x, y", "z")
- .cluster_by("a"),
+ lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
),
(
- lambda: select("x", "y", "z", "a")
- .from_("tbl")
- .sort_by("x, y", "z")
- .sort_by("a"),
+ lambda: select("x", "y", "z", "a").from_("tbl").sort_by("x, y", "z").sort_by("a"),
"SELECT x, y, z, a FROM tbl SORT BY x, y, z, a",
),
(lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"),
@@ -220,21 +190,15 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
- lambda: select("x")
- .from_("tbl")
- .with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
+ lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
- lambda: select("x")
- .from_("tbl")
- .with_("tbl", as_=select("x").from_("tbl2")),
+ lambda: select("x").from_("tbl").with_("tbl", as_=select("x").from_("tbl2")),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
- lambda: select("x")
- .from_("tbl")
- .with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
+ lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
),
(
@@ -245,72 +209,43 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
),
(
- lambda: select("x")
- .from_("tbl")
- .with_("tbl", as_=select("x", "y").from_("tbl2"))
- .select("y"),
+ lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl"),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .group_by("x"),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .order_by("x"),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .limit(10),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .offset(10),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .join("tbl3"),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .distinct(),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .where("x > 10"),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
),
(
- lambda: select("x")
- .with_("tbl", as_=select("x").from_("tbl2"))
- .from_("tbl")
- .having("x > 20"),
+ lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
),
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
@@ -324,9 +259,7 @@ class TestBuild(unittest.TestCase):
),
(lambda: from_("tbl").select("x"), "SELECT x FROM tbl"),
(
- lambda: parse_one("SELECT a FROM tbl")
- .assert_is(exp.Select)
- .select("b"),
+ lambda: parse_one("SELECT a FROM tbl").assert_is(exp.Select).select("b"),
"SELECT a, b FROM tbl",
),
(
@@ -368,15 +301,11 @@ class TestBuild(unittest.TestCase):
"SELECT * FROM x WHERE y = 1 AND z = 1",
),
(
- lambda: exp.subquery("select x from tbl", "foo")
- .select("x")
- .where("x > 0"),
+ lambda: exp.subquery("select x from tbl", "foo").select("x").where("x > 0"),
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
),
(
- lambda: exp.subquery(
- "select x from tbl UNION select x from bar", "unioned"
- ).select("x"),
+ lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
),
]:
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 9afa225..c5841d3 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -27,10 +27,7 @@ class TestExecutor(unittest.TestCase):
)
cls.cache = {}
- cls.sqls = [
- (sql, expected)
- for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
- ]
+ cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
@classmethod
def tearDownClass(cls):
@@ -50,18 +47,17 @@ class TestExecutor(unittest.TestCase):
self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''")
def test_optimized_tpch(self):
- for sql, optimized in self.sqls[0:20]:
- a = self.cached_execute(sql)
- b = self.conn.execute(optimized).fetchdf()
- self.rename_anonymous(b, a)
- assert_frame_equal(a, b)
+ for i, (sql, optimized) in enumerate(self.sqls[:20], start=1):
+ with self.subTest(f"{i}, {sql}"):
+ a = self.cached_execute(sql)
+ b = self.conn.execute(optimized).fetchdf()
+ self.rename_anonymous(b, a)
+ assert_frame_equal(a, b)
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table):
- return parse_one(
- f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
- )
+ return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
return expression
for sql, _ in self.sqls[0:3]:
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index eaef022..716e457 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -26,9 +26,7 @@ class TestExpressions(unittest.TestCase):
parse_one("ROW() OVER(Partition by y)"),
parse_one("ROW() OVER (partition BY y)"),
)
- self.assertEqual(
- parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")
- )
+ self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@@ -87,9 +85,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsNone(column.find_ancestor(exp.Join))
def test_alias_or_name(self):
- expression = parse_one(
- "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
- )
+ expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "e", "*", "zz", "z"],
@@ -118,9 +114,7 @@ class TestExpressions(unittest.TestCase):
)
def test_named_selects(self):
- expression = parse_one(
- "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
- )
+ expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
expression = parse_one(
@@ -196,15 +190,9 @@ class TestExpressions(unittest.TestCase):
def test_sql(self):
self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2")
- self.assertEqual(
- parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`"
- )
- self.assertEqual(
- parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"'
- )
- self.assertEqual(
- parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")'
- )
+ self.assertEqual(parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`")
+ self.assertEqual(parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"')
+ self.assertEqual(parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")')
def test_transform_with_arguments(self):
expression = parse_one("a")
@@ -229,15 +217,11 @@ class TestExpressions(unittest.TestCase):
return node
actual_expression_1 = expression.transform(fun)
- self.assertEqual(
- actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)"
- )
+ self.assertEqual(actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
self.assertIsNot(actual_expression_1, expression)
actual_expression_2 = expression.transform(fun, copy=False)
- self.assertEqual(
- actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)"
- )
+ self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
self.assertIs(actual_expression_2, expression)
with self.assertRaises(ValueError):
@@ -274,12 +258,8 @@ class TestExpressions(unittest.TestCase):
expression = parse_one("SELECT * FROM (SELECT * FROM x)")
self.assertEqual(len(list(expression.walk())), 9)
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
- self.assertTrue(
- all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())
- )
- self.assertTrue(
- all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
- )
+ self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
+ self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
def test_functions(self):
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
@@ -303,9 +283,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If)
self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap)
self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract)
- self.assertIsInstance(
- parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar
- )
+ self.assertIsInstance(parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar)
self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least)
self.assertIsInstance(parse_one("LN(a)"), exp.Ln)
self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10)
@@ -334,6 +312,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate)
self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime)
self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix)
+ self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim)
self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd)
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate)
self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring)
@@ -404,12 +383,8 @@ class TestExpressions(unittest.TestCase):
self.assertFalse(exp.to_identifier("x").quoted)
def test_function_normalizer(self):
- self.assertEqual(
- parse_one("HELLO()").sql(normalize_functions="lower"), "hello()"
- )
- self.assertEqual(
- parse_one("hello()").sql(normalize_functions="upper"), "HELLO()"
- )
+ self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()")
+ self.assertEqual(parse_one("hello()").sql(normalize_functions="upper"), "HELLO()")
self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()")
self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)")
self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)")
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 40540b3..102e141 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -31,9 +31,7 @@ class TestOptimizer(unittest.TestCase):
dialect = meta.get("dialect")
with self.subTest(sql):
self.assertEqual(
- func(parse_one(sql, read=dialect), **kwargs).sql(
- pretty=pretty, dialect=dialect
- ),
+ func(parse_one(sql, read=dialect), **kwargs).sql(pretty=pretty, dialect=dialect),
expected,
)
@@ -86,9 +84,7 @@ class TestOptimizer(unittest.TestCase):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql):
with self.assertRaises(OptimizeError):
- optimizer.qualify_columns.qualify_columns(
- parse_one(sql), schema=self.schema
- )
+ optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
def test_quote_identities(self):
self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
@@ -100,9 +96,7 @@ class TestOptimizer(unittest.TestCase):
expression = optimizer.pushdown_projections.pushdown_projections(expression)
return expression
- self.check_file(
- "pushdown_projections", pushdown_projections, schema=self.schema
- )
+ self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
def test_simplify(self):
self.check_file("simplify", optimizer.simplify.simplify)
@@ -115,9 +109,7 @@ class TestOptimizer(unittest.TestCase):
)
def test_pushdown_predicates(self):
- self.check_file(
- "pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates
- )
+ self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
def test_expand_multi_table_selects(self):
self.check_file(
@@ -138,10 +130,17 @@ class TestOptimizer(unittest.TestCase):
pretty=True,
)
+ def test_merge_derived_tables(self):
+ def optimize(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ expression = optimizer.merge_derived_tables.merge_derived_tables(expression)
+ return expression
+
+ self.check_file("merge_derived_tables", optimize, schema=self.schema)
+
def test_tpch(self):
- self.check_file(
- "tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True
- )
+ self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
def test_schema(self):
schema = ensure_schema(
@@ -262,9 +261,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
- self.assertEqual(
- scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b"
- )
+ self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 779083d..1054103 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -16,28 +16,23 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
def test_column(self):
- columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(
- exp.Column
- )
+ columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
assert len(list(columns)) == 1
self.assertIsNotNone(parse_one("date").find(exp.Column))
def test_table(self):
- tables = [
- t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)
- ]
+ tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
self.assertEqual(tables, ["a", "b.c", "d"])
def test_select(self):
- self.assertIsNotNone(
- parse_one("select * from (select 1) x order by x.y").args["order"]
- )
- self.assertIsNotNone(
- parse_one("select * from x where a = (select 1) order by x.y").args["order"]
- )
+ self.assertIsNotNone(parse_one("select 1 natural"))
+ self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
+ self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
+ self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
self.assertEqual(
- len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1
+ parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
+ """SELECT * FROM x, z LATERAL VIEW EXPLODE(y) CROSS JOIN y""",
)
def test_command(self):
@@ -72,12 +67,8 @@ class TestParser(unittest.TestCase):
)
assert len(expressions) == 2
- assert (
- expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
- )
- assert (
- expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
- )
+ assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
+ assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE)
@@ -147,13 +138,9 @@ class TestParser(unittest.TestCase):
def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
with patch("sqlglot.pretty", True):
- self.assertEqual(
- parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x"
- )
+ self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x")
- self.assertEqual(
- parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x"
- )
+ self.assertEqual(parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x")
@patch("sqlglot.parser.logger")
def test_comment_error_n(self, logger):
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 28bcc7a..4bec2ac 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -42,6 +42,20 @@ class TestTranspile(unittest.TestCase):
"SELECT * FROM x WHERE a = ANY (SELECT 1)",
)
+ def test_leading_comma(self):
+ self.validate(
+ "SELECT FOO, BAR, BAZ",
+ "SELECT\n FOO\n , BAR\n , BAZ",
+ leading_comma=True,
+ pretty=True,
+ )
+ # without pretty, this should be a no-op
+ self.validate(
+ "SELECT FOO, BAR, BAZ",
+ "SELECT FOO, BAR, BAZ",
+ leading_comma=True,
+ )
+
def test_space(self):
self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)")
self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)")
@@ -108,6 +122,11 @@ class TestTranspile(unittest.TestCase):
"extract(month from '2021-01-31'::timestamp without time zone)",
"EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))",
)
+ self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)")
+ self.validate(
+ "EXTRACT(minute FROM datetime1 - datetime2)",
+ "EXTRACT(minute FROM datetime1 - datetime2)",
+ )
def test_if(self):
self.validate(
@@ -122,18 +141,14 @@ class TestTranspile(unittest.TestCase):
"SELECT IF a > 1 THEN b ELSE c END",
"SELECT CASE WHEN a > 1 THEN b ELSE c END",
)
- self.validate(
- "SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo"
- )
+ self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo")
def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
- self.validate(
- "TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)"
- )
+ self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
self.validate(
"TIMESTAMP(9) WITH TIME ZONE '2020-01-01'",
"CAST('2020-01-01' AS TIMESTAMPTZ(9))",
@@ -159,9 +174,7 @@ class TestTranspile(unittest.TestCase):
self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)")
self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)")
self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb")
- self.validate(
- "STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb"
- )
+ self.validate("STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb")
self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb")
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
self.validate(
@@ -209,12 +222,8 @@ class TestTranspile(unittest.TestCase):
self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None)
self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive")
- self.validate(
- "UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive"
- )
- self.validate(
- "STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive"
- )
+ self.validate("UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive")
+ self.validate("STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive")
self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto")
self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive")
self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive")
@@ -232,9 +241,7 @@ class TestTranspile(unittest.TestCase):
)
self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto")
- self.validate(
- "STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto"
- )
+ self.validate("STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto")
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto")
self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto")
self.validate(
@@ -245,9 +252,7 @@ class TestTranspile(unittest.TestCase):
self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto")
self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark")
- self.validate(
- "STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark"
- )
+ self.validate("STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark")
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark")
self.validate(
@@ -283,9 +288,7 @@ class TestTranspile(unittest.TestCase):
def test_partial(self):
for sql in load_sql_fixtures("partial.sql"):
with self.subTest(sql):
- self.assertEqual(
- transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip()
- )
+ self.assertEqual(transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip())
def test_pretty(self):
for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"):