summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-08-10 09:23:50 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-08-10 09:23:50 +0000
commit4cc7d5a6dcda8f275b4156a9a23bbe5380be1b53 (patch)
tree1084b1a2dd9f2782031b4aa79608db08968a5837 /sqlglot
parentReleasing debian version 17.9.1-1. (diff)
downloadsqlglot-4cc7d5a6dcda8f275b4156a9a23bbe5380be1b53.tar.xz
sqlglot-4cc7d5a6dcda8f275b4156a9a23bbe5380be1b53.zip
Merging upstream version 17.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/bigquery.py12
-rw-r--r--sqlglot/dialects/clickhouse.py10
-rw-r--r--sqlglot/dialects/dialect.py23
-rw-r--r--sqlglot/dialects/drill.py6
-rw-r--r--sqlglot/dialects/duckdb.py23
-rw-r--r--sqlglot/dialects/hive.py14
-rw-r--r--sqlglot/dialects/mysql.py10
-rw-r--r--sqlglot/dialects/postgres.py4
-rw-r--r--sqlglot/dialects/presto.py26
-rw-r--r--sqlglot/dialects/redshift.py23
-rw-r--r--sqlglot/dialects/snowflake.py5
-rw-r--r--sqlglot/dialects/starrocks.py11
-rw-r--r--sqlglot/dialects/teradata.py9
-rw-r--r--sqlglot/dialects/tsql.py20
-rw-r--r--sqlglot/expressions.py26
-rw-r--r--sqlglot/generator.py45
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py3
-rw-r--r--sqlglot/optimizer/pushdown_projections.py19
-rw-r--r--sqlglot/optimizer/qualify_columns.py84
-rw-r--r--sqlglot/parser.py58
-rw-r--r--sqlglot/tokens.py6
21 files changed, 256 insertions, 181 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index df9065f..71977dd 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -34,7 +34,7 @@ def _date_add_sql(
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
- interval = exp.Interval(this=expression.expression, unit=unit)
+ interval = exp.Interval(this=expression.expression.copy(), unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
@@ -76,16 +76,12 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope
def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
+
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
expression.set("kind", "TABLE FUNCTION")
- if isinstance(
- expression.expression,
- (
- exp.Subquery,
- exp.Literal,
- ),
- ):
+
+ if isinstance(expression.expression, (exp.Subquery, exp.Literal)):
expression.set("expression", expression.expression.this)
return self.create_sql(expression)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index ce1a486..e6b7743 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -77,7 +77,7 @@ class ClickHouse(Dialect):
FUNCTION_PARSERS.pop("MATCH")
NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy()
- NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY)
+ NO_PAREN_FUNCTION_PARSERS.pop("ANY")
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
@@ -355,6 +355,7 @@ class ClickHouse(Dialect):
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
+ expression = expression.copy()
return self.func(
"CONCAT",
*[
@@ -389,11 +390,7 @@ class ClickHouse(Dialect):
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return f"ON CLUSTER {self.sql(expression, 'this')}"
- def createable_sql(
- self,
- expression: exp.Create,
- locations: dict[exp.Properties.Location, list[exp.Property]],
- ) -> str:
+ def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
kind = self.sql(expression, "kind").upper()
if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
@@ -402,4 +399,5 @@ class ClickHouse(Dialect):
)
this_schema = self.schema_columns_sql(expression.this)
return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}"
+
return super().createable_sql(expression, locations)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 05e81ce..1d0584c 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -346,7 +346,9 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
- exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
+ exp.Like(
+ this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
+ )
)
@@ -410,7 +412,7 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
this = self.sql(expression, "this")
- struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
+ struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
return f"{this}.{struct_key}"
@@ -571,6 +573,17 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return self.sql(exp.cast(expression.this, "date"))
+# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
+def encode_decode_sql(
+ self: Generator, expression: exp.Expression, name: str, replace: bool = True
+) -> str:
+ charset = expression.args.get("charset")
+ if charset and charset.name.lower() != "utf-8":
+ self.unsupported(f"Expected utf-8 character set, got {charset}.")
+
+ return self.func(name, expression.this, expression.args.get("replace") if replace else None)
+
+
def min_or_least(self: Generator, expression: exp.Min) -> str:
name = "LEAST" if expression.expressions else "MIN"
return rename_func(name)(self, expression)
@@ -588,7 +601,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this.expressions[0]
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
- return self.func("sum", exp.func("if", cond, 1, 0))
+ return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql(self: Generator, expression: exp.Trim) -> str:
@@ -625,6 +638,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
+ expression = expression.copy()
this, *rest_args = expression.expressions
for arg in rest_args:
this = exp.DPipe(this=this, expression=arg)
@@ -674,11 +688,10 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
return names
-def simplify_literal(expression: E, copy: bool = True) -> E:
+def simplify_literal(expression: E) -> E:
if not isinstance(expression.expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
- expression = exp.maybe_copy(expression, copy)
simplify(expression.expression)
return expression
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 26d09ce..1b2681d 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -20,9 +20,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
- return (
- f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
- )
+ return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
return func
@@ -145,7 +143,7 @@ class Drill(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
- exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 219b1aa..5428e86 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
binary_from_function,
date_trunc_to_time,
datestrtodate_sql,
+ encode_decode_sql,
format_time_lambda,
no_comment_column_constraint_sql,
no_properties_sql,
@@ -32,14 +33,14 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
- return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
+ return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
- return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
+ return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
# BigQuery -> DuckDB conversion for the DATE function
@@ -167,6 +168,16 @@ class DuckDB(Dialect):
"XOR": binary_from_function(exp.BitwiseXor),
}
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS,
+ "DECODE": lambda self: self.expression(
+ exp.Decode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8")
+ ),
+ "ENCODE": lambda self: self.expression(
+ exp.Encode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8")
+ ),
+ }
+
TYPE_TOKENS = {
*parser.Parser.TYPE_TOKENS,
TokenType.UBIGINT,
@@ -215,7 +226,9 @@ class DuckDB(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
+ exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
+ exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
@@ -228,8 +241,8 @@ class DuckDB(Dialect):
exp.MonthsBetween: lambda self, e: self.func(
"DATEDIFF",
"'month'",
- exp.cast(e.expression, "timestamp"),
- exp.cast(e.this, "timestamp"),
+ exp.cast(e.expression, "timestamp", copy=True),
+ exp.cast(e.this, "timestamp", copy=True),
),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
@@ -290,7 +303,7 @@ class DuckDB(Dialect):
multiplier = 90
if multiplier:
- return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
+ return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})"
return super().interval_sql(expression)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 4e84085..aa4d845 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -59,7 +59,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS
if expression.expression.is_number:
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
else:
- modified_increment = expression.expression
+ modified_increment = expression.expression.copy()
if multiplier != 1:
modified_increment = exp.Mul( # type: ignore
this=modified_increment, expression=exp.Literal.number(multiplier)
@@ -272,8 +272,8 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
- FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS,
+ NO_PAREN_FUNCTION_PARSERS = {
+ **parser.Parser.NO_PAREN_FUNCTION_PARSERS,
"TRANSFORM": lambda self: self._parse_transform(),
}
@@ -284,10 +284,12 @@ class Hive(Dialect):
),
}
- def _parse_transform(self) -> exp.Transform | exp.QueryTransform:
- args = self._parse_csv(self._parse_lambda)
- self._match_r_paren()
+ def _parse_transform(self) -> t.Optional[exp.Transform | exp.QueryTransform]:
+ if not self._match(TokenType.L_PAREN, advance=False):
+ self._retreat(self._index - 1)
+ return None
+ args = self._parse_wrapped_csv(self._parse_lambda)
row_format_before = self._parse_row_format(match_row=True)
record_writer = None
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index a54f076..3cd99e7 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -87,9 +87,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
- return (
- f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
- )
+ return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
return func
@@ -522,7 +520,7 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
- exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")),
+ exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
@@ -556,12 +554,12 @@ class MySQL(Dialect):
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
# MySQL requires simple literal values for its LIMIT clause.
- expression = simplify_literal(expression)
+ expression = simplify_literal(expression.copy())
return super().limit_sql(expression, top=top)
def offset_sql(self, expression: exp.Offset) -> str:
# MySQL requires simple literal values for its OFFSET clause.
- expression = simplify_literal(expression)
+ expression = simplify_literal(expression.copy())
return super().offset_sql(expression)
def xor_sql(self, expression: exp.Xor) -> str:
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index ef100b1..ca44b70 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -40,10 +40,12 @@ DATE_DIFF_FACTOR = {
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+ expression = expression.copy()
+
this = self.sql(expression, "this")
unit = expression.args.get("unit")
- expression = simplify_literal(expression.copy(), copy=False).expression
+ expression = simplify_literal(expression).expression
if not isinstance(expression, exp.Literal):
self.unsupported("Cannot add non literal")
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 14ec3dd..291b478 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
Dialect,
binary_from_function,
date_trunc_to_time,
+ encode_decode_sql,
format_time_lambda,
if_sql,
left_to_substring_sql,
@@ -21,7 +22,6 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
-from sqlglot.errors import UnsupportedError
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@@ -41,6 +41,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
+ expression = expression.copy()
return self.sql(
exp.Join(
this=exp.Unnest(
@@ -59,16 +60,6 @@ def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
-def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
- _ensure_utf8(expression.args["charset"])
- return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
-
-
-def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
- _ensure_utf8(expression.args["charset"])
- return f"TO_UTF8({self.sql(expression, 'this')})"
-
-
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
@@ -106,14 +97,14 @@ def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDat
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
- return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto")
+ return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
- this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE")
+ this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE")
return self.func(
"DATE_ADD",
@@ -123,11 +114,6 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)
-def _ensure_utf8(charset: exp.Literal) -> None:
- if charset.name.lower() != "utf-8":
- raise UnsupportedError(f"Unsupported charset {charset}")
-
-
def _approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
@@ -288,9 +274,9 @@ class Presto(Dialect):
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
- exp.Decode: _decode_sql,
+ exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
- exp.Encode: _encode_sql,
+ exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index f687ba7..cdb8d0d 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -3,7 +3,11 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
-from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func
+from sqlglot.dialects.dialect import (
+ concat_to_dpipe_sql,
+ rename_func,
+ ts_or_ds_to_date_sql,
+)
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -13,6 +17,14 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
+def _parse_date_add(args: t.List) -> exp.DateAdd:
+ return exp.DateAdd(
+ this=exp.TsOrDsToDate(this=seq_get(args, 2)),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
+ )
+
+
class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
@@ -32,11 +44,8 @@ class Redshift(Postgres):
expression=seq_get(args, 1),
unit=exp.var("month"),
),
- "DATEADD": lambda args: exp.DateAdd(
- this=exp.TsOrDsToDate(this=seq_get(args, 2)),
- expression=seq_get(args, 1),
- unit=seq_get(args, 0),
- ),
+ "DATEADD": _parse_date_add,
+ "DATE_ADD": _parse_date_add,
"DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
@@ -123,7 +132,7 @@ class Redshift(Postgres):
exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
- exp.TsOrDsToDate: lambda self, e: self.sql(e.this),
+ exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 499e085..9733a85 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -297,9 +297,10 @@ class Snowflake(Dialect):
return super()._parse_id_var(any_token=any_token, tokens=tokens)
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'", "$$"]
+ QUOTES = ["'"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
+ RAW_STRINGS = ["$$"]
COMMENTS = ["--", "//", ("/*", "*/")]
KEYWORDS = {
@@ -363,6 +364,7 @@ class Snowflake(Dialect):
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
+ exp.StartsWith: rename_func("STARTSWITH"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
@@ -382,6 +384,7 @@ class Snowflake(Dialect):
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index baa62e8..4f6183c 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -17,6 +17,13 @@ class StarRocks(MySQL):
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
+ "DATEDIFF": lambda args: exp.DateDiff(
+ this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
+ ),
+ "DATE_DIFF": lambda args: exp.DateDiff(
+ this=seq_get(args, 1), expression=seq_get(args, 2), unit=seq_get(args, 0)
+ ),
+ "REGEXP": exp.RegexpLike.from_arg_list,
}
class Generator(MySQL.Generator):
@@ -32,9 +39,11 @@ class StarRocks(MySQL):
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
+ exp.DateDiff: lambda self, e: self.func(
+ "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression
+ ),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
- exp.DateDiff: rename_func("DATEDIFF"),
exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 3fac4f5..2be1a62 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.tokens import TokenType
@@ -194,11 +196,7 @@ class Teradata(Dialect):
return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})"
- def createable_sql(
- self,
- expression: exp.Create,
- locations: dict[exp.Properties.Location, list[exp.Property]],
- ) -> str:
+ def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
kind = self.sql(expression, "kind").upper()
if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
@@ -209,4 +207,5 @@ class Teradata(Dialect):
)
this_schema = self.schema_columns_sql(expression.this)
return f"{this_name}{this_properties}{self.sep()}{this_schema}"
+
return super().createable_sql(expression, locations)
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 0eb0906..131307f 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -395,6 +395,20 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
+ def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
+ """
+ T-SQL supports the syntax alias = expression in the SELECT's projection list,
+ so we transform all parsed Selects to convert their EQ projections into Aliases.
+
+ See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax
+ """
+ return [
+ exp.alias_(projection.expression, projection.this.this, copy=False)
+ if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
+ else projection
+ for projection in super()._parse_projections()
+ ]
+
def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
"""Applies to SQL Server and Azure SQL Database
COMMIT [ { TRAN | TRANSACTION }
@@ -625,11 +639,7 @@ class TSQL(Dialect):
LIMIT_FETCH = "FETCH"
- def createable_sql(
- self,
- expression: exp.Create,
- locations: dict[exp.Properties.Location, list[exp.Property]],
- ) -> str:
+ def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
sql = self.sql(expression, "this")
properties = expression.args.get("properties")
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index f8e9fee..c207751 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -192,6 +192,13 @@ class Expression(metaclass=_Expression):
return self.text("alias")
@property
+ def alias_column_names(self) -> t.List[str]:
+ table_alias = self.args.get("alias")
+ if not table_alias:
+ return []
+ return [c.name for c in table_alias.args.get("columns") or []]
+
+ @property
def name(self) -> str:
return self.text("this")
@@ -884,13 +891,6 @@ class Predicate(Condition):
class DerivedTable(Expression):
@property
- def alias_column_names(self) -> t.List[str]:
- table_alias = self.args.get("alias")
- if not table_alias:
- return []
- return [c.name for c in table_alias.args.get("columns") or []]
-
- @property
def selects(self) -> t.List[Expression]:
return self.this.selects if isinstance(self.this, Subqueryable) else []
@@ -4860,8 +4860,18 @@ def maybe_parse(
return sqlglot.parse_one(sql, read=dialect, into=into, **opts)
+@t.overload
+def maybe_copy(instance: None, copy: bool = True) -> None:
+ ...
+
+
+@t.overload
def maybe_copy(instance: E, copy: bool = True) -> E:
- return instance.copy() if copy else instance
+ ...
+
+
+def maybe_copy(instance, copy=True):
+ return instance.copy() if copy and instance else instance
def _is_wrong_expression(expression, into):
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index ed0a681..95db795 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import typing as t
+from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
@@ -676,15 +677,13 @@ class Generator:
this = f" {this}" if this else ""
return f"UNIQUE{this}"
- def createable_sql(
- self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]]
- ) -> str:
+ def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
return self.sql(expression, "this")
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
properties = expression.args.get("properties")
- properties_locs = self.locate_properties(properties) if properties else {}
+ properties_locs = self.locate_properties(properties) if properties else defaultdict()
this = self.createable_sql(expression, properties_locs)
@@ -970,9 +969,9 @@ class Generator:
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_WITH:
- with_properties.append(p)
+ with_properties.append(p.copy())
elif p_loc == exp.Properties.Location.POST_SCHEMA:
- root_properties.append(p)
+ root_properties.append(p.copy())
return self.root_properties(
exp.Properties(expressions=root_properties)
@@ -1001,30 +1000,13 @@ class Generator:
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("WITH"))
- def locate_properties(
- self, properties: exp.Properties
- ) -> t.Dict[exp.Properties.Location, list[exp.Property]]:
- properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = {
- key: [] for key in exp.Properties.Location
- }
-
+ def locate_properties(self, properties: exp.Properties) -> t.DefaultDict:
+ properties_locs = defaultdict(list)
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
- if p_loc == exp.Properties.Location.POST_NAME:
- properties_locs[exp.Properties.Location.POST_NAME].append(p)
- elif p_loc == exp.Properties.Location.POST_INDEX:
- properties_locs[exp.Properties.Location.POST_INDEX].append(p)
- elif p_loc == exp.Properties.Location.POST_SCHEMA:
- properties_locs[exp.Properties.Location.POST_SCHEMA].append(p)
- elif p_loc == exp.Properties.Location.POST_WITH:
- properties_locs[exp.Properties.Location.POST_WITH].append(p)
- elif p_loc == exp.Properties.Location.POST_CREATE:
- properties_locs[exp.Properties.Location.POST_CREATE].append(p)
- elif p_loc == exp.Properties.Location.POST_ALIAS:
- properties_locs[exp.Properties.Location.POST_ALIAS].append(p)
- elif p_loc == exp.Properties.Location.POST_EXPRESSION:
- properties_locs[exp.Properties.Location.POST_EXPRESSION].append(p)
- elif p_loc == exp.Properties.Location.UNSUPPORTED:
+ if p_loc != exp.Properties.Location.UNSUPPORTED:
+ properties_locs[p_loc].append(p.copy())
+ else:
self.unsupported(f"Unsupported property {p.key}")
return properties_locs
@@ -1646,9 +1628,9 @@ class Generator:
with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
- limit = exp.Limit(expression=limit.args.get("count"))
+ limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count")))
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
- limit = exp.Fetch(direction="FIRST", count=limit.expression)
+ limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
fetch = isinstance(limit, exp.Fetch)
@@ -1955,6 +1937,7 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
+ expression = expression.copy()
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
@@ -2261,7 +2244,7 @@ class Generator:
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
- this=exp.Div(this=expression.this, expression=expression.expression),
+ this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 9d4860e..54cf02b 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -41,5 +41,6 @@ def normalize_identifiers(expression, dialect=None):
Returns:
The transformed expression.
"""
- expression = exp.maybe_parse(expression, dialect=dialect)
+ if isinstance(expression, str):
+ expression = exp.to_identifier(expression)
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index c81fd00..b51601f 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -31,6 +31,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
"""
# Map of Scope to all columns being selected by outer queries.
schema = ensure_schema(schema)
+ source_column_alias_count = {}
referenced_columns = defaultdict(set)
# We build the scope tree (which is traversed in DFS postorder), then iterate
@@ -38,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
# columns for a particular scope are completely build by the time we get to it.
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
+ alias_count = source_column_alias_count.get(scope, 0)
- if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
+ if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
# we select from a pivoted source in the parent scope.
parent_selections = {SELECT_ALL}
@@ -59,7 +61,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:
- _remove_unused_selections(scope, parent_selections, schema)
+ _remove_unused_selections(scope, parent_selections, schema, alias_count)
if scope.expression.is_star:
continue
@@ -72,15 +74,19 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
selects[table_name].add(col_name)
# Push the selected columns down to the next scope
- for name, (_, source) in scope.selected_sources.items():
+ for name, (node, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = selects.get(name) or set()
referenced_columns[source].update(columns)
+ column_aliases = node.alias_column_names
+ if column_aliases:
+ source_column_alias_count[source] = len(column_aliases)
+
return expression
-def _remove_unused_selections(scope, parent_selections, schema):
+def _remove_unused_selections(scope, parent_selections, schema, alias_count):
order = scope.expression.args.get("order")
if order:
@@ -93,11 +99,14 @@ def _remove_unused_selections(scope, parent_selections, schema):
removed = False
star = False
+ select_all = SELECT_ALL in parent_selections
+
for selection in scope.expression.selects:
name = selection.alias_or_name
- if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
+ if select_all or name in parent_selections or name in order_refs or alias_count > 0:
new_selections.append(selection)
+ alias_count -= 1
else:
if selection.is_star:
star = True
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 9c34cef..952999d 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@@ -58,6 +59,7 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_qualify_outputs(scope)
+
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@@ -85,7 +87,7 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
"""
Remove table column aliases.
- (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
+ For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
table_alias = derived_table.args.get("alias")
@@ -111,11 +113,11 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
columns = {}
- for k in scope.selected_sources:
- if k in ordered:
- for column in resolver.get_source_columns(k):
- if column not in columns:
- columns[column] = k
+ for source_name in scope.selected_sources:
+ if source_name in ordered:
+ for column_name in resolver.get_source_columns(source_name):
+ if column_name not in columns:
+ columns[column_name] = source_name
source_table = ordered[-1]
ordered.append(join_table)
@@ -183,6 +185,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
for column, *_ in walk_in_scope(node):
if not isinstance(column, exp.Column):
continue
+
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
@@ -198,7 +201,10 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if literal_index:
column.replace(exp.Literal.number(i))
else:
- column.replace(alias_expr.copy())
+ column = column.replace(exp.paren(alias_expr))
+ simplified = simplify_parens(column)
+ if simplified is not column:
+ column.replace(simplified)
for i, projection in enumerate(scope.expression.selects):
replace_columns(projection)
@@ -213,7 +219,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
scope.clear_cache()
-def _expand_group_by(scope: Scope):
+def _expand_group_by(scope: Scope) -> None:
expression = scope.expression
group = expression.args.get("group")
if not group:
@@ -223,7 +229,7 @@ def _expand_group_by(scope: Scope):
expression.set("group", group)
-def _expand_order_by(scope: Scope, resolver: Resolver):
+def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
order = scope.expression.args.get("order")
if not order:
return
@@ -442,7 +448,7 @@ def _add_replace_columns(
replace_columns[id(table)] = columns
-def _qualify_outputs(scope: Scope):
+def _qualify_outputs(scope: Scope) -> None:
"""Ensure all output columns are aliased"""
new_selections = []
@@ -482,9 +488,9 @@ class Resolver:
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
- self._source_columns = None
+ self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
- self._all_columns = None
+ self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
@@ -528,7 +534,7 @@ class Resolver:
return exp.to_identifier(table_name)
@property
- def all_columns(self):
+ def all_columns(self) -> t.Set[str]:
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = {
@@ -536,53 +542,67 @@ class Resolver:
}
return self._all_columns
- def get_source_columns(self, name, only_visible=False):
- """Resolve the source columns for a given source `name`"""
+ def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
+ """Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
source = self.scope.sources[name]
- # If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
- return self.schema.column_names(source, only_visible)
+ columns = self.schema.column_names(source, only_visible)
+ elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
+ columns = source.expression.alias_column_names
+ else:
+ columns = source.expression.named_selects
- if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
- return source.expression.alias_column_names
+ node, _ = self.scope.selected_sources.get(name) or (None, None)
+ if isinstance(node, Scope):
+ column_aliases = node.expression.alias_column_names
+ elif isinstance(node, exp.Expression):
+ column_aliases = node.alias_column_names
+ else:
+ column_aliases = []
- # Otherwise, if referencing another scope, return that scope's named selects
- return source.expression.named_selects
+ # If the source's columns are aliased, their aliases shadow the corresponding column names
+ return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
- def _get_all_source_columns(self):
+ def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
if self._source_columns is None:
self._source_columns = {
- k: self.get_source_columns(k)
- for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
+ source_name: self.get_source_columns(source_name)
+ for source_name, source in itertools.chain(
+ self.scope.selected_sources.items(), self.scope.lateral_sources.items()
+ )
}
return self._source_columns
- def _get_unambiguous_columns(self, source_columns):
+ def _get_unambiguous_columns(
+ self, source_columns: t.Dict[str, t.List[str]]
+ ) -> t.Dict[str, str]:
"""
Find all the unambiguous columns in sources.
Args:
- source_columns (dict): Mapping of names to source columns
+ source_columns: Mapping of names to source columns.
+
Returns:
- dict: Mapping of column name to source name
+ Mapping of column name to source name.
"""
if not source_columns:
return {}
- source_columns = list(source_columns.items())
+ source_columns_pairs = list(source_columns.items())
- first_table, first_columns = source_columns[0]
+ first_table, first_columns = source_columns_pairs[0]
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
all_columns = set(unambiguous_columns)
- for table, columns in source_columns[1:]:
+ for table, columns in source_columns_pairs[1:]:
unique = self._find_unique_columns(columns)
ambiguous = set(all_columns).intersection(unique)
all_columns.update(columns)
+
for column in ambiguous:
unambiguous_columns.pop(column, None)
for column in unique.difference(ambiguous):
@@ -591,7 +611,7 @@ class Resolver:
return unambiguous_columns
@staticmethod
- def _find_unique_columns(columns):
+ def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
"""
Find the unique columns in a list of columns.
@@ -601,7 +621,7 @@ class Resolver:
This is necessary because duplicate column names are ambiguous.
"""
- counts = {}
+ counts: t.Dict[str, int] = {}
for column in columns:
counts[column] = counts.get(column, 0) + 1
return {column for column, count in counts.items() if count == 1}
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index f714c8d..35a1744 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -248,7 +248,6 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FORMAT,
TokenType.FULL,
- TokenType.IF,
TokenType.IS,
TokenType.ISNULL,
TokenType.INTERVAL,
@@ -708,14 +707,10 @@ class Parser(metaclass=_Parser):
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
NO_PAREN_FUNCTION_PARSERS = {
- TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
- TokenType.CASE: lambda self: self._parse_case(),
- TokenType.IF: lambda self: self._parse_if(),
- TokenType.NEXT_VALUE_FOR: lambda self: self.expression(
- exp.NextValueFor,
- this=self._parse_column(),
- order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
- ),
+ "ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
+ "CASE": lambda self: self._parse_case(),
+ "IF": lambda self: self._parse_if(),
+ "NEXT": lambda self: self._parse_next_value_for(),
}
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
@@ -1162,7 +1157,7 @@ class Parser(metaclass=_Parser):
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
return (
- self._match(TokenType.IF)
+ self._match_text_seq("IF")
and (not not_ or self._match(TokenType.NOT))
and self._match(TokenType.EXISTS)
)
@@ -1935,6 +1930,9 @@ class Parser(metaclass=_Parser):
# https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
+ def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
+ return self._parse_expressions()
+
def _parse_select(
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
) -> t.Optional[exp.Expression]:
@@ -1974,14 +1972,14 @@ class Parser(metaclass=_Parser):
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True)
- expressions = self._parse_expressions()
+ projections = self._parse_projections()
this = self.expression(
exp.Select,
kind=kind,
hint=hint,
distinct=distinct,
- expressions=expressions,
+ expressions=projections,
limit=limit,
)
this.comments = comments
@@ -3021,8 +3019,12 @@ class Parser(metaclass=_Parser):
while True:
if self._match_set(self.BITWISE):
this = self.expression(
- self.BITWISE[self._prev.token_type], this=this, expression=self._parse_term()
+ self.BITWISE[self._prev.token_type],
+ this=this,
+ expression=self._parse_term(),
)
+ elif self._match(TokenType.DQMARK):
+ this = self.expression(exp.Coalesce, this=this, expressions=self._parse_term())
elif self._match_pair(TokenType.LT, TokenType.LT):
this = self.expression(
exp.BitwiseLeftShift, this=this, expression=self._parse_term()
@@ -3322,9 +3324,13 @@ class Parser(metaclass=_Parser):
return None
token_type = self._curr.token_type
+ this = self._curr.text
+ upper = this.upper()
- if optional_parens and self._match_set(self.NO_PAREN_FUNCTION_PARSERS):
- return self.NO_PAREN_FUNCTION_PARSERS[token_type](self)
+ parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper)
+ if optional_parens and parser:
+ self._advance()
+ return parser(self)
if not self._next or self._next.token_type != TokenType.L_PAREN:
if optional_parens and token_type in self.NO_PAREN_FUNCTIONS:
@@ -3336,12 +3342,9 @@ class Parser(metaclass=_Parser):
if token_type not in self.FUNC_TOKENS:
return None
- this = self._curr.text
- upper = this.upper()
self._advance(2)
parser = self.FUNCTION_PARSERS.get(upper)
-
if parser and not anonymous:
this = parser(self)
else:
@@ -3368,7 +3371,7 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
- self._match(TokenType.R_PAREN, expression=this)
+ self._match_r_paren(this)
return self._parse_window(this)
def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
@@ -3703,7 +3706,11 @@ class Parser(metaclass=_Parser):
self.expression(exp.Slice, expression=self._parse_conjunction())
]
else:
- expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
+ expressions = self._parse_csv(
+ lambda: self._parse_slice(
+ self._parse_alias(self._parse_conjunction(), explicit=True)
+ )
+ )
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
if bracket_kind == TokenType.L_BRACE:
@@ -3770,6 +3777,17 @@ class Parser(metaclass=_Parser):
return self._parse_window(this)
+ def _parse_next_value_for(self) -> t.Optional[exp.Expression]:
+ if not self._match_text_seq("VALUE", "FOR"):
+ self._retreat(self._index - 1)
+ return None
+
+ return self.expression(
+ exp.NextValueFor,
+ this=self._parse_column(),
+ order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
+ )
+
def _parse_extract(self) -> exp.Extract:
this = self._parse_function() or self._parse_var() or self._parse_type()
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 729e47f..81bcc0b 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -21,6 +21,7 @@ class TokenType(AutoName):
PLUS = auto()
COLON = auto()
DCOLON = auto()
+ DQMARK = auto()
SEMICOLON = auto()
STAR = auto()
BACKSLASH = auto()
@@ -215,7 +216,6 @@ class TokenType(AutoName):
GROUPING_SETS = auto()
HAVING = auto()
HINT = auto()
- IF = auto()
IGNORE = auto()
ILIKE = auto()
ILIKE_ANY = auto()
@@ -248,7 +248,6 @@ class TokenType(AutoName):
MOD = auto()
NATURAL = auto()
NEXT = auto()
- NEXT_VALUE_FOR = auto()
NOTNULL = auto()
NULL = auto()
OFFSET = auto()
@@ -504,6 +503,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
"&&": TokenType.DAMP,
+ "??": TokenType.DQMARK,
"ALL": TokenType.ALL,
"ALWAYS": TokenType.ALWAYS,
"AND": TokenType.AND,
@@ -563,7 +563,6 @@ class Tokenizer(metaclass=_Tokenizer):
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
- "IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
@@ -586,7 +585,6 @@ class Tokenizer(metaclass=_Tokenizer):
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
- "NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
"NOT": TokenType.NOT,
"NOTNULL": TokenType.NOTNULL,
"NULL": TokenType.NULL,