summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-12 10:06:28 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-12 10:06:28 +0000
commit918abde014f9e5c75dfbe21110c379f7f70435c9 (patch)
tree3419a01e34958bffbd917fa9e600eda126ea3a87 /sqlglot
parentReleasing debian version 10.6.3-1. (diff)
downloadsqlglot-918abde014f9e5c75dfbe21110c379f7f70435c9.tar.xz
sqlglot-918abde014f9e5c75dfbe21110c379f7f70435c9.zip
Merging upstream version 11.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py3
-rw-r--r--sqlglot/dialects/bigquery.py54
-rw-r--r--sqlglot/dialects/clickhouse.py2
-rw-r--r--sqlglot/dialects/dialect.py124
-rw-r--r--sqlglot/dialects/drill.py33
-rw-r--r--sqlglot/dialects/duckdb.py8
-rw-r--r--sqlglot/dialects/hive.py2
-rw-r--r--sqlglot/dialects/mysql.py7
-rw-r--r--sqlglot/dialects/postgres.py3
-rw-r--r--sqlglot/dialects/redshift.py3
-rw-r--r--sqlglot/dialects/snowflake.py13
-rw-r--r--sqlglot/dialects/spark.py1
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/diff.py1
-rw-r--r--sqlglot/errors.py15
-rw-r--r--sqlglot/executor/__init__.py1
-rw-r--r--sqlglot/executor/python.py2
-rw-r--r--sqlglot/expressions.py107
-rw-r--r--sqlglot/generator.py54
-rw-r--r--sqlglot/lineage.py3
-rw-r--r--sqlglot/optimizer/annotate_types.py17
-rw-r--r--sqlglot/optimizer/expand_laterals.py34
-rw-r--r--sqlglot/optimizer/optimizer.py5
-rw-r--r--sqlglot/optimizer/pushdown_projections.py6
-rw-r--r--sqlglot/optimizer/qualify_columns.py30
-rw-r--r--sqlglot/optimizer/qualify_tables.py13
-rw-r--r--sqlglot/optimizer/scope.py20
-rw-r--r--sqlglot/parser.py48
-rw-r--r--sqlglot/tokens.py38
29 files changed, 452 insertions, 196 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 714897f..7b07ae1 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -1,5 +1,6 @@
"""
.. include:: ../README.md
+
----
"""
@@ -39,7 +40,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "10.6.3"
+__version__ = "11.0.1"
pretty = False
"""Whether to format generated SQL by default."""
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 90ae229..6a19b46 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -2,6 +2,8 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
+E = t.TypeVar("E", bound=exp.Expression)
+
-def _date_add(expression_class):
+def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
def func(args):
interval = seq_get(args, 1)
return expression_class(
@@ -27,26 +31,26 @@ def _date_add(expression_class):
return func
-def _date_trunc(args):
+def _date_trunc(args: t.Sequence) -> exp.Expression:
unit = seq_get(args, 1)
if isinstance(unit, exp.Column):
unit = exp.Var(this=unit.name)
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
-def _date_add_sql(data_type, kind):
+def _date_add_sql(
+ data_type: str, kind: str
+) -> t.Callable[[generator.Generator, exp.Expression], str]:
def func(self, expression):
this = self.sql(expression, "this")
- unit = self.sql(expression, "unit") or "'day'"
- expression = self.sql(expression, "expression")
- return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})"
+ return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})"
return func
-def _derived_table_values_to_unnest(self, expression):
+def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From):
- expression = transforms.remove_precision_parameterized_types(expression)
+ expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
return self.values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
@@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression):
return self.unnest_sql(unnest_exp)
-def _returnsproperty_sql(self, expression):
+def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
@@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression):
return f"RETURNS {this}"
-def _create_sql(self, expression):
- kind = expression.args.get("kind")
+def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
+ kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
@@ -89,6 +93,29 @@ def _create_sql(self, expression):
return self.create_sql(expression)
+def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
+ """Remove references to unnest table aliases since bigquery doesn't allow them.
+
+ These are added by the optimizer's qualify_column step.
+ """
+ if isinstance(expression, exp.Select):
+ unnests = {
+ unnest.alias
+ for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
+ if isinstance(unnest, exp.Unnest) and unnest.alias
+ }
+
+ if unnests:
+ expression = expression.copy()
+
+ for select in expression.expressions:
+ for column in select.find_all(exp.Column):
+ if column.table in unnests:
+ column.set("table", None)
+
+ return expression
+
+
class BigQuery(Dialect):
unnest_column_only = True
time_mapping = {
@@ -110,7 +137,7 @@ class BigQuery(Dialect):
]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
- ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
@@ -190,6 +217,9 @@ class BigQuery(Dialect):
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
+ exp.Select: transforms.preprocess(
+ [_unqualify_unnest], transforms.delegate("select_sql")
+ ),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 9e8c691..b553df2 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -9,7 +9,7 @@ from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
-def _lower_func(sql):
+def _lower_func(sql: str) -> str:
index = sql.index("(")
return sql[:index].lower() + sql[index:]
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 1b20e0a..176a8ce 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -11,6 +11,8 @@ from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
from sqlglot.trie import new_trie
+E = t.TypeVar("E", bound=exp.Expression)
+
class Dialects(str, Enum):
DIALECT = ""
@@ -37,14 +39,16 @@ class Dialects(str, Enum):
class _Dialect(type):
- classes: t.Dict[str, Dialect] = {}
+ classes: t.Dict[str, t.Type[Dialect]] = {}
@classmethod
- def __getitem__(cls, key):
+ def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@classmethod
- def get(cls, key, default=None):
+ def get(
+ cls, key: str, default: t.Optional[t.Type[Dialect]] = None
+ ) -> t.Optional[t.Type[Dialect]]:
return cls.classes.get(key, default)
def __new__(cls, clsname, bases, attrs):
@@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect):
generator_class = None
@classmethod
- def get_or_raise(cls, dialect):
+ def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
return cls
if isinstance(dialect, _Dialect):
@@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect):
return result
@classmethod
- def format_time(cls, expression):
+ def format_time(
+ cls, expression: t.Optional[str | exp.Expression]
+ ) -> t.Optional[exp.Expression]:
if isinstance(expression, str):
return exp.Literal.string(
format_time(
@@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect):
)
return expression
- def parse(self, sql, **opts):
+ def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
- def parse_into(self, expression_type, sql, **opts):
+ def parse_into(
+ self, expression_type: exp.IntoType, sql: str, **opts
+ ) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
- def generate(self, expression, **opts):
+ def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression)
- def transpile(self, code, **opts):
- return self.generate(self.parse(code), **opts)
+ def transpile(self, sql: str, **opts) -> t.List[str]:
+ return [self.generate(expression, **opts) for expression in self.parse(sql)]
@property
- def tokenizer(self):
+ def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
- self._tokenizer = self.tokenizer_class()
+ self._tokenizer = self.tokenizer_class() # type: ignore
return self._tokenizer
- def parser(self, **opts):
- return self.parser_class(
+ def parser(self, **opts) -> Parser:
+ return self.parser_class( # type: ignore
**{
"index_offset": self.index_offset,
"unnest_column_only": self.unnest_column_only,
@@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect):
},
)
- def generator(self, **opts):
- return self.generator_class(
+ def generator(self, **opts) -> Generator:
+ return self.generator_class( # type: ignore
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
- "escape": self.tokenizer_class.ESCAPES[0],
+ "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
+ "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect):
)
-if t.TYPE_CHECKING:
- DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
+DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
-def rename_func(name):
+def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
def _rename(self, expression):
args = flatten(expression.args.values())
return f"{self.normalize_func(name)}({self.format_args(*args)})"
@@ -214,32 +222,34 @@ def rename_func(name):
return _rename
-def approx_count_distinct_sql(self, expression):
+def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
-def if_sql(self, expression):
+def if_sql(self: Generator, expression: exp.If) -> str:
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"IF({expressions})"
-def arrow_json_extract_sql(self, expression):
+def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
return self.binary(expression, "->")
-def arrow_json_extract_scalar_sql(self, expression):
+def arrow_json_extract_scalar_sql(
+ self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
+) -> str:
return self.binary(expression, "->>")
-def inline_array_sql(self, expression):
+def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression)}]"
-def no_ilike_sql(self, expression):
+def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(
this=exp.Lower(this=expression.this),
@@ -248,44 +258,44 @@ def no_ilike_sql(self, expression):
)
-def no_paren_current_date_sql(self, expression):
+def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
-def no_recursive_cte_sql(self, expression):
+def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
if expression.args.get("recursive"):
self.unsupported("Recursive CTEs are unsupported")
expression.args["recursive"] = False
return self.with_sql(expression)
-def no_safe_divide_sql(self, expression):
+def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
return f"IF({d} <> 0, {n} / {d}, NULL)"
-def no_tablesample_sql(self, expression):
+def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
self.unsupported("TABLESAMPLE unsupported")
return self.sql(expression.this)
-def no_pivot_sql(self, expression):
+def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
return self.sql(expression)
-def no_trycast_sql(self, expression):
+def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression)
-def no_properties_sql(self, expression):
+def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
self.unsupported("Properties unsupported")
return ""
-def str_position_sql(self, expression):
+def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
@@ -294,13 +304,15 @@ def str_position_sql(self, expression):
return f"STRPOS({this}, {substr})"
-def struct_extract_sql(self, expression):
+def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
this = self.sql(expression, "this")
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
return f"{this}.{struct_key}"
-def var_map_sql(self, expression, map_func_name="MAP"):
+def var_map_sql(
+ self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
+) -> str:
keys = expression.args["keys"]
values = expression.args["values"]
@@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"):
return f"{map_func_name}({self.format_args(*args)})"
-def format_time_lambda(exp_class, dialect, default=None):
+def format_time_lambda(
+ exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
+) -> t.Callable[[t.Sequence], E]:
"""Helper used for time expressions.
- Args
- exp_class (Class): the expression class to instantiate
- dialect (string): sql dialect
- default (Option[bool | str]): the default format, True being time
+ Args:
+ exp_class: the expression class to instantiate.
+ dialect: target sql dialect.
+ default: the default format, True being time.
+
+ Returns:
+ A callable that can be used to return the appropriately formatted time expression.
"""
- def _format_time(args):
+ def _format_time(args: t.Sequence):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
- seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
+ seq_get(args, 1)
+ or (Dialect[dialect].time_format if default is True else default or None)
),
)
return _format_time
-def create_with_partitions_sql(self, expression):
+def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
@@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression):
return self.create_sql(expression)
-def parse_date_delta(exp_class, unit_mapping=None):
- def inner_func(args):
+def parse_date_delta(
+ exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
+) -> t.Callable[[t.Sequence], E]:
+ def inner_func(args: t.Sequence) -> E:
unit_based = len(args) == 3
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
- unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
+ unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
-def locate_to_strposition(args):
+def locate_to_strposition(args: t.Sequence) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -379,22 +399,22 @@ def locate_to_strposition(args):
)
-def strposition_to_locate_sql(self, expression):
+def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"
-def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
+def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
-def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
+def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
-def trim_sql(self, expression):
+def trim_sql(self: Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index d0a0251..1730eaf 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import re
+import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
@@ -16,35 +17,29 @@ from sqlglot.dialects.dialect import (
)
-def _to_timestamp(args):
- # TO_TIMESTAMP accepts either a single double argument or (text, text)
- if len(args) == 1 and args[0].is_number:
- return exp.UnixToTime.from_arg_list(args)
- return format_time_lambda(exp.StrToTime, "drill")(args)
-
-
-def _str_to_time_sql(self, expression):
+def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
-def _ts_or_ds_to_date_sql(self, expression):
+def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Drill.time_format, Drill.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
-def _date_add_sql(kind):
- def func(self, expression):
+def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
- unit = expression.text("unit").upper() or "DAY"
- expression = self.sql(expression, "expression")
- return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
+ unit = exp.Var(this=expression.text("unit").upper() or "DAY")
+ return (
+ f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
+ )
return func
-def if_sql(self, expression):
+def if_sql(self: generator.Generator, expression: exp.If) -> str:
"""
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
adds the backticks around the keyword IF.
@@ -61,7 +56,7 @@ def if_sql(self, expression):
return f"`IF`({expressions})"
-def _str_to_date(self, expression):
+def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.date_format:
@@ -111,7 +106,7 @@ class Drill(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"]
IDENTIFIERS = ["`"]
- ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
class Parser(parser.Parser):
@@ -168,10 +163,10 @@ class Drill(Dialect):
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TryCast: no_trycast_sql,
- exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
- def normalize_func(self, name):
+ def normalize_func(self, name: str) -> str:
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 95ff95c..959e5e2 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -25,10 +25,9 @@ def _str_to_time_sql(self, expression):
def _ts_or_ds_add(self, expression):
- this = self.sql(expression, "this")
- e = self.sql(expression, "expression")
+ this = expression.args.get("this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
- return f"CAST({this} AS DATE) + INTERVAL {e} {unit}"
+ return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _ts_or_ds_to_date_sql(self, expression):
@@ -40,9 +39,8 @@ def _ts_or_ds_to_date_sql(self, expression):
def _date_add(self, expression):
this = self.sql(expression, "this")
- e = self.sql(expression, "expression")
unit = self.sql(expression, "unit").strip("'") or "DAY"
- return f"{this} + INTERVAL {e} {unit}"
+ return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self, expression):
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index f2b6eaa..c558b70 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -172,7 +172,7 @@ class Hive(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
- ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
KEYWORDS = {
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index a5bd86b..c2c2c8c 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -89,8 +89,9 @@ def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
- expression = self.sql(expression, "expression")
- return f"DATE_{kind}({this}, INTERVAL {expression} {unit})"
+ return (
+ f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
+ )
return func
@@ -117,7 +118,7 @@ class MySQL(Dialect):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
- ESCAPES = ["'", "\\"]
+ STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 6418032..c709665 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -40,8 +40,7 @@ def _date_add_sql(kind):
expression = expression.copy()
expression.args["is_string"] = True
- expression = self.sql(expression)
- return f"{this} {kind} INTERVAL {expression} {unit}"
+ return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
return func
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index c3c99eb..813ee5f 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -37,11 +37,10 @@ class Redshift(Postgres):
return this
class Tokenizer(Postgres.Tokenizer):
- ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
- "COPY": TokenType.COMMAND,
"ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 3b83b02..55a6bd3 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -180,7 +180,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
- ESCAPES = ["\\", "'"]
+ STRING_ESCAPES = ["\\", "'"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@@ -191,6 +191,7 @@ class Snowflake(Dialect):
**tokens.Tokenizer.KEYWORDS,
"EXCLUDE": TokenType.EXCEPT,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
+ "PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
@@ -222,6 +223,7 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
}
TYPE_MAPPING = {
@@ -294,3 +296,12 @@ class Snowflake(Dialect):
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
return f"DESCRIBE{kind}{this}"
+
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
+ start = expression.args.get("start")
+ start = f" START {start}" if start else ""
+ increment = expression.args.get("increment")
+ increment = f" INCREMENT {increment}" if increment else ""
+ return f"AUTOINCREMENT{start}{increment}"
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 8ef4a87..03ec211 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -157,6 +157,7 @@ class Spark(Hive):
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
+ CREATE_FUNCTION_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 1b39449..a428dd5 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -49,7 +49,6 @@ class SQLite(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- "AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}
class Parser(parser.Parser):
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 7d5ec21..7530613 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -1,5 +1,6 @@
"""
.. include:: ../posts/sql_diff.md
+
----
"""
diff --git a/sqlglot/errors.py b/sqlglot/errors.py
index b5ef5ad..300c215 100644
--- a/sqlglot/errors.py
+++ b/sqlglot/errors.py
@@ -7,10 +7,17 @@ from sqlglot.helper import AutoName
class ErrorLevel(AutoName):
- IGNORE = auto() # Ignore any parser errors
- WARN = auto() # Log any parser errors with ERROR level
- RAISE = auto() # Collect all parser errors and raise a single exception
- IMMEDIATE = auto() # Immediately raise an exception on the first parser error
+ IGNORE = auto()
+ """Ignore all errors."""
+
+ WARN = auto()
+ """Log all errors."""
+
+ RAISE = auto()
+ """Collect all errors and raise a single exception."""
+
+ IMMEDIATE = auto()
+ """Immediately raise an exception on the first error found."""
class SqlglotError(Exception):
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index 67b4b00..c3d2701 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -1,5 +1,6 @@
"""
.. include:: ../../posts/python_sql_engine.md
+
----
"""
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 29848c6..de570b0 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -408,7 +408,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
- ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\"]
class Generator(generator.Generator):
TRANSFORMS = {
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 6bb083a..6800cd5 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -6,6 +6,7 @@ Every AST node in SQLGlot is represented by a subclass of `Expression`.
This module contains the implementation of all supported `Expression` types. Additionally,
it exposes a number of helper functions, which are mainly used to programmatically build
SQL expressions, such as `sqlglot.expressions.select`.
+
----
"""
@@ -137,6 +138,8 @@ class Expression(metaclass=_Expression):
return field
if isinstance(field, (Identifier, Literal, Var)):
return field.this
+ if isinstance(field, (Star, Null)):
+ return field.name
return ""
@property
@@ -176,13 +179,11 @@ class Expression(metaclass=_Expression):
return self.text("alias")
@property
- def name(self):
+ def name(self) -> str:
return self.text("this")
@property
def alias_or_name(self):
- if isinstance(self, Null):
- return "NULL"
return self.alias or self.name
@property
@@ -589,12 +590,11 @@ class Expression(metaclass=_Expression):
return load(obj)
-if t.TYPE_CHECKING:
- IntoType = t.Union[
- str,
- t.Type[Expression],
- t.Collection[t.Union[str, t.Type[Expression]]],
- ]
+IntoType = t.Union[
+ str,
+ t.Type[Expression],
+ t.Collection[t.Union[str, t.Type[Expression]]],
+]
class Condition(Expression):
@@ -939,7 +939,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
- arg_types = {"this": True, "start": False, "increment": False}
+ arg_types = {"this": False, "start": False, "increment": False}
class NotNullColumnConstraint(ColumnConstraintKind):
@@ -2390,7 +2390,7 @@ class Star(Expression):
arg_types = {"except": False, "replace": False}
@property
- def name(self):
+ def name(self) -> str:
return "*"
@property
@@ -2413,6 +2413,10 @@ class Placeholder(Expression):
class Null(Condition):
arg_types: t.Dict[str, t.Any] = {}
+ @property
+ def name(self) -> str:
+ return "NULL"
+
class Boolean(Condition):
pass
@@ -2644,7 +2648,9 @@ class Div(Binary):
class Dot(Binary):
- pass
+ @property
+ def name(self) -> str:
+ return self.expression.name
class DPipe(Binary):
@@ -2961,7 +2967,7 @@ class Cast(Func):
arg_types = {"this": True, "to": True}
@property
- def name(self):
+ def name(self) -> str:
return self.this.name
@property
@@ -4027,17 +4033,39 @@ def paren(expression) -> Paren:
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
-def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
- if alias is None:
+@t.overload
+def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None:
+ ...
+
+
+@t.overload
+def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier:
+ ...
+
+
+def to_identifier(name, quoted=None):
+ """Builds an identifier.
+
+ Args:
+ name: The name to turn into an identifier.
+ quoted: Whether or not force quote the identifier.
+
+ Returns:
+ The identifier ast node.
+ """
+
+ if name is None:
return None
- if isinstance(alias, Identifier):
- identifier = alias
- elif isinstance(alias, str):
- if quoted is None:
- quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
- identifier = Identifier(this=alias, quoted=quoted)
+
+ if isinstance(name, Identifier):
+ identifier = name
+ elif isinstance(name, str):
+ identifier = Identifier(
+ this=name,
+ quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
+ )
else:
- raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
+ raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
return identifier
@@ -4112,20 +4140,31 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return Column(this=column_name, table=table_name, **kwargs)
-def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
- """
- Create an Alias expression.
+def alias_(
+ expression: str | Expression,
+ alias: str | Identifier,
+ table: bool | t.Sequence[str | Identifier] = False,
+ quoted: t.Optional[bool] = None,
+ dialect: DialectType = None,
+ **opts,
+):
+ """Create an Alias expression.
+
Example:
>>> alias_('foo', 'bar').sql()
'foo AS bar'
+ >>> alias_('(select 1, 2)', 'bar', table=['a', 'b']).sql()
+ '(SELECT 1, 2) AS bar(a, b)'
+
Args:
- expression (str | Expression): the SQL code strings to parse.
+ expression: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
- alias (str | Identifier): the alias name to use. If the name has
+ alias: the alias name to use. If the name has
special characters it is quoted.
- table (bool): create a table alias, default false
- dialect (str): the dialect used to parse the input expression.
+ table: Whether or not to create a table alias, can also be a list of columns.
+ quoted: whether or not to quote the alias
+ dialect: the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
Returns:
@@ -4135,8 +4174,14 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
alias = to_identifier(alias, quoted=quoted)
if table:
- expression.set("alias", TableAlias(this=alias))
- return expression
+ table_alias = TableAlias(this=alias)
+ exp.set("alias", table_alias)
+
+ if not isinstance(table, bool):
+ for column in table:
+ table_alias.append("columns", to_identifier(column, quoted=quoted))
+
+ return exp
# We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
# the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index b95e9bc..0d72fe3 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
+import re
import typing as t
from sqlglot import exp
@@ -11,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
+BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)")
+
class Generator:
"""
@@ -28,7 +31,8 @@ class Generator:
identify (bool): if set to True all identifiers will be delimited by the corresponding
character.
normalize (bool): if set to True all identifiers will lower cased
- escape (str): specifies an escape character. Default: '.
+ string_escape (str): specifies a string escape character. Default: '.
+ identifier_escape (str): specifies an identifier escape character. Default: ".
pad (int): determines padding in a formatted string. Default: 2.
indent (int): determines the size of indentation in a formatted string. Default: 4.
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
@@ -85,6 +89,9 @@ class Generator:
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
+ # Whether or not create function uses an AS before the def.
+ CREATE_FUNCTION_AS = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -154,7 +161,8 @@ class Generator:
"identifier_end",
"identify",
"normalize",
- "escape",
+ "string_escape",
+ "identifier_escape",
"pad",
"index_offset",
"unnest_column_only",
@@ -167,6 +175,7 @@ class Generator:
"_indent",
"_replace_backslash",
"_escaped_quote_end",
+ "_escaped_identifier_end",
"_leading_comma",
"_max_text_width",
"_comments",
@@ -183,7 +192,8 @@ class Generator:
identifier_end=None,
identify=False,
normalize=False,
- escape=None,
+ string_escape=None,
+ identifier_escape=None,
pad=2,
indent=2,
index_offset=0,
@@ -208,7 +218,8 @@ class Generator:
self.identifier_end = identifier_end or '"'
self.identify = identify
self.normalize = normalize
- self.escape = escape or "'"
+ self.string_escape = string_escape or "'"
+ self.identifier_escape = identifier_escape or '"'
self.pad = pad
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
@@ -219,8 +230,9 @@ class Generator:
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self._indent = indent
- self._replace_backslash = self.escape == "\\"
- self._escaped_quote_end = self.escape + self.quote_end
+ self._replace_backslash = self.string_escape == "\\"
+ self._escaped_quote_end = self.string_escape + self.quote_end
+ self._escaped_identifier_end = self.identifier_escape + self.identifier_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
@@ -441,6 +453,9 @@ class Generator:
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
+ this = ""
+ if expression.this is not None:
+ this = " ALWAYS " if expression.this else " BY DEFAULT "
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
@@ -449,9 +464,7 @@ class Generator:
if start or increment:
sequence_opts = f"{start} {increment}"
sequence_opts = f" ({sequence_opts.strip()})"
- return (
- f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
- )
+ return f"GENERATED{this}AS IDENTITY{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@@ -496,7 +509,12 @@ class Generator:
properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
- expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
+ if expression_sql:
+ expression_sql = f"{begin}{self.sep()}{expression_sql}"
+
+ if self.CREATE_FUNCTION_AS or kind != "FUNCTION":
+ expression_sql = f" AS{expression_sql}"
+
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@@ -701,6 +719,7 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize else text
+ text = text.replace(self.identifier_end, self._escaped_identifier_end)
if expression.args.get("quoted") or self.identify:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@@ -1121,7 +1140,7 @@ class Generator:
text = expression.this or ""
if expression.is_string:
if self._replace_backslash:
- text = text.replace("\\", "\\\\")
+ text = BACKSLASH_RE.sub(r"\\\\", text)
text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
@@ -1486,9 +1505,16 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
- this = self.sql(expression, "this")
- this = f" {this}" if this else ""
- unit = self.sql(expression, "unit")
+ this = expression.args.get("this")
+ if this:
+ this = (
+ f" {this}"
+ if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
+ else f" ({this})"
+ )
+ else:
+ this = ""
+ unit = expression.args.get("unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL{this}{unit}"
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index a39ad8c..908f126 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.optimizer import Scope, build_scope, optimize
+from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
@@ -38,7 +39,7 @@ def lineage(
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
- rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
+ rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
dialect: DialectType = None,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index bfb2bb8..66f97a9 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -255,12 +255,23 @@ class TypeAnnotator:
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
- if isinstance(source.expression, exp.Values):
+ if isinstance(source.expression, exp.UDTF):
+ values = []
+
+ if isinstance(source.expression, exp.Lateral):
+ if isinstance(source.expression.this, exp.Explode):
+ values = [source.expression.this.this]
+ else:
+ values = source.expression.expressions[0].expressions
+
+ if not values:
+ continue
+
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
- source.expression.expressions[0].expressions,
+ values,
)
}
else:
@@ -272,7 +283,7 @@ class TypeAnnotator:
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
- elif source:
+ elif source and col.table in selects:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py
new file mode 100644
index 0000000..59f3fec
--- /dev/null
+++ b/sqlglot/optimizer/expand_laterals.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp
+
+
+def expand_laterals(expression: exp.Expression) -> exp.Expression:
+ """
+ Expand lateral column alias references.
+
+ This assumes `qualify_columns` as already run.
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> expand_laterals(expression).sql()
+ 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
+
+ Args:
+ expression: expression to optimize
+ Returns:
+ optimized expression
+ """
+ for select in expression.find_all(exp.Select):
+ alias_to_expression: t.Dict[str, exp.Expression] = {}
+ for projection in select.expressions:
+ for column in projection.find_all(exp.Column):
+ if not column.table and column.name in alias_to_expression:
+ column.replace(alias_to_expression[column.name].copy())
+ if isinstance(projection, exp.Alias):
+ alias_to_expression[projection.alias] = projection.this
+ return expression
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 766e059..96fd56b 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
+from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
-from sqlglot.optimizer.qualify_columns import qualify_columns
+from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@@ -22,6 +23,8 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
+ expand_laterals,
+ validate_qualify_columns,
pushdown_projections,
normalize,
unnest_subqueries,
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index a73647c..54c5021 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
SELECT_ALL = object()
# Selection to use if selection list is empty
-DEFAULT_SELECTION = alias("1", "_")
+DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression):
@@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION.copy())
+ new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)
if removed:
@@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
- new_selections.append(DEFAULT_SELECTION.copy())
+ new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 54425a8..ab13d01 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -37,11 +37,24 @@ def qualify_columns(expression, schema):
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
- _check_unknown_tables(scope)
return expression
+def validate_qualify_columns(expression):
+ """Raise an `OptimizeError` if any columns aren't qualified"""
+ unqualified_columns = []
+ for scope in traverse_scope(expression):
+ if isinstance(scope.expression, exp.Select):
+ unqualified_columns.extend(scope.unqualified_columns)
+ if scope.external_columns and not scope.is_correlated_subquery:
+ raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
+
+ if unqualified_columns:
+ raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
+ return expression
+
+
def _pop_table_column_aliases(derived_tables):
"""
Remove table column aliases.
@@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
- if not scope.is_subquery and not scope.is_udtf:
- if column_table is None:
- raise OptimizeError(f"Ambiguous column: {column_name}")
-
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", exp.to_identifier(column_table))
@@ -231,10 +240,8 @@ def _qualify_columns(scope, resolver):
for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name)
- if column_table is None:
- raise OptimizeError(f"Ambiguous column: {column.name}")
-
- column.set("table", exp.to_identifier(column_table))
+ if column_table:
+ column.set("table", exp.to_identifier(column_table))
def _expand_stars(scope, resolver):
@@ -322,11 +329,6 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
-def _check_unknown_tables(scope):
- if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
- raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
-
-
class _Resolver:
"""
Helper for resolving columns.
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 5d8e0d9..65593bd 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -2,7 +2,7 @@ import itertools
from sqlglot import alias, exp
from sqlglot.helper import csv_reader
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
@@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
sequence = itertools.count()
+ next_name = lambda: f"_q_{next(sequence)}"
+
for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
@@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source.copy(),
- source.this if identifier else f"_q_{next(sequence)}",
+ source.this if identifier else next_name(),
table=True,
)
)
@@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
)
+ elif isinstance(source, Scope) and source.is_udtf:
+ udtf = source.expression
+ table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
+ udtf.set("alias", table_alias)
+
+ if not table_alias.name:
+ table_alias.set("this", next_name())
return expression
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index badbb87..8565c64 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -237,6 +237,8 @@ class Scope:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
if (
not ancestor
+ # Window functions can have an ORDER BY clause
+ or not isinstance(ancestor.parent, exp.Select)
or column.table
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
@@ -479,7 +481,7 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.UDTF):
- pass
+ _set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
else:
@@ -509,6 +511,22 @@ def _traverse_union(scope):
scope.union_scopes = [left, right]
+def _set_udtf_scope(scope):
+ parent = scope.expression.parent
+ from_ = parent.args.get("from")
+
+ if not from_:
+ return
+
+ for table in from_.expressions:
+ if isinstance(table, exp.Table):
+ scope.tables.append(table)
+ elif isinstance(table, exp.Subquery):
+ scope.subqueries.append(table)
+ _add_table_sources(scope)
+ _traverse_subqueries(scope)
+
+
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index e2b2c54..579c2ce 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -194,6 +194,7 @@ class Parser(metaclass=_Parser):
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LEADING,
+ TokenType.LEFT,
TokenType.LOCAL,
TokenType.MATERIALIZED,
TokenType.MERGE,
@@ -208,6 +209,7 @@ class Parser(metaclass=_Parser):
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
+ TokenType.RIGHT,
TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
@@ -237,8 +239,10 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
+ TokenType.LEFT,
TokenType.NATURAL,
TokenType.OFFSET,
+ TokenType.RIGHT,
TokenType.WINDOW,
}
@@ -258,6 +262,8 @@ class Parser(metaclass=_Parser):
TokenType.IDENTIFIER,
TokenType.INDEX,
TokenType.ISNULL,
+ TokenType.ILIKE,
+ TokenType.LIKE,
TokenType.MERGE,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
@@ -971,13 +977,14 @@ class Parser(metaclass=_Parser):
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties()
- if self._match(TokenType.ALIAS):
- begin = self._match(TokenType.BEGIN)
- return_ = self._match_text_seq("RETURN")
- expression = self._parse_statement()
- if return_:
- expression = self.expression(exp.Return, this=expression)
+ self._match(TokenType.ALIAS)
+ begin = self._match(TokenType.BEGIN)
+ return_ = self._match_text_seq("RETURN")
+ expression = self._parse_statement()
+
+ if return_:
+ expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
elif create_token.token_type in (
@@ -2163,7 +2170,9 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
- limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
+ limit_exp = self.expression(
+ exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term()
+ )
if limit_paren:
self._match_r_paren()
@@ -2740,8 +2749,23 @@ class Parser(metaclass=_Parser):
kind: exp.Expression
- if self._match(TokenType.AUTO_INCREMENT):
- kind = exp.AutoIncrementColumnConstraint()
+ if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)):
+ start = None
+ increment = None
+
+ if self._match(TokenType.L_PAREN, advance=False):
+ args = self._parse_wrapped_csv(self._parse_bitwise)
+ start = seq_get(args, 0)
+ increment = seq_get(args, 1)
+ elif self._match_text_seq("START"):
+ start = self._parse_bitwise()
+ self._match_text_seq("INCREMENT")
+ increment = self._parse_bitwise()
+
+ if start and increment:
+ kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
+ else:
+ kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
constraint = self._parse_wrapped(self._parse_conjunction)
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
@@ -3294,8 +3318,8 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.EXCEPT):
return None
if self._match(TokenType.L_PAREN, advance=False):
- return self._parse_wrapped_id_vars()
- return self._parse_csv(self._parse_id_var)
+ return self._parse_wrapped_csv(self._parse_column)
+ return self._parse_csv(self._parse_column)
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE):
@@ -3442,7 +3466,7 @@ class Parser(metaclass=_Parser):
def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
- return None
+ return self._parse_as_command(self._prev)
exists = self._parse_exists()
this = self._parse_table(schema=True)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index e95057a..8cf17a7 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -357,7 +357,8 @@ class _Tokenizer(type):
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
- klass._ESCAPES = set(klass.ESCAPES)
+ klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
+ klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
@@ -429,9 +430,13 @@ class Tokenizer(metaclass=_Tokenizer):
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
- ESCAPES = ["'"]
+ STRING_ESCAPES = ["'"]
- _ESCAPES: t.Set[str] = set()
+ _STRING_ESCAPES: t.Set[str] = set()
+
+ IDENTIFIER_ESCAPES = ['"']
+
+ _IDENTIFIER_ESCAPES: t.Set[str] = set()
KEYWORDS = {
**{
@@ -469,6 +474,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
"AT TIME ZONE": TokenType.AT_TIME_ZONE,
+ "AUTOINCREMENT": TokenType.AUTO_INCREMENT,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN,
@@ -691,6 +697,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
+ "COPY": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND,
@@ -744,7 +751,7 @@ class Tokenizer(metaclass=_Tokenizer):
)
def __init__(self) -> None:
- self._replace_backslash = "\\" in self._ESCAPES
+ self._replace_backslash = "\\" in self._STRING_ESCAPES
self.reset()
def reset(self) -> None:
@@ -1046,12 +1053,25 @@ class Tokenizer(metaclass=_Tokenizer):
return True
def _scan_identifier(self, identifier_end: str) -> None:
- while self._peek != identifier_end:
+ text = ""
+ identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES
+
+ while True:
if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
+
self._advance()
- self._advance()
- self._add(TokenType.IDENTIFIER, self._text[1:-1])
+ if self._char == identifier_end:
+ if identifier_end_is_escape and self._peek == identifier_end:
+ text += identifier_end # type: ignore
+ self._advance()
+ continue
+
+ break
+
+ text += self._char # type: ignore
+
+ self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
while True:
@@ -1072,9 +1092,9 @@ class Tokenizer(metaclass=_Tokenizer):
while True:
if (
- self._char in self._ESCAPES
+ self._char in self._STRING_ESCAPES
and self._peek
- and (self._peek == delimiter or self._peek in self._ESCAPES)
+ and (self._peek == delimiter or self._peek in self._STRING_ESCAPES)
):
text += self._peek
self._advance(2)