summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/__main__.py7
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py9
-rw-r--r--sqlglot/dialects/dialect.py31
-rw-r--r--sqlglot/dialects/duckdb.py5
-rw-r--r--sqlglot/dialects/hive.py15
-rw-r--r--sqlglot/dialects/mysql.py29
-rw-r--r--sqlglot/dialects/oracle.py8
-rw-r--r--sqlglot/dialects/postgres.py116
-rw-r--r--sqlglot/dialects/presto.py6
-rw-r--r--sqlglot/dialects/redshift.py34
-rw-r--r--sqlglot/dialects/snowflake.py4
-rw-r--r--sqlglot/dialects/spark.py15
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/dialects/trino.py3
-rw-r--r--sqlglot/diff.py35
-rw-r--r--sqlglot/executor/__init__.py10
-rw-r--r--sqlglot/executor/context.py4
-rw-r--r--sqlglot/executor/python.py14
-rw-r--r--sqlglot/executor/table.py5
-rw-r--r--sqlglot/expressions.py169
-rw-r--r--sqlglot/generator.py167
-rw-r--r--sqlglot/optimizer/__init__.py2
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py4
-rw-r--r--sqlglot/optimizer/merge_derived_tables.py232
-rw-r--r--sqlglot/optimizer/normalize.py22
-rw-r--r--sqlglot/optimizer/optimize_joins.py6
-rw-r--r--sqlglot/optimizer/optimizer.py39
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py20
-rw-r--r--sqlglot/optimizer/qualify_columns.py36
-rw-r--r--sqlglot/optimizer/qualify_tables.py4
-rw-r--r--sqlglot/optimizer/schema.py4
-rw-r--r--sqlglot/optimizer/scope.py58
-rw-r--r--sqlglot/optimizer/simplify.py8
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py22
-rw-r--r--sqlglot/parser.py404
-rw-r--r--sqlglot/planner.py21
-rw-r--r--sqlglot/tokens.py184
-rw-r--r--sqlglot/transforms.py4
40 files changed, 1082 insertions, 678 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 0007e34..3fa40ce 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "6.0.4"
+__version__ = "6.1.1"
pretty = False
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
index 25200c4..4161259 100644
--- a/sqlglot/__main__.py
+++ b/sqlglot/__main__.py
@@ -49,12 +49,7 @@ args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
if args.parse:
- sqls = [
- repr(expression)
- for expression in sqlglot.parse(
- args.sql, read=args.read, error_level=error_level
- )
- ]
+ sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
else:
sqls = sqlglot.transpile(
args.sql,
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 5aa7d77..f7d03ad 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
+from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index f4e87c3..1f1f90a 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -44,6 +44,7 @@ class BigQuery(Dialect):
]
IDENTIFIERS = ["`"]
ESCAPE = "\\"
+ HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
@@ -120,9 +121,5 @@ class BigQuery(Dialect):
def intersect_op(self, expression):
if not expression.args.get("distinct", False):
- self.unsupported(
- "INTERSECT without DISTINCT is not supported in BigQuery"
- )
- return (
- f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
- )
+ self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
+ return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 8045f7a..f338c81 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -20,6 +20,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
+ REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
SQLITE = "sqlite"
@@ -53,12 +54,19 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class()
- klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[
- 0
- ]
- klass.identifier_start, klass.identifier_end = list(
- klass.tokenizer_class.IDENTIFIERS.items()
- )[0]
+ klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
+ klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
+
+ if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
+ bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
+ klass.generator_class.TRANSFORMS[
+ exp.BitString
+ ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
+ if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
+ hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
+ klass.generator_class.TRANSFORMS[
+ exp.HexString
+ ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
return klass
@@ -122,9 +130,7 @@ class Dialect(metaclass=_Dialect):
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
def parse_into(self, expression_type, sql, **opts):
- return self.parser(**opts).parse_into(
- expression_type, self.tokenizer.tokenize(sql), sql
- )
+ return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
def generate(self, expression, **opts):
return self.generator(**opts).generate(expression)
@@ -164,9 +170,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
- return (
- lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
- )
+ return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
def approx_count_distinct_sql(self, expression):
@@ -260,8 +264,7 @@ def format_time_lambda(exp_class, dialect, default=None):
return exp_class(
this=list_get(args, 0),
format=Dialect[dialect].format_time(
- list_get(args, 1)
- or (Dialect[dialect].time_format if default is True else default)
+ list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index d83a620..ff3a8b1 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -63,10 +63,7 @@ def _sort_array_reverse(args):
def _struct_pack_sql(self, expression):
- args = [
- self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
- for e in expression.expressions
- ]
+ args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
return f"STRUCT_PACK({', '.join(args)})"
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index e3f3f39..59aa8fa 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -109,9 +109,7 @@ def _unnest_to_explode_sql(self, expression):
alias=exp.TableAlias(this=alias.this, columns=[column]),
)
)
- for expression, column in zip(
- unnest.expressions, alias.columns if alias else []
- )
+ for expression, column in zip(unnest.expressions, alias.columns if alias else [])
)
return self.join_sql(expression)
@@ -206,14 +204,11 @@ class Hive(Dialect):
substr=list_get(args, 0),
position=list_get(args, 2),
),
- "LOG": (
- lambda args: exp.Log.from_arg_list(args)
- if len(args) > 1
- else exp.Ln.from_arg_list(args)
- ),
+ "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": _parse_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
+ "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
"COLLECT_SET": exp.SetAgg.from_arg_list,
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
@@ -262,6 +257,7 @@ class Hive(Dialect):
HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
exp.Quantile: rename_func("PERCENTILE"),
+ exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql,
@@ -296,8 +292,7 @@ class Hive(Dialect):
def datatype_sql(self, expression):
if (
- expression.this
- in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
+ expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions
):
expression = exp.DataType.build("text")
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 93800a6..87a2c41 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -49,6 +49,21 @@ def _str_to_date_sql(self, expression):
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
+def _trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+ remove_chars = self.sql(expression, "expression")
+
+ # Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific
+ if not remove_chars:
+ return self.trim_sql(expression)
+
+ trim_type = f"{trim_type} " if trim_type else ""
+ remove_chars = f"{remove_chars} " if remove_chars else ""
+ from_part = "FROM " if trim_type or remove_chars else ""
+ return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
+
+
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
@@ -88,9 +103,12 @@ class MySQL(Dialect):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
+ BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
+ "SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@@ -145,6 +163,15 @@ class MySQL(Dialect):
"STR_TO_DATE": _str_to_date,
}
+ FUNCTION_PARSERS = {
+ **Parser.FUNCTION_PARSERS,
+ "GROUP_CONCAT": lambda self: self.expression(
+ exp.GroupConcat,
+ this=self._parse_lambda(),
+ separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
+ ),
+ }
+
class Generator(Generator):
NULL_ORDERING_SUPPORTED = False
@@ -158,6 +185,8 @@ class MySQL(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
+ exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
+ exp.Trim: _trim_sql,
}
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 9c8b6f2..91e30b2 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -51,6 +51,14 @@ class Oracle(Dialect):
sep="",
)
+ def alias_sql(self, expression):
+ if isinstance(expression.this, exp.Table):
+ to_sql = self.sql(expression, "alias")
+ # oracle does not allow "AS" between table and alias
+ to_sql = f" {to_sql}" if to_sql else ""
+ return f"{self.sql(expression, 'this')}{to_sql}"
+ return super().alias_sql(expression)
+
def offset_sql(self, expression):
return f"{super().offset_sql(expression)} ROWS"
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 61dff86..c796839 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.transforms import delegate, preprocess
def _date_add_sql(kind):
@@ -32,11 +33,96 @@ def _date_add_sql(kind):
return func
+def _lateral_sql(self, expression):
+ this = self.sql(expression, "this")
+ if isinstance(expression.this, exp.Subquery):
+ return f"LATERAL{self.sep()}{this}"
+ alias = expression.args["alias"]
+ table = alias.name
+ table = f" {table}" if table else table
+ columns = self.expressions(alias, key="columns", flat=True)
+ columns = f" AS {columns}" if columns else ""
+ return f"LATERAL{self.sep()}{this}{table}{columns}"
+
+
+def _substring_sql(self, expression):
+ this = self.sql(expression, "this")
+ start = self.sql(expression, "start")
+ length = self.sql(expression, "length")
+
+ from_part = f" FROM {start}" if start else ""
+ for_part = f" FOR {length}" if length else ""
+
+ return f"SUBSTRING({this}{from_part}{for_part})"
+
+
+def _trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+ remove_chars = self.sql(expression, "expression")
+ collation = self.sql(expression, "collation")
+
+ # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
+ if not remove_chars and not collation:
+ return self.trim_sql(expression)
+
+ trim_type = f"{trim_type} " if trim_type else ""
+ remove_chars = f"{remove_chars} " if remove_chars else ""
+ from_part = "FROM " if trim_type or remove_chars else ""
+ collation = f" COLLATE {collation}" if collation else ""
+ return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
+
+
+def _auto_increment_to_serial(expression):
+ auto = expression.find(exp.AutoIncrementColumnConstraint)
+
+ if auto:
+ expression = expression.copy()
+ expression.args["constraints"].remove(auto.parent)
+ kind = expression.args["kind"]
+
+ if kind.this == exp.DataType.Type.INT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL))
+ elif kind.this == exp.DataType.Type.SMALLINT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL))
+ elif kind.this == exp.DataType.Type.BIGINT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL))
+
+ return expression
+
+
+def _serial_to_generated(expression):
+ kind = expression.args["kind"]
+
+ if kind.this == exp.DataType.Type.SERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.INT)
+ elif kind.this == exp.DataType.Type.SMALLSERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.SMALLINT)
+ elif kind.this == exp.DataType.Type.BIGSERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.BIGINT)
+ else:
+ data_type = None
+
+ if data_type:
+ expression = expression.copy()
+ expression.args["kind"].replace(data_type)
+ constraints = expression.args["constraints"]
+ generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
+ notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
+ if notnull not in constraints:
+ constraints.insert(0, notnull)
+ if generated not in constraints:
+ constraints.insert(0, generated)
+
+ return expression
+
+
class Postgres(Dialect):
null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
- "AM": "%p", # AM or PM
+ "AM": "%p",
+ "PM": "%p",
"D": "%w", # 1-based day of week
"DD": "%d", # day of month
"DDD": "%j", # zero padded day of year
@@ -65,14 +151,25 @@ class Postgres(Dialect):
}
class Tokenizer(Tokenizer):
+ BIT_STRINGS = [("b'", "'"), ("B'", "'")]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
- "SERIAL": TokenType.AUTO_INCREMENT,
+ "ALWAYS": TokenType.ALWAYS,
+ "BY DEFAULT": TokenType.BY_DEFAULT,
+ "IDENTITY": TokenType.IDENTITY,
+ "FOR": TokenType.FOR,
+ "GENERATED": TokenType.GENERATED,
+ "DOUBLE PRECISION": TokenType.DOUBLE,
+ "BIGSERIAL": TokenType.BIGSERIAL,
+ "SERIAL": TokenType.SERIAL,
+ "SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
}
class Parser(Parser):
STRICT_CAST = False
+
FUNCTIONS = {
**Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
@@ -86,14 +183,18 @@ class Postgres(Dialect):
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
- }
-
- TOKEN_MAPPING = {
- TokenType.AUTO_INCREMENT: "SERIAL",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
+ exp.ColumnDef: preprocess(
+ [
+ _auto_increment_to_serial,
+ _serial_to_generated,
+ ],
+ delegate("columndef_sql"),
+ ),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
@@ -102,8 +203,11 @@ class Postgres(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
+ exp.Lateral: _lateral_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
+ exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
}
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index ca913e4..7253f7e 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -96,9 +96,7 @@ def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
- return (
- f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
- )
+ return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
def _ts_or_ds_add_sql(self, expression):
@@ -141,6 +139,7 @@ class Presto(Dialect):
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Generator):
@@ -193,6 +192,7 @@ class Presto(Dialect):
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql,
+ exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
new file mode 100644
index 0000000..e1f7b78
--- /dev/null
+++ b/sqlglot/dialects/redshift.py
@@ -0,0 +1,34 @@
+from sqlglot import exp
+from sqlglot.dialects.postgres import Postgres
+from sqlglot.tokens import TokenType
+
+
+class Redshift(Postgres):
+ time_format = "'YYYY-MM-DD HH:MI:SS'"
+ time_mapping = {
+ **Postgres.time_mapping,
+ "MON": "%b",
+ "HH": "%H",
+ }
+
+ class Tokenizer(Postgres.Tokenizer):
+ ESCAPE = "\\"
+
+ KEYWORDS = {
+ **Postgres.Tokenizer.KEYWORDS,
+ "GEOMETRY": TokenType.GEOMETRY,
+ "GEOGRAPHY": TokenType.GEOGRAPHY,
+ "HLLSKETCH": TokenType.HLLSKETCH,
+ "SUPER": TokenType.SUPER,
+ "TIME": TokenType.TIMESTAMP,
+ "TIMETZ": TokenType.TIMESTAMPTZ,
+ "VARBYTE": TokenType.BINARY,
+ "SIMILAR TO": TokenType.SIMILAR_TO,
+ }
+
+ class Generator(Postgres.Generator):
+ TYPE_MAPPING = {
+ **Postgres.Generator.TYPE_MAPPING,
+ exp.DataType.Type.BINARY: "VARBYTE",
+ exp.DataType.Type.INT: "INTEGER",
+ }
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 148dfb5..8d6ee78 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -23,9 +23,7 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
- raise ValueError(
- f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
- )
+ raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 89c7ed5..a331191 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -65,12 +65,11 @@ class Spark(Hive):
this=list_get(args, 0),
start=exp.Sub(
this=exp.Length(this=list_get(args, 0)),
- expression=exp.Add(
- this=list_get(args, 1), expression=exp.Literal.number(1)
- ),
+ expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
),
length=list_get(args, 1),
),
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Hive.Generator):
@@ -82,11 +81,7 @@ class Spark(Hive):
}
TRANSFORMS = {
- **{
- k: v
- for k, v in Hive.Generator.TRANSFORMS.items()
- if k not in {exp.ArraySort}
- },
+ **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
@@ -102,5 +97,5 @@ class Spark(Hive):
HiveMap: _map_sql,
}
- def bitstring_sql(self, expression):
- return f"X'{self.sql(expression, 'this')}'"
+ class Tokenizer(Hive.Tokenizer):
+ HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 6cf5022..cfdbe1b 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -16,6 +16,7 @@ from sqlglot.tokens import Tokenizer, TokenType
class SQLite(Dialect):
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 805106c..9a6f7fe 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -8,3 +8,6 @@ class Trino(Presto):
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}
+
+ class Tokenizer(Presto.Tokenizer):
+ HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 8eeb4e9..0567c12 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -115,13 +115,8 @@ class ChangeDistiller:
for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
- if (
- not isinstance(source_node, LEAF_EXPRESSION_TYPES)
- or source_node == target_node
- ):
- edit_script.extend(
- self._generate_move_edits(source_node, target_node, matching_set)
- )
+ if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
+ edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
@@ -132,9 +127,7 @@ class ChangeDistiller:
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
- args_lcs = set(
- _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
- )
+ args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set))
move_edits = []
for a in source_args:
@@ -148,14 +141,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
- id(n[0]): None
- for n in self._source.bfs()
- if id(n[0]) in self._unmatched_source_nodes
+ id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
- id(n[0]): None
- for n in self._target.bfs()
- if id(n[0]) in self._unmatched_target_nodes
+ id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@@ -169,18 +158,13 @@ class ChangeDistiller:
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
common_leaves_num = sum(
- 1 if s in source_leaf_ids and t in target_leaf_ids else 0
- for s, t in leaves_matching_set
+ 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
- adjusted_t = (
- self.t
- if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
- else 0.4
- )
+ adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
@@ -217,10 +201,7 @@ class ChangeDistiller:
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
- if (
- id(source_leaf) in self._unmatched_source_nodes
- and id(target_leaf) in self._unmatched_target_nodes
- ):
+ if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf))
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index a437431..bca9f3e 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -3,11 +3,17 @@ import time
from sqlglot import parse_one
from sqlglot.executor.python import PythonExecutor
-from sqlglot.optimizer import optimize
+from sqlglot.optimizer import RULES, optimize
+from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.planner import Plan
logger = logging.getLogger("sqlglot")
+OPTIMIZER_RULES = list(RULES)
+
+# The executor needs isolated table selects
+OPTIMIZER_RULES.remove(merge_derived_tables)
+
def execute(sql, schema, read=None):
"""
@@ -28,7 +34,7 @@ def execute(sql, schema, read=None):
"""
expression = parse_one(sql, read=read)
now = time.time()
- expression = optimize(expression, schema)
+ expression = optimize(expression, schema, rules=OPTIMIZER_RULES)
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression)
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 457bea7..d265a2c 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -19,9 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
- self.range_readers = {
- name: table.range_reader for name, table in self.tables.items()
- }
+ self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 388a419..610aa4b 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -26,11 +26,7 @@ class PythonExecutor:
while queue:
node = queue.pop()
context = self.context(
- {
- name: table
- for dep in node.dependencies
- for name, table in contexts[dep].tables.items()
- }
+ {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
)
running.add(node)
@@ -151,9 +147,7 @@ class PythonExecutor:
return self.context({name: table for name in ctx.tables})
for name, join in step.joins.items():
- join_context = self.context(
- {**join_context.tables, name: context.tables[name]}
- )
+ join_context = self.context({**join_context.tables, name: context.tables[name]})
if join.get("source_key"):
table = self.hash_join(join, source, name, join_context)
@@ -247,9 +241,7 @@ class PythonExecutor:
if step.operands:
source_table = context.tables[source]
- operand_table = Table(
- source_table.columns + self.table(step.operands).columns
- )
+ operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 6df49f7..80674cb 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -37,10 +37,7 @@ class Table:
break
lines.append(
- " ".join(
- str(row[column]).rjust(widths[column])[0 : widths[column]]
- for column in self.columns
- )
+ " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
)
return "\n".join(lines)
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 7acc63d..b983bf9 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -47,10 +47,7 @@ class Expression(metaclass=_Expression):
return hash(
(
self.key,
- tuple(
- (k, tuple(v) if isinstance(v, list) else v)
- for k, v in _norm_args(self).items()
- ),
+ tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
)
)
@@ -116,9 +113,22 @@ class Expression(metaclass=_Expression):
item.parent = parent
return new
+ def append(self, arg_key, value):
+ """
+ Appends value to arg_key if it's a list or sets it as a new list.
+
+ Args:
+ arg_key (str): name of the list expression arg
+ value (Any): value to append to the list
+ """
+ if not isinstance(self.args.get(arg_key), list):
+ self.args[arg_key] = []
+ self.args[arg_key].append(value)
+ self._set_parent(arg_key, value)
+
def set(self, arg_key, value):
"""
- Sets `arg` to `value`.
+ Sets `arg_key` to `value`.
Args:
arg_key (str): name of the expression arg
@@ -267,6 +277,14 @@ class Expression(metaclass=_Expression):
expression = expression.this
return expression
+ def unalias(self):
+ """
+ Returns the inner expression if this is an Alias.
+ """
+ if isinstance(self, Alias):
+ return self.this
+ return self
+
def unnest_operands(self):
"""
Returns unnested operands as a tuple.
@@ -279,9 +297,7 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(
- prune=lambda n, p, *_: p and not isinstance(n, self.__class__)
- ):
+ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
if not isinstance(node, self.__class__):
yield node.unnest() if unnest else node
@@ -314,9 +330,7 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
- v.to_s(hide_missing=hide_missing, level=level + 1)
- if hasattr(v, "to_s")
- else str(v)
+ v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_list(vs)
if v is not None
)
@@ -354,9 +368,7 @@ class Expression(metaclass=_Expression):
new_node.parent = node.parent
return new_node
- replace_children(
- new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)
- )
+ replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
return new_node
def replace(self, expression):
@@ -546,6 +558,10 @@ class BitString(Condition):
pass
+class HexString(Condition):
+ pass
+
+
class Column(Condition):
arg_types = {"this": True, "table": False}
@@ -566,35 +582,44 @@ class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
-class AutoIncrementColumnConstraint(Expression):
+class ColumnConstraintKind(Expression):
pass
-class CheckColumnConstraint(Expression):
+class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass
-class CollateColumnConstraint(Expression):
+class CheckColumnConstraint(ColumnConstraintKind):
pass
-class CommentColumnConstraint(Expression):
+class CollateColumnConstraint(ColumnConstraintKind):
pass
-class DefaultColumnConstraint(Expression):
+class CommentColumnConstraint(ColumnConstraintKind):
pass
-class NotNullColumnConstraint(Expression):
+class DefaultColumnConstraint(ColumnConstraintKind):
pass
-class PrimaryKeyColumnConstraint(Expression):
+class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
+ # this: True -> ALWAYS, this: False -> BY DEFAULT
+ arg_types = {"this": True, "expression": False}
+
+
+class NotNullColumnConstraint(ColumnConstraintKind):
pass
-class UniqueColumnConstraint(Expression):
+class PrimaryKeyColumnConstraint(ColumnConstraintKind):
+ pass
+
+
+class UniqueColumnConstraint(ColumnConstraintKind):
pass
@@ -651,9 +676,7 @@ class Identifier(Expression):
return bool(self.args.get("quoted"))
def __eq__(self, other):
- return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(
- other.this
- )
+ return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
def __hash__(self):
return hash((self.key, self.this.lower()))
@@ -709,9 +732,7 @@ class Literal(Condition):
def __eq__(self, other):
return (
- isinstance(other, Literal)
- and self.this == other.this
- and self.args["is_string"] == other.args["is_string"]
+ isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
@@ -733,6 +754,7 @@ class Join(Expression):
"side": False,
"kind": False,
"using": False,
+ "natural": False,
}
@property
@@ -743,6 +765,10 @@ class Join(Expression):
def side(self):
return self.text("side").upper()
+ @property
+ def alias_or_name(self):
+ return self.this.alias_or_name
+
def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
"""
Append to or set the ON expressions.
@@ -873,10 +899,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
-class Table(Expression):
- arg_types = {"this": True, "db": False, "catalog": False}
-
-
class Tuple(Expression):
arg_types = {"expressions": False}
@@ -986,6 +1008,16 @@ QUERY_MODIFIERS = {
}
+class Table(Expression):
+ arg_types = {
+ "this": True,
+ "db": False,
+ "catalog": False,
+ "laterals": False,
+ "joins": False,
+ }
+
+
class Union(Subqueryable, Expression):
arg_types = {
"with": False,
@@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression):
join.this.replace(join.this.subquery())
if join_type:
- side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ if natural:
+ join.set("natural", True)
if side:
join.set("side", side.text)
if kind:
@@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression):
properties_expression = None
if properties:
properties_str = " ".join(
- [
- f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
- for k, v in properties.items()
- ]
+ [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
)
properties_expression = maybe_parse(
properties_str,
@@ -1654,6 +1685,7 @@ class DataType(Expression):
DECIMAL = auto()
BOOLEAN = auto()
JSON = auto()
+ INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
DATE = auto()
@@ -1662,15 +1694,19 @@ class DataType(Expression):
MAP = auto()
UUID = auto()
GEOGRAPHY = auto()
+ GEOMETRY = auto()
STRUCT = auto()
NULLABLE = auto()
+ HLLSKETCH = auto()
+ SUPER = auto()
+ SERIAL = auto()
+ SMALLSERIAL = auto()
+ BIGSERIAL = auto()
@classmethod
def build(cls, dtype, **kwargs):
return DataType(
- this=dtype
- if isinstance(dtype, DataType.Type)
- else DataType.Type[dtype.upper()],
+ this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
)
@@ -1798,6 +1834,14 @@ class Like(Binary, Predicate):
pass
+class SimilarTo(Binary, Predicate):
+ pass
+
+
+class Distance(Binary):
+ pass
+
+
class LT(Binary, Predicate):
pass
@@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression):
pass
+class RespectNulls(Expression):
+ pass
+
+
# Functions
class Func(Condition):
"""
@@ -1924,9 +1972,7 @@ class Func(Condition):
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
- non_var_len_arg_keys = (
- all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
- )
+ non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
args_dict = {}
arg_idx = 0
@@ -1944,9 +1990,7 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
- raise NotImplementedError(
- "SQL name is only supported by concrete function implementations"
- )
+ raise NotImplementedError("SQL name is only supported by concrete function implementations")
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@@ -2178,6 +2222,10 @@ class Greatest(Func):
is_var_len_args = True
+class GroupConcat(Func):
+ arg_types = {"this": True, "separator": False}
+
+
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -2274,6 +2322,10 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
+class ApproxQuantile(Quantile):
+ pass
+
+
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@@ -2306,8 +2358,10 @@ class Split(Func):
arg_types = {"this": True, "expression": True}
+# Start may be omitted in the case of postgres
+# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func):
- arg_types = {"this": True, "start": True, "length": False}
+ arg_types = {"this": True, "start": False, "length": False}
class StrPosition(Func):
@@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func):
pass
+class Trim(Func):
+ arg_types = {
+ "this": True,
+ "position": False,
+ "expression": False,
+ "collation": False,
+ }
+
+
class TsOrDsAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -2455,9 +2518,7 @@ def _all_functions():
obj
for _, obj in inspect.getmembers(
sys.modules[__name__],
- lambda obj: inspect.isclass(obj)
- and issubclass(obj, Func)
- and obj not in (AggFunc, Anonymous, Func),
+ lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
)
]
@@ -2633,9 +2694,7 @@ def _apply_conjunction_builder(
def _combine(expressions, operator, dialect=None, **opts):
- expressions = [
- condition(expression, dialect=dialect, **opts) for expression in expressions
- ]
+ expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
@@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None):
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
identifier = Identifier(this=alias, quoted=quoted)
else:
- raise ValueError(
- f"Alias needs to be a string or an Identifier, got: {alias.__class__}"
- )
+ raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
return identifier
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 793cff0..a445178 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -41,6 +41,8 @@ class Generator:
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
+ leading_comma (bool): if the the comma is leading or trailing in select statements
+ Default: False
"""
TRANSFORMS = {
@@ -108,6 +110,7 @@ class Generator:
"_indent",
"_replace_backslash",
"_escaped_quote_end",
+ "_leading_comma",
)
def __init__(
@@ -131,6 +134,7 @@ class Generator:
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
max_unsupported=3,
+ leading_comma=False,
):
import sqlglot
@@ -157,6 +161,7 @@ class Generator:
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
+ self._leading_comma = leading_comma
def generate(self, expression):
"""
@@ -178,9 +183,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
- raise UnsupportedError(
- concat_errors(self.unsupported_messages, self.max_unsupported)
- )
+ raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
return sql
@@ -197,9 +200,7 @@ class Generator:
def wrap(self, expression):
this_sql = self.indent(
- self.sql(expression)
- if isinstance(expression, (exp.Select, exp.Union))
- else self.sql(expression, "this"),
+ self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
level=1,
pad=0,
)
@@ -251,9 +252,7 @@ class Generator:
return transform
if not isinstance(expression, exp.Expression):
- raise ValueError(
- f"Expected an Expression. Received {type(expression)}: {expression}"
- )
+ raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
@@ -276,11 +275,7 @@ class Generator:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
- options = (
- f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})"
- if options
- else ""
- )
+ options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
sql = self.sql(expression, "expression")
sql = f" AS{self.sep()}{sql}" if sql else ""
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
@@ -306,9 +301,7 @@ class Generator:
def columndef_sql(self, expression):
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
- constraints = self.expressions(
- expression, key="constraints", sep=" ", flat=True
- )
+ constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
if not constraints:
return f"{column} {kind}"
@@ -338,6 +331,9 @@ class Generator:
default = self.sql(expression, "this")
return f"DEFAULT {default}"
+ def generatedasidentitycolumnconstraint_sql(self, expression):
+ return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
+
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
@@ -384,7 +380,10 @@ class Generator:
return f"{alias}{columns}"
def bitstring_sql(self, expression):
- return f"b'{self.sql(expression, 'this')}'"
+ return self.sql(expression, "this")
+
+ def hexstring_sql(self, expression):
+ return self.sql(expression, "this")
def datatype_sql(self, expression):
type_value = expression.this
@@ -452,10 +451,7 @@ class Generator:
def partition_sql(self, expression):
keys = csv(
- *[
- f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"]
- for k, v in expression.args.get("this")
- ]
+ *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
)
return f"PARTITION({keys})"
@@ -470,9 +466,9 @@ class Generator:
elif p_class in self.WITH_PROPERTIES:
with_properties.append(p)
- return self.root_properties(
- exp.Properties(expressions=root_properties)
- ) + self.with_properties(exp.Properties(expressions=with_properties))
+ return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
+ exp.Properties(expressions=with_properties)
+ )
def root_properties(self, properties):
if properties.expressions:
@@ -508,11 +504,7 @@ class Generator:
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
this = self.sql(expression, "this")
exists = " IF EXISTS " if expression.args.get("exists") else " "
- partition_sql = (
- self.sql(expression, "partition")
- if expression.args.get("partition")
- else ""
- )
+ partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
@@ -531,7 +523,7 @@ class Generator:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def table_sql(self, expression):
- return ".".join(
+ table = ".".join(
part
for part in [
self.sql(expression, "catalog"),
@@ -541,6 +533,10 @@ class Generator:
if part
)
+ laterals = self.expressions(expression, key="laterals", sep="")
+ joins = self.expressions(expression, key="joins", sep="")
+ return f"{table}{laterals}{joins}"
+
def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
this = self.sql(expression.this, "this")
@@ -586,11 +582,7 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
- grouping_sets = (
- f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}"
- if grouping_sets
- else ""
- )
+ grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@@ -603,7 +595,16 @@ class Generator:
def join_sql(self, expression):
op_sql = self.seg(
- " ".join(op for op in (expression.side, expression.kind, "JOIN") if op)
+ " ".join(
+ op
+ for op in (
+ "NATURAL" if expression.args.get("natural") else None,
+ expression.side,
+ expression.kind,
+ "JOIN",
+ )
+ if op
+ )
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@@ -630,9 +631,9 @@ class Generator:
def lateral_sql(self, expression):
this = self.sql(expression, "this")
- op_sql = self.seg(
- f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}"
- )
+ if isinstance(expression.this, exp.Subquery):
+ return f"LATERAL{self.sep()}{this}"
+ op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
@@ -688,21 +689,13 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
- if nulls_first and (
- (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
- ):
+ if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
nulls_sort_change = " NULLS FIRST"
- elif (
- nulls_last
- and ((asc and nulls_are_small) or (desc and nulls_are_large))
- and not nulls_are_last
- ):
+ elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
- self.unsupported(
- "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
- )
+ self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@@ -798,14 +791,20 @@ class Generator:
def window_sql(self, expression):
this = self.sql(expression, "this")
+
partition = self.expressions(expression, key="partition_by", flat=True)
partition = f"PARTITION BY {partition}" if partition else ""
+
order = expression.args.get("order")
order_sql = self.order_sql(order, flat=True) if order else ""
+
partition_sql = partition + " " if partition and order else partition
+
spec = expression.args.get("spec")
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
+
alias = self.sql(expression, "alias")
+
if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS"
else:
@@ -818,13 +817,8 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
- start = csv(
- self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" "
- )
- end = (
- csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
- or "CURRENT ROW"
- )
+ start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
+ end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@@ -879,6 +873,17 @@ class Generator:
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
+ def trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+
+ if trim_type == "LEADING":
+ return f"LTRIM({target})"
+ elif trim_type == "TRAILING":
+ return f"RTRIM({target})"
+ else:
+ return f"TRIM({target})"
+
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@@ -898,9 +903,7 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
- return self.case_sql(
- exp.Case(ifs=[expression], default=expression.args.get("false"))
- )
+ return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def in_sql(self, expression):
query = expression.args.get("query")
@@ -917,7 +920,9 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression):
- return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}"
+ unit = self.sql(expression, "unit")
+ unit = f" {unit}" if unit else ""
+ return f"INTERVAL {self.sql(expression, 'this')}{unit}"
def reference_sql(self, expression):
this = self.sql(expression, "this")
@@ -925,9 +930,7 @@ class Generator:
return f"REFERENCES {this}({expressions})"
def anonymous_sql(self, expression):
- args = self.indent(
- self.expressions(expression, flat=True), skip_first=True, skip_last=True
- )
+ args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
def paren_sql(self, expression):
@@ -1006,6 +1009,9 @@ class Generator:
def ignorenulls_sql(self, expression):
return f"{self.sql(expression, 'this')} IGNORE NULLS"
+ def respectnulls_sql(self, expression):
+ return f"{self.sql(expression, 'this')} RESPECT NULLS"
+
def intdiv_sql(self, expression):
return self.sql(
exp.Cast(
@@ -1023,6 +1029,9 @@ class Generator:
def div_sql(self, expression):
return self.binary(expression, "/")
+ def distance_sql(self, expression):
+ return self.binary(expression, "<->")
+
def dot_sql(self, expression):
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
@@ -1047,6 +1056,9 @@ class Generator:
def like_sql(self, expression):
return self.binary(expression, "LIKE")
+ def similarto_sql(self, expression):
+ return self.binary(expression, "SIMILAR TO")
+
def lt_sql(self, expression):
return self.binary(expression, "<")
@@ -1069,14 +1081,10 @@ class Generator:
return self.binary(expression, "-")
def trycast_sql(self, expression):
- return (
- f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- )
+ return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def binary(self, expression, op):
- return (
- f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
- )
+ return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression):
args = []
@@ -1089,9 +1097,7 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({args_str})"
def format_time(self, expression):
- return format_time(
- self.sql(expression, "format"), self.time_mapping, self.time_trie
- )
+ return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
expressions = expression.args.get(key or "expressions")
@@ -1102,7 +1108,14 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
- expressions = self.sep(sep).join(self.sql(e) for e in expressions)
+ sql = (self.sql(e) for e in expressions)
+ # the only time leading_comma changes the output is if pretty print is enabled
+ if self._leading_comma and self.pretty:
+ pad = " " * self.pad
+ expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
+ else:
+ expressions = self.sep(sep).join(sql)
+
if indent:
return self.indent(expressions, skip_first=False)
return expressions
@@ -1116,9 +1129,7 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
- return self.query_modifiers(
- expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
- )
+ return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
index a4c4cc2..d1146ca 100644
--- a/sqlglot/optimizer/__init__.py
+++ b/sqlglot/optimizer/__init__.py
@@ -1,2 +1,2 @@
-from sqlglot.optimizer.optimizer import optimize
+from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.schema import Schema
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index c2e021e..e060739 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -13,9 +13,7 @@ def isolate_table_selects(expression):
continue
if not isinstance(source.parent, exp.Alias):
- raise OptimizeError(
- "Tables require an alias. Run qualify_tables optimization."
- )
+ raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
parent = source.parent
diff --git a/sqlglot/optimizer/merge_derived_tables.py b/sqlglot/optimizer/merge_derived_tables.py
new file mode 100644
index 0000000..8b161fb
--- /dev/null
+++ b/sqlglot/optimizer/merge_derived_tables.py
@@ -0,0 +1,232 @@
+from collections import defaultdict
+
+from sqlglot import expressions as exp
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def merge_derived_tables(expression):
+ """
+ Rewrite sqlglot AST to merge derived tables into the outer query.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
+ >>> merge_derived_tables(expression).sql()
+ 'SELECT x.a FROM x'
+
+ Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ for outer_scope in traverse_scope(expression):
+ for subquery in outer_scope.derived_tables:
+ inner_select = subquery.unnest()
+ if (
+ isinstance(outer_scope.expression, exp.Select)
+ and isinstance(inner_select, exp.Select)
+ and _mergeable(inner_select)
+ ):
+ alias = subquery.alias_or_name
+ from_or_join = subquery.find_ancestor(exp.From, exp.Join)
+ inner_scope = outer_scope.sources[alias]
+
+ _rename_inner_sources(outer_scope, inner_scope, alias)
+ _merge_from(outer_scope, inner_scope, subquery)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
+ _merge_expressions(outer_scope, inner_scope, alias)
+ _merge_where(outer_scope, inner_scope, from_or_join)
+ _merge_order(outer_scope, inner_scope)
+ return expression
+
+
+# If a derived table has these Select args, it can't be merged
+UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
+ "expressions",
+ "from",
+ "joins",
+ "where",
+ "order",
+}
+
+
+def _mergeable(inner_select):
+ """
+ Return True if `inner_select` can be merged into outer query.
+
+ Args:
+ inner_select (exp.Select)
+ Returns:
+ bool: True if can be merged
+ """
+ return (
+ isinstance(inner_select, exp.Select)
+ and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
+ and inner_select.args.get("from")
+ and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
+ )
+
+
+def _rename_inner_sources(outer_scope, inner_scope, alias):
+ """
+ Renames any sources in the inner query that conflict with names in the outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ alias (str)
+ """
+ taken = set(outer_scope.selected_sources)
+ conflicts = taken.intersection(set(inner_scope.selected_sources))
+ conflicts = conflicts - {alias}
+
+ for conflict in conflicts:
+ new_name = _find_new_name(taken, conflict)
+
+ source, _ = inner_scope.selected_sources[conflict]
+ new_alias = exp.to_identifier(new_name)
+
+ if isinstance(source, exp.Subquery):
+ source.set("alias", exp.TableAlias(this=new_alias))
+ elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
+ source.parent.set("alias", new_alias)
+ elif isinstance(source, exp.Table):
+ source.replace(exp.alias_(source.copy(), new_alias))
+
+ for column in inner_scope.source_columns(conflict):
+ column.set("table", exp.to_identifier(new_name))
+
+ inner_scope.rename_source(conflict, new_name)
+
+
+def _find_new_name(taken, base):
+ """
+ Searches for a new source name.
+
+ Args:
+ taken (set[str]): set of taken names
+ base (str): base name to alter
+ """
+ i = 2
+ new = f"{base}_{i}"
+ while new in taken:
+ i += 1
+ new = f"{base}_{i}"
+ return new
+
+
+def _merge_from(outer_scope, inner_scope, subquery):
+ """
+ Merge FROM clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ subquery (exp.Subquery)
+ """
+ new_subquery = inner_scope.expression.args.get("from").expressions[0]
+ subquery.replace(new_subquery)
+ outer_scope.remove_source(subquery.alias_or_name)
+ outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
+
+
+def _merge_joins(outer_scope, inner_scope, from_or_join):
+ """
+ Merge JOIN clauses of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ from_or_join (exp.From|exp.Join)
+ """
+
+ new_joins = []
+ comma_joins = inner_scope.expression.args.get("from").expressions[1:]
+ for subquery in comma_joins:
+ new_joins.append(exp.Join(this=subquery, kind="CROSS"))
+ outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
+
+ joins = inner_scope.expression.args.get("joins") or []
+ for join in joins:
+ new_joins.append(join)
+ outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
+
+ if new_joins:
+ outer_joins = outer_scope.expression.args.get("joins", [])
+
+ # Maintain the join order
+ if isinstance(from_or_join, exp.From):
+ position = 0
+ else:
+ position = outer_joins.index(from_or_join) + 1
+ outer_joins[position:position] = new_joins
+
+ outer_scope.expression.set("joins", outer_joins)
+
+
+def _merge_expressions(outer_scope, inner_scope, alias):
+ """
+ Merge projections of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ alias (str)
+ """
+ # Collect all columns that for the alias of the inner query
+ outer_columns = defaultdict(list)
+ for column in outer_scope.columns:
+ if column.table == alias:
+ outer_columns[column.name].append(column)
+
+ # Replace columns with the projection expression in the inner query
+ for expression in inner_scope.expression.expressions:
+ projection_name = expression.alias_or_name
+ if not projection_name:
+ continue
+ columns_to_replace = outer_columns.get(projection_name, [])
+ for column in columns_to_replace:
+ column.replace(expression.unalias())
+
+
+def _merge_where(outer_scope, inner_scope, from_or_join):
+ """
+ Merge WHERE clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ from_or_join (exp.From|exp.Join)
+ """
+ where = inner_scope.expression.args.get("where")
+ if not where or not where.this:
+ return
+
+ if isinstance(from_or_join, exp.Join) and from_or_join.side:
+ # Merge predicates from an outer join to the ON clause
+ from_or_join.on(where.this, copy=False)
+ from_or_join.set("on", simplify(from_or_join.args.get("on")))
+ else:
+ outer_scope.expression.where(where.this, copy=False)
+ outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
+
+
+def _merge_order(outer_scope, inner_scope):
+ """
+ Merge ORDER clause of inner query into outer query.
+
+ Args:
+ outer_scope (sqlglot.optimizer.scope.Scope)
+ inner_scope (sqlglot.optimizer.scope.Scope)
+ """
+ if (
+ any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
+ or len(outer_scope.selected_sources) != 1
+ or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
+ ):
+ return
+
+ outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 2c9f89c..ab30d7a 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -22,18 +22,14 @@ def normalize(expression, dnf=False, max_distance=128):
"""
expression = simplify(expression)
- expression = while_changing(
- expression, lambda e: distributive_law(e, dnf, max_distance)
- )
+ expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
return simplify(expression)
def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
- return not any(
- connector.find_ancestor(ancestor) for connector in expression.find_all(root)
- )
+ return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
def normalization_distance(expression, dnf=False):
@@ -54,9 +50,7 @@ def normalization_distance(expression, dnf=False):
Returns:
int: difference
"""
- return sum(_predicate_lengths(expression, dnf)) - (
- len(list(expression.find_all(exp.Connector))) + 1
- )
+ return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
def _predicate_lengths(expression, dnf):
@@ -73,11 +67,7 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- x = [
- a + b
- for a in _predicate_lengths(left, dnf)
- for b in _predicate_lengths(right, dnf)
- ]
+ x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
return x
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
@@ -102,9 +92,7 @@ def distributive_law(expression, dnf, max_distance):
to_func = exp.and_ if to_exp == exp.And else exp.or_
if isinstance(a, to_exp) and isinstance(b, to_exp):
- if len(tuple(a.find_all(exp.Connector))) > len(
- tuple(b.find_all(exp.Connector))
- ):
+ if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 40e4ab1..0c74e36 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -68,8 +68,4 @@ def normalize(expression):
def other_table_names(join, exclude):
- return [
- name
- for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
- if name != exclude
- ]
+ return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index c03fe3c..c8c2403 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,6 +1,7 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
@@ -10,8 +11,23 @@ from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
+RULES = (
+ qualify_tables,
+ isolate_table_selects,
+ qualify_columns,
+ pushdown_projections,
+ normalize,
+ unnest_subqueries,
+ expand_multi_table_selects,
+ pushdown_predicates,
+ optimize_joins,
+ eliminate_subqueries,
+ merge_derived_tables,
+ quote_identities,
+)
-def optimize(expression, schema=None, db=None, catalog=None):
+
+def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
"""
Rewrite a sqlglot AST into an optimized form.
@@ -25,19 +41,18 @@ def optimize(expression, schema=None, db=None, catalog=None):
3. {catalog: {db: {table: {col: type}}}}
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
+ rules (list): sequence of optimizer rules to use
+ **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
"""
+ possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy()
- expression = qualify_tables(expression, db=db, catalog=catalog)
- expression = isolate_table_selects(expression)
- expression = qualify_columns(expression, schema)
- expression = pushdown_projections(expression)
- expression = normalize(expression)
- expression = unnest_subqueries(expression)
- expression = expand_multi_table_selects(expression)
- expression = pushdown_predicates(expression)
- expression = optimize_joins(expression)
- expression = eliminate_subqueries(expression)
- expression = quote_identities(expression)
+ for rule in rules:
+
+ # Find any additional rule parameters, beyond `expression`
+ rule_params = rule.__code__.co_varnames
+ rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
+
+ expression = rule(expression, **rule_kwargs)
return expression
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index e757322..a070d70 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -42,11 +42,7 @@ def pushdown(condition, sources):
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
- predicates = list(
- condition.flatten()
- if isinstance(condition, exp.And if cnf_like else exp.Or)
- else [condition]
- )
+ predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
if cnf_like:
pushdown_cnf(predicates, sources)
@@ -105,17 +101,11 @@ def pushdown_dnf(predicates, scope):
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
- predicate_condition = (
- exp.and_(predicate_condition, condition)
- if predicate_condition
- else condition
- )
+ predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
if predicate_condition:
conditions[table] = (
- exp.or_(conditions[table], predicate_condition)
- if table in conditions
- else predicate_condition
+ exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
)
for name, node in nodes.items():
@@ -133,9 +123,7 @@ def pushdown_dnf(predicates, scope):
def nodes_for_predicate(predicate, sources):
nodes = {}
tables = exp.column_table_names(predicate)
- where_condition = isinstance(
- predicate.find_ancestor(exp.Join, exp.Where), exp.Where
- )
+ where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
for table in tables:
node, source = sources.get(table) or (None, None)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 394f49e..0bb947a 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -226,9 +226,7 @@ def _expand_stars(scope, resolver):
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
- elif isinstance(expression, exp.Column) and isinstance(
- expression.this, exp.Star
- ):
+ elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
@@ -245,9 +243,7 @@ def _expand_stars(scope, resolver):
if name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
- new_selections.append(
- alias(column, alias_) if alias_ != name else column
- )
+ new_selections.append(alias(column, alias_) if alias_ != name else column)
scope.expression.set("expressions", new_selections)
@@ -280,9 +276,7 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
- for i, (selection, aliased_column) in enumerate(
- itertools.zip_longest(scope.selects, scope.outer_column_list)
- ):
+ for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
@@ -302,11 +296,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
- if (
- scope.external_columns
- and not scope.is_unnest
- and not scope.is_correlated_subquery
- ):
+ if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
@@ -334,20 +324,14 @@ class _Resolver:
(str) table name
"""
if self._unambiguous_columns is None:
- self._unambiguous_columns = self._get_unambiguous_columns(
- self._get_all_source_columns()
- )
+ self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
- self._all_columns = set(
- column
- for columns in self._get_all_source_columns().values()
- for column in columns
- )
+ self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
return self._all_columns
def get_source_columns(self, name):
@@ -369,9 +353,7 @@ class _Resolver:
def _get_all_source_columns(self):
if self._source_columns is None:
- self._source_columns = {
- k: self.get_source_columns(k) for k in self.scope.selected_sources
- }
+ self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
return self._source_columns
def _get_unambiguous_columns(self, source_columns):
@@ -389,9 +371,7 @@ class _Resolver:
source_columns = list(source_columns.items())
first_table, first_columns = source_columns[0]
- unambiguous_columns = {
- col: first_table for col in self._find_unique_columns(first_columns)
- }
+ unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
all_columns = set(unambiguous_columns)
for table, columns in source_columns[1:]:
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 9f8b9f5..30e93ba 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -27,9 +27,7 @@ def qualify_tables(expression, db=None, catalog=None):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
- derived_table.set(
- "alias", exp.TableAlias(this=exp.to_identifier(alias_))
- )
+ derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
for source in scope.sources.values():
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
index 9968108..1761228 100644
--- a/sqlglot/optimizer/schema.py
+++ b/sqlglot/optimizer/schema.py
@@ -57,9 +57,7 @@ class MappingSchema(Schema):
for forbidden in self.forbidden_args:
if table.text(forbidden):
- raise ValueError(
- f"Schema doesn't support {forbidden}. Received: {table.sql()}"
- )
+ raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index f6f59e8..e816e10 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -104,9 +104,7 @@ class Scope:
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
- elif isinstance(node, exp.Subquery) and isinstance(
- parent, (exp.From, exp.Join)
- ):
+ elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
@@ -195,20 +193,14 @@ class Scope:
self._ensure_collected()
columns = self._raw_columns
- external_columns = [
- column
- for scope in self.subquery_scopes
- for column in scope.external_columns
- ]
+ external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
- if not (
- c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
- )
+ if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
]
return self._columns
@@ -229,9 +221,7 @@ class Scope:
for table in self.tables:
referenced_names.append(
(
- table.parent.alias
- if isinstance(table.parent, exp.Alias)
- else table.name,
+ table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
table,
)
)
@@ -274,9 +264,7 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
- self._external_columns = [
- c for c in self.columns if c.table not in self.selected_sources
- ]
+ self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
return self._external_columns
def source_columns(self, source_name):
@@ -310,6 +298,16 @@ class Scope:
columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns
+ def add_source(self, name, source):
+ """Add a source to this scope"""
+ self.sources[name] = source
+ self.clear_cache()
+
+ def remove_source(self, name):
+ """Remove a source from this scope"""
+ self.sources.pop(name, None)
+ self.clear_cache()
+
def traverse_scope(expression):
"""
@@ -334,7 +332,7 @@ def traverse_scope(expression):
Args:
expression (exp.Expression): expression to traverse
Returns:
- List[Scope]: scope instances
+ list[Scope]: scope instances
"""
return list(_traverse_scope(Scope(expression)))
@@ -356,9 +354,7 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
- yield from _traverse_derived_tables(
- scope.derived_tables, scope, ScopeType.DERIVED_TABLE
- )
+ yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
_add_table_sources(scope)
@@ -367,15 +363,11 @@ def _traverse_union(scope):
# The last scope to be yield should be the top most scope
left = None
- for left in _traverse_scope(
- scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
- ):
+ for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
yield left
right = None
- for right in _traverse_scope(
- scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
- ):
+ for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
scope.union = (left, right)
@@ -387,14 +379,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
for derived_table in derived_tables:
for child_scope in _traverse_scope(
scope.branch(
- derived_table
- if isinstance(derived_table, (exp.Unnest, exp.Lateral))
- else derived_table.this,
+ derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UNNEST
- if isinstance(derived_table, exp.Unnest)
- else scope_type,
+ scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
)
):
yield child_scope
@@ -430,9 +418,7 @@ def _add_table_sources(scope):
def _traverse_subqueries(scope):
for subquery in scope.subqueries:
top = None
- for child_scope in _traverse_scope(
- scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
- ):
+ for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 6771153..319e6b6 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -188,9 +188,7 @@ def absorb_and_eliminate(expression):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
- elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
- a.flatten()
- ):
+ elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
elif isinstance(b, kind):
# eliminate
@@ -227,9 +225,7 @@ def simplify_literals(expression):
operands.append(a)
if len(operands) < size:
- return functools.reduce(
- lambda a, b: expression.__class__(this=a, expression=b), operands
- )
+ return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 55c81c5..11c6eba 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -89,11 +89,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
return
if isinstance(predicate, exp.Binary):
- key = (
- predicate.right
- if any(node is column for node, *_ in predicate.left.walk())
- else predicate.left
- )
+ key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
else:
return
@@ -124,9 +120,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
if not value.find(exp.AggFunc) and value.this not in group_by:
- select.select(
- f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
- )
+ select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@@ -151,16 +145,12 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
- parent_predicate = _replace(
- parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
- )
+ parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
else:
- parent_predicate = _replace(
- parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
- )
+ parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
elif isinstance(parent_predicate, exp.In):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
@@ -178,9 +168,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
- parent_predicate = _replace(
- parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
- )
+ parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 9396c50..f46bafe 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -78,6 +78,7 @@ class Parser:
TokenType.TEXT,
TokenType.BINARY,
TokenType.JSON,
+ TokenType.INTERVAL,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.DATETIME,
@@ -85,6 +86,12 @@ class Parser:
TokenType.DECIMAL,
TokenType.UUID,
TokenType.GEOGRAPHY,
+ TokenType.GEOMETRY,
+ TokenType.HLLSKETCH,
+ TokenType.SUPER,
+ TokenType.SERIAL,
+ TokenType.SMALLSERIAL,
+ TokenType.BIGSERIAL,
*NESTED_TYPE_TOKENS,
}
@@ -100,13 +107,14 @@ class Parser:
ID_VAR_TOKENS = {
TokenType.VAR,
TokenType.ALTER,
+ TokenType.ALWAYS,
TokenType.BEGIN,
+ TokenType.BOTH,
TokenType.BUCKET,
TokenType.CACHE,
TokenType.COLLATE,
TokenType.COMMIT,
TokenType.CONSTRAINT,
- TokenType.CONVERT,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.ENGINE,
@@ -115,14 +123,19 @@ class Parser:
TokenType.FALSE,
TokenType.FIRST,
TokenType.FOLLOWING,
+ TokenType.FOR,
TokenType.FORMAT,
TokenType.FUNCTION,
+ TokenType.GENERATED,
+ TokenType.IDENTITY,
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.LAZY,
+ TokenType.LEADING,
TokenType.LOCATION,
+ TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
TokenType.OPTIMIZE,
@@ -141,6 +154,7 @@ class Parser:
TokenType.TABLE_FORMAT,
TokenType.TEMPORARY,
TokenType.TOP,
+ TokenType.TRAILING,
TokenType.TRUNCATE,
TokenType.TRUE,
TokenType.UNBOUNDED,
@@ -150,18 +164,15 @@ class Parser:
*TYPE_TOKENS,
}
- CASTS = {
- TokenType.CAST,
- TokenType.TRY_CAST,
- }
+ TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL}
+
+ TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
- TokenType.CONVERT,
TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
TokenType.CURRENT_TIME,
- TokenType.EXTRACT,
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
@@ -178,7 +189,6 @@ class Parser:
TokenType.DATETIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
- *CASTS,
*NESTED_TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@@ -215,6 +225,7 @@ class Parser:
FACTOR = {
TokenType.DIV: exp.IntDiv,
+ TokenType.LR_ARROW: exp.Distance,
TokenType.SLASH: exp.Div,
TokenType.STAR: exp.Mul,
}
@@ -299,14 +310,13 @@ class Parser:
PRIMARY_PARSERS = {
TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
- TokenType.STAR: lambda self, _: exp.Star(
- **{"except": self._parse_except(), "replace": self._parse_replace()}
- ),
+ TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
TokenType.NULL: lambda *_: exp.Null(),
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
+ TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self.expression(
exp.Introducer,
this=token.text,
@@ -319,13 +329,16 @@ class Parser:
TokenType.IN: lambda self, this: self._parse_in(this),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: lambda self, this: self._parse_escape(
- self.expression(exp.Like, this=this, expression=self._parse_type())
+ self.expression(exp.Like, this=this, expression=self._parse_bitwise())
),
TokenType.ILIKE: lambda self, this: self._parse_escape(
- self.expression(exp.ILike, this=this, expression=self._parse_type())
+ self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
),
TokenType.RLIKE: lambda self, this: self.expression(
- exp.RegexpLike, this=this, expression=self._parse_type()
+ exp.RegexpLike, this=this, expression=self._parse_bitwise()
+ ),
+ TokenType.SIMILAR_TO: lambda self, this: self.expression(
+ exp.SimilarTo, this=this, expression=self._parse_bitwise()
),
}
@@ -363,28 +376,21 @@ class Parser:
}
FUNCTION_PARSERS = {
- TokenType.CONVERT: lambda self, _: self._parse_convert(),
- TokenType.EXTRACT: lambda self, _: self._parse_extract(),
- **{
- token_type: lambda self, token_type: self._parse_cast(
- self.STRICT_CAST and token_type == TokenType.CAST
- )
- for token_type in CASTS
- },
+ "CONVERT": lambda self: self._parse_convert(),
+ "EXTRACT": lambda self: self._parse_extract(),
+ "SUBSTRING": lambda self: self._parse_substring(),
+ "TRIM": lambda self: self._parse_trim(),
+ "CAST": lambda self: self._parse_cast(self.STRICT_CAST),
+ "TRY_CAST": lambda self: self._parse_cast(False),
}
QUERY_MODIFIER_PARSERS = {
- "laterals": lambda self: self._parse_laterals(),
- "joins": lambda self: self._parse_joins(),
"where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
- "window": lambda self: self._match(TokenType.WINDOW)
- and self._parse_window(self._parse_id_var(), alias=True),
- "distribute": lambda self: self._parse_sort(
- TokenType.DISTRIBUTE_BY, exp.Distribute
- ),
+ "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
+ "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
"order": lambda self: self._parse_order(),
@@ -392,6 +398,8 @@ class Parser:
"offset": lambda self: self._parse_offset(),
}
+ MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
+
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
STRICT_CAST = True
@@ -457,9 +465,7 @@ class Parser:
Returns
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
"""
- return self._parse(
- parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
- )
+ return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
def parse_into(self, expression_types, raw_tokens, sql=None):
for expression_type in ensure_list(expression_types):
@@ -532,21 +538,13 @@ class Parser:
for k in expression.args:
if k not in expression.arg_types:
- self.raise_error(
- f"Unexpected keyword: '{k}' for {expression.__class__}"
- )
+ self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
for k, mandatory in expression.arg_types.items():
v = expression.args.get(k)
if mandatory and (v is None or (isinstance(v, list) and not v)):
- self.raise_error(
- f"Required keyword: '{k}' missing for {expression.__class__}"
- )
+ self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
- if (
- args
- and len(args) > len(expression.arg_types)
- and not expression.is_var_len_args
- ):
+ if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args:
self.raise_error(
f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(expression.arg_types)})"
@@ -594,11 +592,7 @@ class Parser:
)
expression = self._parse_expression()
- expression = (
- self._parse_set_operations(expression)
- if expression
- else self._parse_select()
- )
+ expression = self._parse_set_operations(expression) if expression else self._parse_select()
self._parse_query_modifiers(expression)
return expression
@@ -618,11 +612,7 @@ class Parser:
)
def _parse_exists(self, not_=False):
- return (
- self._match(TokenType.IF)
- and (not not_ or self._match(TokenType.NOT))
- and self._match(TokenType.EXISTS)
- )
+ return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
@@ -647,11 +637,9 @@ class Parser:
this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
this = self._parse_table(schema=True)
- properties = self._parse_properties(
- this if isinstance(this, exp.Schema) else None
- )
+ properties = self._parse_properties(this if isinstance(this, exp.Schema) else None)
if self._match(TokenType.ALIAS):
- expression = self._parse_select()
+ expression = self._parse_select(nested=True)
return self.expression(
exp.Create,
@@ -682,9 +670,7 @@ class Parser:
if schema and not isinstance(value, exp.Schema):
columns = {v.name.upper() for v in value.expressions}
partitions = [
- expression
- for expression in schema.expressions
- if expression.this.name.upper() in columns
+ expression for expression in schema.expressions if expression.this.name.upper() in columns
]
schema.set(
"expressions",
@@ -811,7 +797,7 @@ class Parser:
this=self._parse_table(schema=True),
exists=self._parse_exists(),
partition=self._parse_partition(),
- expression=self._parse_select(),
+ expression=self._parse_select(nested=True),
overwrite=overwrite,
)
@@ -829,8 +815,7 @@ class Parser:
exp.Update,
**{
"this": self._parse_table(schema=True),
- "expressions": self._match(TokenType.SET)
- and self._parse_csv(self._parse_equality),
+ "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(),
"where": self._parse_where(),
},
@@ -865,7 +850,7 @@ class Parser:
this=table,
lazy=lazy,
options=options,
- expression=self._parse_select(),
+ expression=self._parse_select(nested=True),
)
def _parse_partition(self):
@@ -894,9 +879,7 @@ class Parser:
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
- def _parse_select(self, table=None):
- index = self._index
-
+ def _parse_select(self, nested=False, table=False):
if self._match(TokenType.SELECT):
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
@@ -912,9 +895,7 @@ class Parser:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True)
- expressions = self._parse_csv(
- lambda: self._parse_annotation(self._parse_expression())
- )
+ expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
this = self.expression(
exp.Select,
@@ -960,19 +941,13 @@ class Parser:
)
else:
self.raise_error(f"{this.key} does not support CTE")
- elif self._match(TokenType.L_PAREN):
- this = self._parse_table() if table else self._parse_select()
-
- if this:
- self._parse_query_modifiers(this)
- self._match_r_paren()
- this = self._parse_subquery(this)
- else:
- self._retreat(index)
+ elif (table or nested) and self._match(TokenType.L_PAREN):
+ this = self._parse_table() if table else self._parse_select(nested=True)
+ self._parse_query_modifiers(this)
+ self._match_r_paren()
+ this = self._parse_subquery(this)
elif self._match(TokenType.VALUES):
- this = self.expression(
- exp.Values, expressions=self._parse_csv(self._parse_value)
- )
+ this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
alias = self._parse_table_alias()
if alias:
this = self.expression(exp.Subquery, this=this, alias=alias)
@@ -1001,7 +976,7 @@ class Parser:
def _parse_table_alias(self):
any_token = self._match(TokenType.ALIAS)
- alias = self._parse_id_var(any_token)
+ alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
columns = None
if self._match(TokenType.L_PAREN):
@@ -1021,9 +996,24 @@ class Parser:
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
def _parse_query_modifiers(self, this):
- if not isinstance(this, (exp.Subquery, exp.Subqueryable)):
+ if not isinstance(this, self.MODIFIABLES):
return
+ table = isinstance(this, exp.Table)
+
+ while True:
+ lateral = self._parse_lateral()
+ join = self._parse_join()
+ comma = None if table else self._match(TokenType.COMMA)
+ if lateral:
+ this.append("laterals", lateral)
+ if join:
+ this.append("joins", join)
+ if comma:
+ this.args["from"].append("expressions", self._parse_table())
+ if not (lateral or join or comma):
+ break
+
for key, parser in self.QUERY_MODIFIER_PARSERS.items():
expression = parser(self)
@@ -1032,9 +1022,7 @@ class Parser:
def _parse_annotation(self, expression):
if self._match(TokenType.ANNOTATION):
- return self.expression(
- exp.Annotation, this=self._prev.text, expression=expression
- )
+ return self.expression(exp.Annotation, this=self._prev.text, expression=expression)
return expression
@@ -1052,16 +1040,16 @@ class Parser:
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
- def _parse_laterals(self):
- return self._parse_all(self._parse_lateral)
-
def _parse_lateral(self):
if not self._match(TokenType.LATERAL):
return None
- if not self._match(TokenType.VIEW):
- self.raise_error("Expected VIEW after LATERAL")
+ subquery = self._parse_select(table=True)
+ if subquery:
+ return self.expression(exp.Lateral, this=subquery)
+
+ self._match(TokenType.VIEW)
outer = self._match(TokenType.OUTER)
return self.expression(
@@ -1071,31 +1059,27 @@ class Parser:
alias=self.expression(
exp.TableAlias,
this=self._parse_id_var(any_token=False),
- columns=(
- self._parse_csv(self._parse_id_var)
- if self._match(TokenType.ALIAS)
- else None
- ),
+ columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None),
),
)
- def _parse_joins(self):
- return self._parse_all(self._parse_join)
-
def _parse_join_side_and_kind(self):
return (
+ self._match(TokenType.NATURAL) and self._prev,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
def _parse_join(self):
- side, kind = self._parse_join_side_and_kind()
+ natural, side, kind = self._parse_join_side_and_kind()
if not self._match(TokenType.JOIN):
return None
kwargs = {"this": self._parse_table()}
+ if natural:
+ kwargs["natural"] = True
if side:
kwargs["side"] = side.text
if kind:
@@ -1120,6 +1104,11 @@ class Parser:
)
def _parse_table(self, schema=False):
+ lateral = self._parse_lateral()
+
+ if lateral:
+ return lateral
+
unnest = self._parse_unnest()
if unnest:
@@ -1172,9 +1161,7 @@ class Parser:
expressions = self._parse_csv(self._parse_column)
self._match_r_paren()
- ordinality = bool(
- self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)
- )
+ ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
alias = self._parse_table_alias()
@@ -1280,17 +1267,13 @@ class Parser:
if not self._match(TokenType.ORDER_BY):
return this
- return self.expression(
- exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
- )
+ return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
def _parse_sort(self, token_type, exp_class):
if not self._match(token_type):
return None
- return self.expression(
- exp_class, expressions=self._parse_csv(self._parse_ordered)
- )
+ return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
def _parse_ordered(self):
this = self._parse_conjunction()
@@ -1305,22 +1288,17 @@ class Parser:
if (
not explicitly_null_ordered
and (
- (asc and self.null_ordering == "nulls_are_small")
- or (desc and self.null_ordering != "nulls_are_small")
+ (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
)
and self.null_ordering != "nulls_are_last"
):
nulls_first = True
- return self.expression(
- exp.Ordered, this=this, desc=desc, nulls_first=nulls_first
- )
+ return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
def _parse_limit(self, this=None, top=False):
if self._match(TokenType.TOP if top else TokenType.LIMIT):
- return self.expression(
- exp.Limit, this=this, expression=self._parse_number()
- )
+ return self.expression(exp.Limit, this=this, expression=self._parse_number())
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
@@ -1354,7 +1332,7 @@ class Parser:
expression,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
- expression=self._parse_select(),
+ expression=self._parse_select(nested=True),
)
def _parse_expression(self):
@@ -1396,9 +1374,7 @@ class Parser:
this = self.expression(exp.In, this=this, unnest=unnest)
else:
self._match_l_paren()
- expressions = self._parse_csv(
- lambda: self._parse_select() or self._parse_expression()
- )
+ expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression())
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0])
@@ -1430,13 +1406,9 @@ class Parser:
expression=self._parse_term(),
)
elif self._match_pair(TokenType.LT, TokenType.LT):
- this = self.expression(
- exp.BitwiseLeftShift, this=this, expression=self._parse_term()
- )
+ this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
elif self._match_pair(TokenType.GT, TokenType.GT):
- this = self.expression(
- exp.BitwiseRightShift, this=this, expression=self._parse_term()
- )
+ this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
else:
break
@@ -1524,7 +1496,7 @@ class Parser:
self.raise_error("Expecting >")
if type_token in self.TIMESTAMPS:
- tz = self._match(TokenType.WITH_TIME_ZONE)
+ tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
self._match(TokenType.WITHOUT_TIME_ZONE)
if tz:
return exp.DataType(
@@ -1594,16 +1566,14 @@ class Parser:
if query:
expressions = [query]
else:
- expressions = self._parse_csv(
- lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
- )
+ expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
this = list_get(expressions, 0)
self._parse_query_modifiers(this)
self._match_r_paren()
if isinstance(this, exp.Subqueryable):
- return self._parse_subquery(this)
+ return self._parse_set_operations(self._parse_subquery(this))
if len(expressions) > 1:
return self.expression(exp.Tuple, expressions=expressions)
return self.expression(exp.Paren, this=this)
@@ -1611,11 +1581,7 @@ class Parser:
return None
def _parse_field(self, any_token=False):
- return (
- self._parse_primary()
- or self._parse_function()
- or self._parse_id_var(any_token)
- )
+ return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
def _parse_function(self):
if not self._curr:
@@ -1628,21 +1594,22 @@ class Parser:
if not self._next or self._next.token_type != TokenType.L_PAREN:
if token_type in self.NO_PAREN_FUNCTIONS:
- return self.expression(
- self._advance() or self.NO_PAREN_FUNCTIONS[token_type]
- )
+ return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type])
return None
if token_type not in self.FUNC_TOKENS:
return None
- if self._match_set(self.FUNCTION_PARSERS):
- self._advance()
- this = self.FUNCTION_PARSERS[token_type](self, token_type)
+ this = self._curr.text
+ upper = this.upper()
+ self._advance(2)
+
+ parser = self.FUNCTION_PARSERS.get(upper)
+
+ if parser:
+ this = parser(self)
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
- this = self._curr.text
- self._advance(2)
if subquery_predicate and self._curr.token_type in (
TokenType.SELECT,
@@ -1652,7 +1619,7 @@ class Parser:
self._match_r_paren()
return this
- function = self.FUNCTIONS.get(this.upper())
+ function = self.FUNCTIONS.get(upper)
args = self._parse_csv(self._parse_lambda)
if function:
@@ -1700,10 +1667,7 @@ class Parser:
self._retreat(index)
return this
- args = self._parse_csv(
- lambda: self._parse_constraint()
- or self._parse_column_def(self._parse_field())
- )
+ args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@@ -1720,12 +1684,9 @@ class Parser:
break
constraints.append(constraint)
- return self.expression(
- exp.ColumnDef, this=this, kind=kind, constraints=constraints
- )
+ return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
def _parse_column_constraint(self):
- kind = None
this = None
if self._match(TokenType.CONSTRAINT):
@@ -1735,28 +1696,28 @@ class Parser:
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
self._match_l_paren()
- kind = self.expression(
- exp.CheckColumnConstraint, this=self._parse_conjunction()
- )
+ kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
self._match_r_paren()
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
- kind = self.expression(
- exp.DefaultColumnConstraint, this=self._parse_field()
- )
- elif self._match(TokenType.NOT) and self._match(TokenType.NULL):
+ kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
+ elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.SCHEMA_COMMENT):
- kind = self.expression(
- exp.CommentColumnConstraint, this=self._parse_string()
- )
+ kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
kind = exp.PrimaryKeyColumnConstraint()
elif self._match(TokenType.UNIQUE):
kind = exp.UniqueColumnConstraint()
-
- if kind is None:
+ elif self._match(TokenType.GENERATED):
+ if self._match(TokenType.BY_DEFAULT):
+ kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
+ else:
+ self._match(TokenType.ALWAYS)
+ kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
+ self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
+ else:
return None
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
@@ -1864,9 +1825,7 @@ class Parser:
if not self._match(TokenType.END):
self.raise_error("Expected END after CASE", self._prev)
- return self._parse_window(
- self.expression(exp.Case, this=expression, ifs=ifs, default=default)
- )
+ return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
def _parse_if(self):
if self._match(TokenType.L_PAREN):
@@ -1889,7 +1848,7 @@ class Parser:
if not self._match(TokenType.FROM):
self.raise_error("Expected FROM after EXTRACT", self._prev)
- return self.expression(exp.Extract, this=this, expression=self._parse_type())
+ return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
def _parse_cast(self, strict):
this = self._parse_conjunction()
@@ -1917,12 +1876,54 @@ class Parser:
to = None
return self.expression(exp.Cast, this=this, to=to)
+ def _parse_substring(self):
+ # Postgres supports the form: substring(string [from int] [for int])
+ # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
+
+ args = self._parse_csv(self._parse_bitwise)
+
+ if self._match(TokenType.FROM):
+ args.append(self._parse_bitwise())
+ if self._match(TokenType.FOR):
+ args.append(self._parse_bitwise())
+
+ this = exp.Substring.from_arg_list(args)
+ self.validate_expression(this, args)
+
+ return this
+
+ def _parse_trim(self):
+ # https://www.w3resource.com/sql/character-functions/trim.php
+ # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
+
+ position = None
+ collation = None
+
+ if self._match_set(self.TRIM_TYPES):
+ position = self._prev.text.upper()
+
+ expression = self._parse_term()
+ if self._match(TokenType.FROM):
+ this = self._parse_term()
+ else:
+ this = expression
+ expression = None
+
+ if self._match(TokenType.COLLATE):
+ collation = self._parse_term()
+
+ return self.expression(
+ exp.Trim,
+ this=this,
+ position=position,
+ expression=expression,
+ collation=collation,
+ )
+
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
self._match_l_paren()
- this = self.expression(
- exp.Filter, this=this, expression=self._parse_where()
- )
+ this = self.expression(exp.Filter, this=this, expression=self._parse_where())
self._match_r_paren()
if self._match(TokenType.WITHIN_GROUP):
@@ -1935,6 +1936,25 @@ class Parser:
self._match_r_paren()
return this
+ # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
+ # Some dialects choose to implement and some do not.
+ # https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html
+
+ # There is some code above in _parse_lambda that handles
+ # SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ...
+
+ # The below changes handle
+ # SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ...
+
+ # Oracle allows both formats
+ # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
+ # and Snowflake chose to do the same for familiarity
+ # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
+ if self._match(TokenType.IGNORE_NULLS):
+ this = self.expression(exp.IgnoreNulls, this=this)
+ elif self._match(TokenType.RESPECT_NULLS):
+ this = self.expression(exp.RespectNulls, this=this)
+
# bigquery select from window x AS (partition by ...)
if alias:
self._match(TokenType.ALIAS)
@@ -1992,13 +2012,9 @@ class Parser:
self._match(TokenType.BETWEEN)
return {
- "value": (
- self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW))
- and self._prev.text
- )
+ "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
or self._parse_bitwise(),
- "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING))
- and self._prev.text,
+ "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
}
def _parse_alias(self, this, explicit=False):
@@ -2023,22 +2039,16 @@ class Parser:
return this
- def _parse_id_var(self, any_token=True):
+ def _parse_id_var(self, any_token=True, tokens=None):
identifier = self._parse_identifier()
if identifier:
return identifier
- if (
- any_token
- and self._curr
- and self._curr.token_type not in self.RESERVED_KEYWORDS
- ):
+ if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
- return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier(
- this=self._prev.text, quoted=False
- )
+ return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
def _parse_string(self):
if self._match(TokenType.STRING):
@@ -2077,9 +2087,7 @@ class Parser:
def _parse_star(self):
if self._match(TokenType.STAR):
- return exp.Star(
- **{"except": self._parse_except(), "replace": self._parse_replace()}
- )
+ return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
return None
def _parse_placeholder(self):
@@ -2117,15 +2125,10 @@ class Parser:
this = parse()
while self._match_set(expressions):
- this = self.expression(
- expressions[self._prev.token_type], this=this, expression=parse()
- )
+ this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
return this
- def _parse_all(self, parse):
- return list(iter(parse, None))
-
def _parse_wrapped_id_vars(self):
self._match_l_paren()
expressions = self._parse_csv(self._parse_id_var)
@@ -2156,10 +2159,7 @@ class Parser:
if not self._curr or not self._next:
return None
- if (
- self._curr.token_type == token_type_a
- and self._next.token_type == token_type_b
- ):
+ if self._curr.token_type == token_type_a and self._next.token_type == token_type_b:
if advance:
self._advance(2)
return True
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 2006a75..ed0b66c 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -72,9 +72,7 @@ class Step:
if from_:
from_ = from_.expressions
if len(from_) > 1:
- raise UnsupportedError(
- "Multi-from statements are unsupported. Run it through the optimizer"
- )
+ raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
step = Scan.from_expression(from_[0], ctes)
else:
@@ -104,9 +102,7 @@ class Step:
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
- operand.replace(
- exp.column(operands[operand], step.name, quoted=True)
- )
+ operand.replace(exp.column(operands[operand], step.name, quoted=True))
else:
projections.append(e)
@@ -121,14 +117,9 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
- aggregate.operands = tuple(
- alias(operand, alias_) for operand, alias_ in operands.items()
- )
+ aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
aggregate.aggregations = aggregations
- aggregate.group = [
- exp.column(e.alias_or_name, step.name, quoted=True)
- for e in group.expressions
- ]
+ aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
aggregate.add_dependency(step)
step = aggregate
@@ -212,9 +203,7 @@ class Scan(Step):
alias_ = expression.alias
if not alias_:
- raise UnsupportedError(
- "Tables/Subqueries must be aliased. Run it through the optimizer"
- )
+ raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
if isinstance(expression, exp.Subquery):
step = Step.from_expression(table, ctes)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index e4b754d..bd95bc7 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -38,6 +38,7 @@ class TokenType(AutoName):
DARROW = auto()
HASH_ARROW = auto()
DHASH_ARROW = auto()
+ LR_ARROW = auto()
ANNOTATION = auto()
DOLLAR = auto()
@@ -53,6 +54,7 @@ class TokenType(AutoName):
TABLE = auto()
VAR = auto()
BIT_STRING = auto()
+ HEX_STRING = auto()
# types
BOOLEAN = auto()
@@ -78,10 +80,17 @@ class TokenType(AutoName):
UUID = auto()
GEOGRAPHY = auto()
NULLABLE = auto()
+ GEOMETRY = auto()
+ HLLSKETCH = auto()
+ SUPER = auto()
+ SERIAL = auto()
+ SMALLSERIAL = auto()
+ BIGSERIAL = auto()
# keywords
ADD_FILE = auto()
ALIAS = auto()
+ ALWAYS = auto()
ALL = auto()
ALTER = auto()
ANALYZE = auto()
@@ -92,11 +101,12 @@ class TokenType(AutoName):
AUTO_INCREMENT = auto()
BEGIN = auto()
BETWEEN = auto()
+ BOTH = auto()
BUCKET = auto()
+ BY_DEFAULT = auto()
CACHE = auto()
CALL = auto()
CASE = auto()
- CAST = auto()
CHARACTER_SET = auto()
CHECK = auto()
CLUSTER_BY = auto()
@@ -104,7 +114,6 @@ class TokenType(AutoName):
COMMENT = auto()
COMMIT = auto()
CONSTRAINT = auto()
- CONVERT = auto()
CREATE = auto()
CROSS = auto()
CUBE = auto()
@@ -127,22 +136,24 @@ class TokenType(AutoName):
EXCEPT = auto()
EXISTS = auto()
EXPLAIN = auto()
- EXTRACT = auto()
FALSE = auto()
FETCH = auto()
FILTER = auto()
FINAL = auto()
FIRST = auto()
FOLLOWING = auto()
+ FOR = auto()
FOREIGN_KEY = auto()
FORMAT = auto()
FULL = auto()
FUNCTION = auto()
FROM = auto()
+ GENERATED = auto()
GROUP_BY = auto()
GROUPING_SETS = auto()
HAVING = auto()
HINT = auto()
+ IDENTITY = auto()
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
@@ -159,12 +170,14 @@ class TokenType(AutoName):
JOIN = auto()
LATERAL = auto()
LAZY = auto()
+ LEADING = auto()
LEFT = auto()
LIKE = auto()
LIMIT = auto()
LOCATION = auto()
MAP = auto()
MOD = auto()
+ NATURAL = auto()
NEXT = auto()
NO_ACTION = auto()
NULL = auto()
@@ -204,8 +217,10 @@ class TokenType(AutoName):
ROWS = auto()
SCHEMA_COMMENT = auto()
SELECT = auto()
+ SEPARATOR = auto()
SET = auto()
SHOW = auto()
+ SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
STORED = auto()
@@ -213,12 +228,11 @@ class TokenType(AutoName):
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
- TIME = auto()
TOP = auto()
THEN = auto()
TRUE = auto()
+ TRAILING = auto()
TRUNCATE = auto()
- TRY_CAST = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
@@ -272,35 +286,32 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
- klass.QUOTES = dict(
- (quote, quote) if isinstance(quote, str) else (quote[0], quote[1])
- for quote in klass.QUOTES
- )
-
- klass.IDENTIFIERS = dict(
- (identifier, identifier)
- if isinstance(identifier, str)
- else (identifier[0], identifier[1])
- for identifier in klass.IDENTIFIERS
- )
-
- klass.COMMENTS = dict(
- (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
- for comment in klass.COMMENTS
+ klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
+ klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
+ klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
+ klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
+ klass._COMMENTS = dict(
+ (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
)
klass.KEYWORD_TRIE = new_trie(
key.upper()
for key, value in {
**klass.KEYWORDS,
- **{comment: TokenType.COMMENT for comment in klass.COMMENTS},
- **{quote: TokenType.QUOTE for quote in klass.QUOTES},
+ **{comment: TokenType.COMMENT for comment in klass._COMMENTS},
+ **{quote: TokenType.QUOTE for quote in klass._QUOTES},
+ **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
+ **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
}.items()
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
+ @staticmethod
+ def _delimeter_list_to_dict(list):
+ return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
+
class Tokenizer(metaclass=_Tokenizer):
SINGLE_TOKENS = {
@@ -339,6 +350,10 @@ class Tokenizer(metaclass=_Tokenizer):
QUOTES = ["'"]
+ BIT_STRINGS = []
+
+ HEX_STRINGS = []
+
IDENTIFIERS = ['"']
ESCAPE = "'"
@@ -357,6 +372,7 @@ class Tokenizer(metaclass=_Tokenizer):
"->>": TokenType.DARROW,
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
+ "<->": TokenType.LR_ARROW,
"ADD ARCHIVE": TokenType.ADD_FILE,
"ADD ARCHIVES": TokenType.ADD_FILE,
"ADD FILE": TokenType.ADD_FILE,
@@ -374,12 +390,12 @@ class Tokenizer(metaclass=_Tokenizer):
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN,
+ "BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
"CALL": TokenType.CALL,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
- "CAST": TokenType.CAST,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
@@ -387,7 +403,6 @@ class Tokenizer(metaclass=_Tokenizer):
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"CONSTRAINT": TokenType.CONSTRAINT,
- "CONVERT": TokenType.CONVERT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
"CUBE": TokenType.CUBE,
@@ -408,7 +423,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
- "EXTRACT": TokenType.EXTRACT,
"FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER,
@@ -437,10 +451,12 @@ class Tokenizer(metaclass=_Tokenizer):
"JOIN": TokenType.JOIN,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
+ "LEADING": TokenType.LEADING,
"LEFT": TokenType.LEFT,
"LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT,
"LOCATION": TokenType.LOCATION,
+ "NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
@@ -490,8 +506,8 @@ class Tokenizer(metaclass=_Tokenizer):
"TEMPORARY": TokenType.TEMPORARY,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
+ "TRAILING": TokenType.TRAILING,
"TRUNCATE": TokenType.TRUNCATE,
- "TRY_CAST": TokenType.TRY_CAST,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNNEST": TokenType.UNNEST,
@@ -626,14 +642,12 @@ class Tokenizer(metaclass=_Tokenizer):
break
white_space = self.WHITE_SPACE.get(self._char)
- identifier_end = self.IDENTIFIERS.get(self._char)
+ identifier_end = self._IDENTIFIERS.get(self._char)
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
- elif self._char == "0" and self._peek == "x":
- self._scan_hex()
elif self._char.isdigit():
self._scan_number()
elif identifier_end:
@@ -666,9 +680,7 @@ class Tokenizer(metaclass=_Tokenizer):
text = self._text if text is None else text
self.tokens.append(Token(token_type, text, self._line, self._col))
- if token_type in self.COMMANDS and (
- len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
- ):
+ if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
@@ -725,6 +737,8 @@ class Tokenizer(metaclass=_Tokenizer):
if self._scan_string(word):
return
+ if self._scan_numeric_string(word):
+ return
if self._scan_comment(word):
return
@@ -732,10 +746,10 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(self.KEYWORDS[word.upper()])
def _scan_comment(self, comment_start):
- if comment_start not in self.COMMENTS:
+ if comment_start not in self._COMMENTS:
return False
- comment_end = self.COMMENTS[comment_start]
+ comment_end = self._COMMENTS[comment_start]
if comment_end:
comment_end_size = len(comment_end)
@@ -749,15 +763,18 @@ class Tokenizer(metaclass=_Tokenizer):
return True
def _scan_annotation(self):
- while (
- not self._end
- and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK
- and self._peek != ","
- ):
+ while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
self._advance()
self._add(TokenType.ANNOTATION, self._text[1:])
def _scan_number(self):
+ if self._char == "0":
+ peek = self._peek.upper()
+ if peek == "B":
+ return self._scan_bits()
+ elif peek == "X":
+ return self._scan_hex()
+
decimal = False
scientific = 0
@@ -788,57 +805,71 @@ class Tokenizer(metaclass=_Tokenizer):
else:
return self._add(TokenType.NUMBER)
+ def _scan_bits(self):
+ self._advance()
+ value = self._extract_value()
+ try:
+ self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
+ except ValueError:
+ self._add(TokenType.IDENTIFIER)
+
def _scan_hex(self):
self._advance()
+ value = self._extract_value()
+ try:
+ self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
+ except ValueError:
+ self._add(TokenType.IDENTIFIER)
+ def _extract_value(self):
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
break
- try:
- self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}")
- except ValueError:
- self._add(TokenType.IDENTIFIER)
+
+ return self._text
def _scan_string(self, quote):
- quote_end = self.QUOTES.get(quote)
+ quote_end = self._QUOTES.get(quote)
if quote_end is None:
return False
- text = ""
self._advance(len(quote))
- quote_end_size = len(quote_end)
-
- while True:
- if self._char == self.ESCAPE and self._peek == quote_end:
- text += quote
- self._advance(2)
- else:
- if self._chars(quote_end_size) == quote_end:
- if quote_end_size > 1:
- self._advance(quote_end_size - 1)
- break
-
- if self._end:
- raise RuntimeError(
- f"Missing {quote} from {self._line}:{self._start}"
- )
- text += self._char
- self._advance()
+ text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
self._add(TokenType.STRING, text)
return True
+ def _scan_numeric_string(self, string_start):
+ if string_start in self._HEX_STRINGS:
+ delimiters = self._HEX_STRINGS
+ token_type = TokenType.HEX_STRING
+ base = 16
+ elif string_start in self._BIT_STRINGS:
+ delimiters = self._BIT_STRINGS
+ token_type = TokenType.BIT_STRING
+ base = 2
+ else:
+ return False
+
+ self._advance(len(string_start))
+ string_end = delimiters.get(string_start)
+ text = self._extract_string(string_end)
+
+ try:
+ self._add(token_type, f"{int(text, base)}")
+ except ValueError:
+ raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
+ return True
+
def _scan_identifier(self, identifier_end):
while self._peek != identifier_end:
if self._end:
- raise RuntimeError(
- f"Missing {identifier_end} from {self._line}:{self._start}"
- )
+ raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
self._advance()
self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1])
@@ -851,3 +882,24 @@ class Tokenizer(metaclass=_Tokenizer):
else:
break
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
+
+ def _extract_string(self, delimiter):
+ text = ""
+ delim_size = len(delimiter)
+
+ while True:
+ if self._char == self.ESCAPE and self._peek == delimiter:
+ text += delimiter
+ self._advance(2)
+ else:
+ if self._chars(delim_size) == delimiter:
+ if delim_size > 1:
+ self._advance(delim_size - 1)
+ break
+
+ if self._end:
+ raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
+ text += self._char
+ self._advance()
+
+ return text
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index e7ccb8e..7fc71dd 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -12,9 +12,7 @@ def unalias_group(expression):
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
- e.alias: i
- for i, e in enumerate(expression.parent.expressions, start=1)
- if isinstance(e, exp.Alias)
+ e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias)
}
expression = expression.copy()