summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
commit376de8b6892deca7dc5d83035c047f1e13eb67ea (patch)
tree334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot
parentReleasing debian version 20.9.0-1. (diff)
downloadsqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz
sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py6
-rw-r--r--sqlglot/_typing.py5
-rw-r--r--sqlglot/dataframe/sql/column.py8
-rw-r--r--sqlglot/dataframe/sql/dataframe.py22
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dataframe/sql/normalize.py4
-rw-r--r--sqlglot/dataframe/sql/session.py8
-rw-r--r--sqlglot/dataframe/sql/window.py14
-rw-r--r--sqlglot/dialects/bigquery.py34
-rw-r--r--sqlglot/dialects/clickhouse.py20
-rw-r--r--sqlglot/dialects/dialect.py36
-rw-r--r--sqlglot/dialects/duckdb.py20
-rw-r--r--sqlglot/dialects/hive.py22
-rw-r--r--sqlglot/dialects/oracle.py3
-rw-r--r--sqlglot/dialects/postgres.py15
-rw-r--r--sqlglot/dialects/presto.py1
-rw-r--r--sqlglot/dialects/snowflake.py60
-rw-r--r--sqlglot/dialects/spark.py6
-rw-r--r--sqlglot/dialects/spark2.py27
-rw-r--r--sqlglot/dialects/tableau.py1
-rw-r--r--sqlglot/dialects/tsql.py37
-rw-r--r--sqlglot/executor/python.py12
-rw-r--r--sqlglot/expressions.py159
-rw-r--r--sqlglot/generator.py50
-rw-r--r--sqlglot/helper.py12
-rw-r--r--sqlglot/jsonpath.py215
-rw-r--r--sqlglot/lineage.py6
-rw-r--r--sqlglot/optimizer/annotate_types.py30
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py10
-rw-r--r--sqlglot/optimizer/qualify_columns.py7
-rw-r--r--sqlglot/optimizer/qualify_tables.py31
-rw-r--r--sqlglot/optimizer/scope.py14
-rw-r--r--sqlglot/optimizer/simplify.py8
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py15
-rw-r--r--sqlglot/parser.py161
-rw-r--r--sqlglot/schema.py19
-rw-r--r--sqlglot/tokens.py2
37 files changed, 822 insertions, 280 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 6cf9949..d71c06d 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -87,13 +87,11 @@ def parse(
@t.overload
-def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
- ...
+def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
@t.overload
-def parse_one(sql: str, **opts) -> Expression:
- ...
+def parse_one(sql: str, **opts) -> Expression: ...
def parse_one(
diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py
index 86d965a..65f307e 100644
--- a/sqlglot/_typing.py
+++ b/sqlglot/_typing.py
@@ -4,10 +4,13 @@ import typing as t
import sqlglot
+if t.TYPE_CHECKING:
+ from typing_extensions import Literal as Lit # noqa
+
# A little hack for backwards compatibility with Python 3.7.
# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed.
# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency.
A = t.TypeVar("A", bound=t.Any)
-
+B = t.TypeVar("B", bound="sqlglot.exp.Binary")
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
T = t.TypeVar("T")
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index ca85376..724c5bf 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -144,9 +144,11 @@ class Column:
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
- k: [Column.ensure_col(x).expression for x in v]
- if is_iterable(v)
- else Column.ensure_col(v).expression
+ k: (
+ [Column.ensure_col(x).expression for x in v]
+ if is_iterable(v)
+ else Column.ensure_col(v).expression
+ )
for k, v in kwargs.items()
if v is not None
}
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 68d36fe..0bacbf9 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -140,12 +140,10 @@ class DataFrame:
return cte, name
@t.overload
- def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
- ...
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
@t.overload
- def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
- ...
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
@@ -496,9 +494,11 @@ class DataFrame:
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
select_column_names = [
- column.alias_or_name
- if not isinstance(column.expression.this, exp.Star)
- else column.sql()
+ (
+ column.alias_or_name
+ if not isinstance(column.expression.this, exp.Star)
+ else column.sql()
+ )
for column in self_columns + other_columns
]
select_column_names = [
@@ -552,9 +552,11 @@ class DataFrame:
), "The length of items in ascending must equal the number of columns provided"
col_and_ascending = list(zip(columns, ascending))
order_by_columns = [
- exp.Ordered(this=col.expression, desc=not asc)
- if i not in pre_ordered_col_indexes
- else columns[i].column_expression
+ (
+ exp.Ordered(this=col.expression, desc=not asc)
+ if i not in pre_ordered_col_indexes
+ else columns[i].column_expression
+ )
for i, (col, asc) in enumerate(col_and_ascending)
]
return self.copy(expression=self.expression.order_by(*order_by_columns))
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 141a302..a388cb4 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -661,7 +661,7 @@ def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
tz_column = tz if isinstance(tz, Column) else lit(tz)
- return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column)
+ return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column)
def timestamp_seconds(col: ColumnOrName) -> Column:
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index f68bacb..b246641 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -7,11 +7,11 @@ from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.helper import ensure_list
-NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
-
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.session import SparkSession
+ NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
+
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
expr = ensure_list(expr)
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 4a33ef9..f518ac2 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -82,9 +82,11 @@ class SparkSession:
]
sel_columns = [
- F.col(name).cast(data_type).alias(name).expression
- if data_type is not None
- else F.col(name).expression
+ (
+ F.col(name).cast(data_type).alias(name).expression
+ if data_type is not None
+ else F.col(name).expression
+ )
for name, data_type in column_mapping.items()
]
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
index c1d913f..9e2fabd 100644
--- a/sqlglot/dataframe/sql/window.py
+++ b/sqlglot/dataframe/sql/window.py
@@ -90,9 +90,11 @@ class WindowSpec:
**kwargs,
**{
"start_side": "PRECEDING",
- "start": "UNBOUNDED"
- if start <= Window.unboundedPreceding
- else F.lit(start).expression,
+ "start": (
+ "UNBOUNDED"
+ if start <= Window.unboundedPreceding
+ else F.lit(start).expression
+ ),
},
}
if end == Window.currentRow:
@@ -102,9 +104,9 @@ class WindowSpec:
**kwargs,
**{
"end_side": "FOLLOWING",
- "end": "UNBOUNDED"
- if end >= Window.unboundedFollowing
- else F.lit(end).expression,
+ "end": (
+ "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression
+ ),
},
}
return kwargs
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 0151e6c..771ae1a 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -5,7 +5,6 @@ import re
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
@@ -30,7 +29,7 @@ from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
- from typing_extensions import Literal
+ from sqlglot._typing import E, Lit
logger = logging.getLogger("sqlglot")
@@ -47,9 +46,11 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
exp.alias_(value, column_name)
for value, column_name in zip(
t.expressions,
- alias.columns
- if alias and alias.columns
- else (f"_c{i}" for i in range(len(t.expressions))),
+ (
+ alias.columns
+ if alias and alias.columns
+ else (f"_c{i}" for i in range(len(t.expressions)))
+ ),
)
]
)
@@ -473,12 +474,10 @@ class BigQuery(Dialect):
return table
@t.overload
- def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
- ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
- def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
- ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
@@ -546,9 +545,11 @@ class BigQuery(Dialect):
exp.ArrayContains: _array_contains_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
- exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
- if e.args.get("default")
- else f"COLLATE {self.sql(e, 'this')}",
+ exp.CollateProperty: lambda self, e: (
+ f"DEFAULT COLLATE {self.sql(e, 'this')}"
+ if e.args.get("default")
+ else f"COLLATE {self.sql(e, 'this')}"
+ ),
exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
@@ -560,6 +561,9 @@ class BigQuery(Dialect):
exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"),
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
+ exp.FromTimeZone: lambda self, e: self.func(
+ "DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'"
+ ),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
@@ -595,9 +599,9 @@ class BigQuery(Dialect):
exp.SHA2: lambda self, e: self.func(
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
- exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
- if e.name == "IMMUTABLE"
- else "NOT DETERMINISTIC",
+ exp.StabilityProperty: lambda self, e: (
+ f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
+ ),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index f2e4fe1..1248edc 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -88,6 +88,8 @@ class ClickHouse(Dialect):
"UINT8": TokenType.UTINYINT,
"IPV4": TokenType.IPV4,
"IPV6": TokenType.IPV6,
+ "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
+ "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
}
SINGLE_TOKENS = {
@@ -548,6 +550,8 @@ class ClickHouse(Dialect):
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.IPV4: "IPv4",
exp.DataType.Type.IPV6: "IPv6",
+ exp.DataType.Type.AGGREGATEFUNCTION: "AggregateFunction",
+ exp.DataType.Type.SIMPLEAGGREGATEFUNCTION: "SimpleAggregateFunction",
}
TRANSFORMS = {
@@ -651,12 +655,16 @@ class ClickHouse(Dialect):
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
return super().after_limit_modifiers(expression) + [
- self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
- if expression.args.get("settings")
- else "",
- self.seg("FORMAT ") + self.sql(expression, "format")
- if expression.args.get("format")
- else "",
+ (
+ self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
+ if expression.args.get("settings")
+ else ""
+ ),
+ (
+ self.seg("FORMAT ") + self.sql(expression, "format")
+ if expression.args.get("format")
+ else ""
+ ),
]
def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 7664c40..6be991b 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -5,7 +5,6 @@ from enum import Enum, auto
from functools import reduce
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, seq_get
@@ -14,11 +13,12 @@ from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
-B = t.TypeVar("B", bound=exp.Binary)
-
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
+if t.TYPE_CHECKING:
+ from sqlglot._typing import B, E
+
class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
@@ -381,9 +381,11 @@ class Dialect(metaclass=_Dialect):
):
expression.set(
"this",
- expression.this.upper()
- if self.normalization_strategy is NormalizationStrategy.UPPERCASE
- else expression.this.lower(),
+ (
+ expression.this.upper()
+ if self.normalization_strategy is NormalizationStrategy.UPPERCASE
+ else expression.this.lower()
+ ),
)
return expression
@@ -877,9 +879,11 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
"""
agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
+ lambda node: (
+ exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
)
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
@@ -999,10 +1003,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
- normalize = (
- lambda identifier: self.dialect.normalize_identifier(identifier).name
- if identifier
- else None
+ normalize = lambda identifier: (
+ self.dialect.normalize_identifier(identifier).name if identifier else None
)
targets = {normalize(expression.this.this)}
@@ -1012,9 +1014,11 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
for when in expression.expressions:
when.transform(
- lambda node: exp.column(node.this)
- if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
- else node,
+ lambda node: (
+ exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node
+ ),
copy=False,
)
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 2343b35..f55ad70 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -148,8 +148,8 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
def _rename_unless_within_group(
a: str, b: str
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
- return (
- lambda self, expression: self.func(a, *flatten(expression.args.values()))
+ return lambda self, expression: (
+ self.func(a, *flatten(expression.args.values()))
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
else self.func(b, *flatten(expression.args.values()))
)
@@ -273,9 +273,11 @@ class DuckDB(Dialect):
PLACEHOLDER_PARSERS = {
**parser.Parser.PLACEHOLDER_PARSERS,
- TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
- if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
- else None,
+ TokenType.PARAMETER: lambda self: (
+ self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
+ else None
+ ),
}
def _parse_types(
@@ -321,9 +323,11 @@ class DuckDB(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
- if e.expressions and e.expressions[0].find(exp.Select)
- else inline_array_sql(self, e),
+ exp.Array: lambda self, e: (
+ self.func("ARRAY", e.expressions[0])
+ if e.expressions and e.expressions[0].find(exp.Select)
+ else inline_array_sql(self, e)
+ ),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"),
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index dffa41e..060f9bd 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -397,9 +397,11 @@ class Hive(Dialect):
if this and not schema:
return this.transform(
- lambda node: node.replace(exp.DataType.build("text"))
- if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
- else node,
+ lambda node: (
+ node.replace(exp.DataType.build("text"))
+ if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
+ else node
+ ),
copy=False,
)
@@ -409,9 +411,11 @@ class Hive(Dialect):
self,
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return (
- self._parse_csv(self._parse_conjunction)
- if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
- else [],
+ (
+ self._parse_csv(self._parse_conjunction)
+ if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
+ else []
+ ),
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
)
@@ -483,9 +487,9 @@ class Hive(Dialect):
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
- exp.NotNullColumnConstraint: lambda self, e: ""
- if e.args.get("allow_null")
- else "NOT NULL",
+ exp.NotNullColumnConstraint: lambda self, e: (
+ "" if e.args.get("allow_null") else "NOT NULL"
+ ),
exp.VarMap: var_map_sql,
exp.Create: _create_sql,
exp.Quantile: rename_func("PERCENTILE"),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 6ad3718..4591d59 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -166,6 +166,7 @@ class Oracle(Dialect):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_SELECT_INTO = True
+ TZ_TO_WITH_TIME_ZONE = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -179,6 +180,8 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.NCHAR: "NCHAR",
exp.DataType.Type.TEXT: "CLOB",
+ exp.DataType.Type.TIMETZ: "TIME",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 1ca0a78..87f6b02 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -282,6 +282,12 @@ class Postgres(Dialect):
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ "SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()),
+ }
+ PROPERTY_PARSERS.pop("INPUT", None)
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
@@ -385,9 +391,11 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
- exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
- if isinstance(seq_get(e.expressions, 0), exp.Select)
- else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
+ exp.Array: lambda self, e: (
+ f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
+ if isinstance(seq_get(e.expressions, 0), exp.Select)
+ else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]"
+ ),
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@@ -396,6 +404,7 @@ class Postgres(Dialect):
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.CurrentUser: lambda *_: "CURRENT_USER",
exp.DateAdd: _date_add_sql("+"),
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: datestrtodate_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 9b421e7..6cc6030 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -356,6 +356,7 @@ class Presto(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
+ exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.GetPath: path_to_jsonpath(),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index a8e4a42..281167d 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -3,7 +3,6 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
@@ -25,6 +24,9 @@ from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
@@ -297,10 +299,7 @@ def _parse_colon_get_path(
if not self._match(TokenType.COLON):
break
- if self._match_set(self.RANGE_PARSERS):
- this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
-
- return this
+ return self._parse_range(this)
def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
@@ -376,7 +375,7 @@ class Snowflake(Dialect):
and isinstance(expression.parent, exp.Table)
and expression.name.lower() == "dual"
):
- return t.cast(E, expression)
+ return expression # type: ignore
return super().quote_identifier(expression, identify=identify)
@@ -471,6 +470,10 @@ class Snowflake(Dialect):
}
SHOW_PARSERS = {
+ "SCHEMAS": _show_parser("SCHEMAS"),
+ "TERSE SCHEMAS": _show_parser("SCHEMAS"),
+ "OBJECTS": _show_parser("OBJECTS"),
+ "TERSE OBJECTS": _show_parser("OBJECTS"),
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
@@ -580,21 +583,35 @@ class Snowflake(Dialect):
scope = None
scope_kind = None
+ # will identity SHOW TERSE SCHEMAS but not SHOW TERSE PRIMARY KEYS
+ # which is syntactically valid but has no effect on the output
+ terse = self._tokens[self._index - 2].text.upper() == "TERSE"
+
like = self._parse_string() if self._match(TokenType.LIKE) else None
if self._match(TokenType.IN):
if self._match_text_seq("ACCOUNT"):
scope_kind = "ACCOUNT"
elif self._match_set(self.DB_CREATABLES):
- scope_kind = self._prev.text
+ scope_kind = self._prev.text.upper()
if self._curr:
- scope = self._parse_table()
+ scope = self._parse_table_parts()
elif self._curr:
- scope_kind = "TABLE"
- scope = self._parse_table()
+ scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE"
+ scope = self._parse_table_parts()
return self.expression(
- exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
+ exp.Show,
+ **{
+ "terse": terse,
+ "this": this,
+ "like": like,
+ "scope": scope,
+ "scope_kind": scope_kind,
+ "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(),
+ "limit": self._parse_limit(),
+ "from": self._parse_string() if self._match(TokenType.FROM) else None,
+ },
)
def _parse_alter_table_swap(self) -> exp.SwapTable:
@@ -690,6 +707,9 @@ class Snowflake(Dialect):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.Explode: rename_func("FLATTEN"),
exp.Extract: rename_func("DATE_PART"),
+ exp.FromTimeZone: lambda self, e: self.func(
+ "CONVERT_TIMEZONE", e.args.get("zone"), "'UTC'", e.this
+ ),
exp.GenerateSeries: lambda self, e: self.func(
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
),
@@ -820,6 +840,7 @@ class Snowflake(Dialect):
return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
+ terse = "TERSE " if expression.args.get("terse") else ""
like = self.sql(expression, "like")
like = f" LIKE {like}" if like else ""
@@ -830,7 +851,19 @@ class Snowflake(Dialect):
if scope_kind:
scope_kind = f" IN {scope_kind}"
- return f"SHOW {expression.name}{like}{scope_kind}{scope}"
+ starts_with = self.sql(expression, "starts_with")
+ if starts_with:
+ starts_with = f" STARTS WITH {starts_with}"
+
+ limit = self.sql(expression, "limit")
+
+ from_ = self.sql(expression, "from")
+ if from_:
+ from_ = f" FROM {from_}"
+
+ return (
+ f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
+ )
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
@@ -884,3 +917,6 @@ class Snowflake(Dialect):
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")
+
+ def cluster_sql(self, expression: exp.Cluster) -> str:
+ return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index ba73ac0..624f76e 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -80,9 +80,9 @@ class Spark(Spark2):
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
- exp.TryCast: lambda self, e: self.trycast_sql(e)
- if e.args.get("safe")
- else self.cast_sql(e),
+ exp.TryCast: lambda self, e: (
+ self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
+ ),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index e27ba18..e4bb30e 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -129,10 +129,20 @@ class Spark2(Hive):
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
- "TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
- if len(args) == 1
- else format_time_lambda(exp.StrToTime, "spark")(args),
+ "TO_TIMESTAMP": lambda args: (
+ _parse_as_cast("timestamp")(args)
+ if len(args) == 1
+ else format_time_lambda(exp.StrToTime, "spark")(args)
+ ),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
+ "TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
+ this=exp.cast_unless(
+ seq_get(args, 0) or exp.Var(this=""),
+ exp.DataType.build("timestamp"),
+ exp.DataType.build("timestamp"),
+ ),
+ zone=seq_get(args, 1),
+ ),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
@@ -188,6 +198,7 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
+ exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
@@ -255,10 +266,12 @@ class Spark2(Hive):
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
expression,
- sep=": "
- if isinstance(expression.parent, exp.DataType)
- and expression.parent.is_type("struct")
- else sep,
+ sep=(
+ ": "
+ if isinstance(expression.parent, exp.DataType)
+ and expression.parent.is_type("struct")
+ else sep
+ ),
)
class Tokenizer(Hive.Tokenizer):
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 33ec7e1..3795045 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -38,3 +38,4 @@ class Tableau(Dialect):
**parser.Parser.FUNCTIONS,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
+ NO_PAREN_IF_COMMANDS = False
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b9c347c..a5e04da 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -76,9 +76,11 @@ def _format_time_lambda(
format=exp.Literal.string(
format_time(
args[0].name.lower(),
- {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
- if full_format_mapping
- else TSQL.TIME_MAPPING,
+ (
+ {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
+ if full_format_mapping
+ else TSQL.TIME_MAPPING
+ ),
)
),
)
@@ -264,6 +266,15 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
)
+def _parse_len(args: t.List) -> exp.Length:
+ this = seq_get(args, 0)
+
+ if this and not this.is_string:
+ this = exp.cast(this, exp.DataType.Type.TEXT)
+
+ return exp.Length(this=this)
+
+
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
@@ -431,7 +442,7 @@ class TSQL(Dialect):
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
- "LEN": exp.Length.from_arg_list,
+ "LEN": _parse_len,
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
@@ -469,6 +480,7 @@ class TSQL(Dialect):
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
STRING_ALIASES = True
+ NO_PAREN_IF_COMMANDS = False
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -478,9 +490,11 @@ class TSQL(Dialect):
See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax
"""
return [
- exp.alias_(projection.expression, projection.this.this, copy=False)
- if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
- else projection
+ (
+ exp.alias_(projection.expression, projection.this.this, copy=False)
+ if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
+ else projection
+ )
for projection in super()._parse_projections()
]
@@ -702,7 +716,6 @@ class TSQL(Dialect):
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
- exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
@@ -922,3 +935,11 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True, sep=" ")
return f"CONSTRAINT {this} {expressions}"
+
+ def length_sql(self, expression: exp.Length) -> str:
+ this = expression.this
+ if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT):
+ this_sql = self.sql(this, "this")
+ else:
+ this_sql = self.sql(this)
+ return self.func("LEN", this_sql)
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 3277e65..7ff9608 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -392,9 +392,9 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
names = {e.name.lower() for e in e.expressions}
e = e.transform(
- lambda n: exp.var(n.name)
- if isinstance(n, exp.Identifier) and n.name.lower() in names
- else n
+ lambda n: (
+ exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n
+ )
)
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
@@ -438,9 +438,9 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
- exp.Is: lambda self, e: self.binary(e, "==")
- if isinstance(e.this, exp.Literal)
- else self.binary(e, "is"),
+ exp.Is: lambda self, e: (
+ self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
+ ),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index ddad8f8..a95a73e 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -23,7 +23,6 @@ from copy import deepcopy
from enum import auto
from functools import reduce
-from sqlglot._typing import E
from sqlglot.errors import ErrorLevel, ParseError
from sqlglot.helper import (
AutoName,
@@ -36,8 +35,7 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
- from typing_extensions import Literal as Lit
-
+ from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import DialectType
@@ -389,7 +387,7 @@ class Expression(metaclass=_Expression):
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
ancestor = ancestor.parent
- return t.cast(E, ancestor)
+ return ancestor # type: ignore
@property
def parent_select(self) -> t.Optional[Select]:
@@ -555,12 +553,10 @@ class Expression(metaclass=_Expression):
return new_node
@t.overload
- def replace(self, expression: E) -> E:
- ...
+ def replace(self, expression: E) -> E: ...
@t.overload
- def replace(self, expression: None) -> None:
- ...
+ def replace(self, expression: None) -> None: ...
def replace(self, expression):
"""
@@ -781,13 +777,16 @@ class Expression(metaclass=_Expression):
this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
- unnest=Unnest(
- expressions=[
- maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
- ]
- )
- if unnest
- else None,
+ unnest=(
+ Unnest(
+ expressions=[
+ maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts)
+ for e in ensure_list(unnest)
+ ]
+ )
+ if unnest
+ else None
+ ),
)
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
@@ -926,7 +925,7 @@ class DerivedTable(Expression):
class Unionable(Expression):
def union(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
- ) -> Unionable:
+ ) -> Union:
"""
Builds a UNION expression.
@@ -1134,9 +1133,12 @@ class SetItem(Expression):
class Show(Expression):
arg_types = {
"this": True,
+ "terse": False,
"target": False,
"offset": False,
+ "starts_with": False,
"limit": False,
+ "from": False,
"like": False,
"where": False,
"db": False,
@@ -1274,9 +1276,14 @@ class AlterColumn(Expression):
"using": False,
"default": False,
"drop": False,
+ "comment": False,
}
+class RenameColumn(Expression):
+ arg_types = {"this": True, "to": True, "exists": False}
+
+
class RenameTable(Expression):
pass
@@ -1402,7 +1409,7 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
class GeneratedAsRowColumnConstraint(ColumnConstraintKind):
- arg_types = {"start": True, "hidden": False}
+ arg_types = {"start": False, "hidden": False}
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
@@ -1667,6 +1674,7 @@ class Index(Expression):
"unique": False,
"primary": False,
"amp": False, # teradata
+ "include": False,
"partition_by": False, # teradata
"where": False, # postgres partial indexes
}
@@ -2016,7 +2024,13 @@ class AutoRefreshProperty(Property):
class BlockCompressionProperty(Property):
- arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
+ arg_types = {
+ "autotemp": False,
+ "always": False,
+ "default": False,
+ "manual": False,
+ "never": False,
+ }
class CharacterSetProperty(Property):
@@ -2089,6 +2103,10 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
+class InheritsProperty(Property):
+ arg_types = {"expressions": True}
+
+
class InputModelProperty(Property):
arg_types = {"this": True}
@@ -2099,11 +2117,11 @@ class OutputModelProperty(Property):
class IsolatedLoadingProperty(Property):
arg_types = {
- "no": True,
- "concurrent": True,
- "for_all": True,
- "for_insert": True,
- "for_none": True,
+ "no": False,
+ "concurrent": False,
+ "for_all": False,
+ "for_insert": False,
+ "for_none": False,
}
@@ -2264,6 +2282,10 @@ class SetProperty(Property):
arg_types = {"multi": True}
+class SetConfigProperty(Property):
+ arg_types = {"this": True}
+
+
class SettingsProperty(Property):
arg_types = {"expressions": True}
@@ -2407,13 +2429,16 @@ class Tuple(Expression):
this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
- unnest=Unnest(
- expressions=[
- maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest)
- ]
- )
- if unnest
- else None,
+ unnest=(
+ Unnest(
+ expressions=[
+ maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts)
+ for e in ensure_list(unnest)
+ ]
+ )
+ if unnest
+ else None
+ ),
)
@@ -3631,6 +3656,8 @@ class DataType(Expression):
class Type(AutoName):
ARRAY = auto()
+ AGGREGATEFUNCTION = auto()
+ SIMPLEAGGREGATEFUNCTION = auto()
BIGDECIMAL = auto()
BIGINT = auto()
BIGSERIAL = auto()
@@ -4162,6 +4189,10 @@ class AtTimeZone(Expression):
arg_types = {"this": True, "zone": True}
+class FromTimeZone(Expression):
+ arg_types = {"this": True, "zone": True}
+
+
class Between(Predicate):
arg_types = {"this": True, "low": True, "high": True}
@@ -5456,8 +5487,7 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
-) -> E:
- ...
+) -> E: ...
@t.overload
@@ -5469,8 +5499,7 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
-) -> E:
- ...
+) -> E: ...
def maybe_parse(
@@ -5522,13 +5551,11 @@ def maybe_parse(
@t.overload
-def maybe_copy(instance: None, copy: bool = True) -> None:
- ...
+def maybe_copy(instance: None, copy: bool = True) -> None: ...
@t.overload
-def maybe_copy(instance: E, copy: bool = True) -> E:
- ...
+def maybe_copy(instance: E, copy: bool = True) -> E: ...
def maybe_copy(instance, copy=True):
@@ -6151,15 +6178,13 @@ SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
@t.overload
-def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
- ...
+def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
@t.overload
def to_identifier(
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
-) -> Identifier:
- ...
+) -> Identifier: ...
def to_identifier(name, quoted=None, copy=True):
@@ -6231,13 +6256,11 @@ def to_interval(interval: str | Literal) -> Interval:
@t.overload
-def to_table(sql_path: str | Table, **kwargs) -> Table:
- ...
+def to_table(sql_path: str | Table, **kwargs) -> Table: ...
@t.overload
-def to_table(sql_path: None, **kwargs) -> None:
- ...
+def to_table(sql_path: None, **kwargs) -> None: ...
def to_table(
@@ -6562,6 +6585,34 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
)
+def rename_column(
+ table_name: str | Table,
+ old_column_name: str | Column,
+ new_column_name: str | Column,
+ exists: t.Optional[bool] = None,
+) -> AlterTable:
+ """Build ALTER TABLE... RENAME COLUMN... expression
+
+ Args:
+ table_name: Name of the table
+ old_column: The old name of the column
+ new_column: The new name of the column
+ exists: Whether or not to add the `IF EXISTS` clause
+
+ Returns:
+ Alter table expression
+ """
+ table = to_table(table_name)
+ old_column = to_column(old_column_name)
+ new_column = to_column(new_column_name)
+ return AlterTable(
+ this=table,
+ actions=[
+ RenameColumn(this=old_column, to=new_column, exists=exists),
+ ],
+ )
+
+
def convert(value: t.Any, copy: bool = False) -> Expression:
"""Convert a python value into an expression object.
@@ -6581,7 +6632,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
if isinstance(value, bool):
return Boolean(this=value)
if value is None or (isinstance(value, float) and math.isnan(value)):
- return NULL
+ return null()
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, datetime.datetime):
@@ -6674,9 +6725,11 @@ def table_name(table: Table | str, dialect: DialectType = None, identify: bool =
raise ValueError(f"Cannot parse {table}")
return ".".join(
- part.sql(dialect=dialect, identify=True, copy=False)
- if identify or not SAFE_IDENTIFIER_RE.match(part.name)
- else part.name
+ (
+ part.sql(dialect=dialect, identify=True, copy=False)
+ if identify or not SAFE_IDENTIFIER_RE.match(part.name)
+ else part.name
+ )
for part in table.parts
)
@@ -6942,9 +6995,3 @@ def null() -> Null:
Returns a Null expression.
"""
return Null()
-
-
-# TODO: deprecate this
-TRUE = Boolean(this=True)
-FALSE = Boolean(this=False)
-NULL = Null()
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 977185f..8e3ff9b 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -77,6 +77,7 @@ class Generator:
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
+ exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
@@ -96,6 +97,7 @@ class Generator:
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
+ exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlReadWriteProperty: lambda self, e: e.name,
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
@@ -323,6 +325,7 @@ class Generator:
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
+ exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA,
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
@@ -353,6 +356,7 @@ class Generator:
exp.Set: exp.Properties.Location.POST_SCHEMA,
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
+ exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
@@ -568,9 +572,11 @@ class Generator:
def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
- self.sql(expression)
- if isinstance(expression, (exp.Select, exp.Union))
- else self.sql(expression, "this"),
+ (
+ self.sql(expression)
+ if isinstance(expression, (exp.Select, exp.Union))
+ else self.sql(expression, "this")
+ ),
level=1,
pad=0,
)
@@ -605,9 +611,11 @@ class Generator:
lines = sql.split("\n")
return "\n".join(
- line
- if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
- else f"{' ' * (level * self._indent + pad)}{line}"
+ (
+ line
+ if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
+ else f"{' ' * (level * self._indent + pad)}{line}"
+ )
for i, line in enumerate(lines)
)
@@ -775,7 +783,7 @@ class Generator:
def generatedasrowcolumnconstraint_sql(
self, expression: exp.GeneratedAsRowColumnConstraint
) -> str:
- start = "START" if expression.args["start"] else "END"
+ start = "START" if expression.args.get("start") else "END"
hidden = " HIDDEN" if expression.args.get("hidden") else ""
return f"GENERATED ALWAYS AS ROW {start}{hidden}"
@@ -1111,7 +1119,10 @@ class Generator:
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
where = self.sql(expression, "where")
- return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}{where}"
+ include = self.expressions(expression, key="include", flat=True)
+ if include:
+ include = f" INCLUDE ({include})"
+ return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
@@ -2017,9 +2028,11 @@ class Generator:
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
- self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
- if expression.args.get("windows")
- else "",
+ (
+ self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
+ if expression.args.get("windows")
+ else ""
+ ),
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
@@ -2552,6 +2565,11 @@ class Generator:
zone = self.sql(expression, "zone")
return f"{this} AT TIME ZONE {zone}"
+ def fromtimezone_sql(self, expression: exp.FromTimeZone) -> str:
+ this = self.sql(expression, "this")
+ zone = self.sql(expression, "zone")
+ return f"{this} AT TIME ZONE {zone} AT TIME ZONE 'UTC'"
+
def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
@@ -2669,6 +2687,10 @@ class Generator:
if default:
return f"ALTER COLUMN {this} SET DEFAULT {default}"
+ comment = self.sql(expression, "comment")
+ if comment:
+ return f"ALTER COLUMN {this} COMMENT {comment}"
+
if not expression.args.get("drop"):
self.unsupported("Unsupported ALTER COLUMN syntax")
@@ -2683,6 +2705,12 @@ class Generator:
this = self.sql(expression, "this")
return f"RENAME TO {this}"
+ def renamecolumn_sql(self, expression: exp.RenameColumn) -> str:
+ exists = " IF EXISTS" if expression.args.get("exists") else ""
+ old_column = self.sql(expression, "this")
+ new_column = self.sql(expression, "to")
+ return f"RENAME COLUMN{exists} {old_column} TO {new_column}"
+
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 349c8c8..de737be 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -53,13 +53,11 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
@t.overload
-def ensure_list(value: t.Collection[T]) -> t.List[T]:
- ...
+def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
@t.overload
-def ensure_list(value: T) -> t.List[T]:
- ...
+def ensure_list(value: T) -> t.List[T]: ...
def ensure_list(value):
@@ -81,13 +79,11 @@ def ensure_list(value):
@t.overload
-def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
- ...
+def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
@t.overload
-def ensure_collection(value: T) -> t.Collection[T]:
- ...
+def ensure_collection(value: T) -> t.Collection[T]: ...
def ensure_collection(value):
diff --git a/sqlglot/jsonpath.py b/sqlglot/jsonpath.py
new file mode 100644
index 0000000..c410d11
--- /dev/null
+++ b/sqlglot/jsonpath.py
@@ -0,0 +1,215 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot.errors import ParseError
+from sqlglot.expressions import SAFE_IDENTIFIER_RE
+from sqlglot.tokens import Token, Tokenizer, TokenType
+
+if t.TYPE_CHECKING:
+ from sqlglot._typing import Lit
+
+
+class JSONPathTokenizer(Tokenizer):
+ SINGLE_TOKENS = {
+ "(": TokenType.L_PAREN,
+ ")": TokenType.R_PAREN,
+ "[": TokenType.L_BRACKET,
+ "]": TokenType.R_BRACKET,
+ ":": TokenType.COLON,
+ ",": TokenType.COMMA,
+ "-": TokenType.DASH,
+ ".": TokenType.DOT,
+ "?": TokenType.PLACEHOLDER,
+ "@": TokenType.PARAMETER,
+ "'": TokenType.QUOTE,
+ '"': TokenType.QUOTE,
+ "$": TokenType.DOLLAR,
+ "*": TokenType.STAR,
+ }
+
+ KEYWORDS = {
+ "..": TokenType.DOT,
+ }
+
+ IDENTIFIER_ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\"]
+
+
+JSONPathNode = t.Dict[str, t.Any]
+
+
+def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode:
+ node = {"kind": kind, **kwargs}
+
+ if value is not None:
+ node["value"] = value
+
+ return node
+
+
+def parse(path: str) -> t.List[JSONPathNode]:
+ """Takes in a JSONPath string and converts into a list of nodes."""
+ tokens = JSONPathTokenizer().tokenize(path)
+ size = len(tokens)
+
+ i = 0
+
+ def _curr() -> t.Optional[TokenType]:
+ return tokens[i].token_type if i < size else None
+
+ def _prev() -> Token:
+ return tokens[i - 1]
+
+ def _advance() -> Token:
+ nonlocal i
+ i += 1
+ return _prev()
+
+ def _error(msg: str) -> str:
+ return f"{msg} at index {i}: {path}"
+
+ @t.overload
+ def _match(token_type: TokenType, raise_unmatched: Lit[True] = True) -> Token:
+ pass
+
+ @t.overload
+ def _match(token_type: TokenType, raise_unmatched: Lit[False] = False) -> t.Optional[Token]:
+ pass
+
+ def _match(token_type, raise_unmatched=False):
+ if _curr() == token_type:
+ return _advance()
+ if raise_unmatched:
+ raise ParseError(_error(f"Expected {token_type}"))
+ return None
+
+ def _parse_literal() -> t.Any:
+ token = _match(TokenType.STRING) or _match(TokenType.IDENTIFIER)
+ if token:
+ return token.text
+ if _match(TokenType.STAR):
+ return _node("wildcard")
+ if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN):
+ script = _prev().text == "("
+ start = i
+
+ while True:
+ if _match(TokenType.L_BRACKET):
+ _parse_bracket() # nested call which we can throw away
+ if _curr() in (TokenType.R_BRACKET, None):
+ break
+ _advance()
+ return _node(
+ "script" if script else "filter", path[tokens[start].start : tokens[i].end]
+ )
+
+ number = "-" if _match(TokenType.DASH) else ""
+
+ token = _match(TokenType.NUMBER)
+ if token:
+ number += token.text
+
+ if number:
+ return int(number)
+ return False
+
+ def _parse_slice() -> t.Any:
+ start = _parse_literal()
+ end = _parse_literal() if _match(TokenType.COLON) else None
+ step = _parse_literal() if _match(TokenType.COLON) else None
+
+ if end is None and step is None:
+ return start
+ return _node("slice", start=start, end=end, step=step)
+
+ def _parse_bracket() -> JSONPathNode:
+ literal = _parse_slice()
+
+ if isinstance(literal, str) or literal is not False:
+ indexes = [literal]
+ while _match(TokenType.COMMA):
+ literal = _parse_slice()
+
+ if literal:
+ indexes.append(literal)
+
+ if len(indexes) == 1:
+ if isinstance(literal, str):
+ node = _node("key", indexes[0])
+ elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"):
+ node = _node("selector", indexes[0])
+ else:
+ node = _node("subscript", indexes[0])
+ else:
+ node = _node("union", indexes)
+ else:
+ raise ParseError(_error("Cannot have empty segment"))
+
+ _match(TokenType.R_BRACKET, raise_unmatched=True)
+
+ return node
+
+ nodes = []
+
+ while _curr():
+ if _match(TokenType.DOLLAR):
+ nodes.append(_node("root"))
+ elif _match(TokenType.DOT):
+ recursive = _prev().text == ".."
+ value = _match(TokenType.VAR) or _match(TokenType.STAR)
+ nodes.append(
+ _node("recursive" if recursive else "child", value=value.text if value else None)
+ )
+ elif _match(TokenType.L_BRACKET):
+ nodes.append(_parse_bracket())
+ elif _match(TokenType.VAR):
+ nodes.append(_node("key", _prev().text))
+ elif _match(TokenType.STAR):
+ nodes.append(_node("wildcard"))
+ elif _match(TokenType.PARAMETER):
+ nodes.append(_node("current"))
+ else:
+ raise ParseError(_error(f"Unexpected {tokens[i].token_type}"))
+
+ return nodes
+
+
+MAPPING = {
+ "child": lambda n: f".{n['value']}" if n.get("value") is not None else "",
+ "filter": lambda n: f"?{n['value']}",
+ "key": lambda n: (
+ f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]'
+ ),
+ "recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..",
+ "root": lambda _: "$",
+ "script": lambda n: f"({n['value']}",
+ "slice": lambda n: ":".join(
+ "" if p is False else generate([p])
+ for p in [n["start"], n["end"], n["step"]]
+ if p is not None
+ ),
+ "selector": lambda n: f"[{generate([n['value']])}]",
+ "subscript": lambda n: f"[{generate([n['value']])}]",
+ "union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]",
+ "wildcard": lambda _: "*",
+}
+
+
+def generate(
+ nodes: t.List[JSONPathNode],
+ mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None,
+) -> str:
+ mapping = MAPPING if mapping is None else mapping
+ path = []
+
+ for node in nodes:
+ if isinstance(node, dict):
+ path.append(mapping[node["kind"]](node))
+ elif isinstance(node, str):
+ escaped = node.replace('"', '\\"')
+ path.append(f'"{escaped}"')
+ else:
+ path.append(str(node))
+
+ return "".join(path)
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 09bf201..bdd1d14 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -41,9 +41,9 @@ class Node:
else:
label = node.expression.sql(pretty=True, dialect=dialect)
source = node.source.transform(
- lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
- if n is node.expression
- else n,
+ lambda n: (
+ exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
+ ),
copy=False,
).sql(pretty=True, dialect=dialect)
title = f"<pre>{source}</pre>"
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index d0168d5..a2a86cd 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -4,7 +4,6 @@ import functools
import typing as t
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.helper import (
ensure_list,
is_date_unit,
@@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
- B = t.TypeVar("B", bound=exp.Binary)
+ from sqlglot._typing import B, E
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
BinaryCoercions = t.Dict[
@@ -480,6 +479,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return self._annotate_args(expression)
@t.no_type_check
+ def _annotate_struct_value(
+ self, expression: exp.Expression
+ ) -> t.Optional[exp.DataType] | exp.ColumnDef:
+ alias = expression.args.get("alias")
+ if alias:
+ return exp.ColumnDef(this=alias.copy(), kind=expression.type)
+
+ # Case: key = value or key := value
+ if expression.expression:
+ return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
+
+ return expression.type
+
+ @t.no_type_check
def _annotate_by_args(
self,
expression: E,
@@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
if struct:
- expressions = [
- expr.type
- if not expr.args.get("alias")
- else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
- for expr in expressions
- ]
-
self._set_type(
expression,
- exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
+ exp.DataType(
+ this=exp.DataType.Type.STRUCT,
+ expressions=[self._annotate_struct_value(expr) for expr in expressions],
+ nested=True,
+ ),
)
return expression
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 3361a33..f2a0990 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -3,18 +3,18 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
@t.overload
-def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
- ...
+def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
@t.overload
-def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
- ...
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression, dialect=None):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index a6397ae..1656727 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -4,7 +4,6 @@ import itertools
import typing as t
from sqlglot import alias, exp
-from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
@@ -12,6 +11,9 @@ from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def qualify_columns(
expression: exp.Expression,
@@ -210,7 +212,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not node:
return
- for column, *_ in walk_in_scope(node):
+ for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
if not isinstance(column, exp.Column):
continue
@@ -525,6 +527,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
+ copy=False,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index e0fe641..d460e81 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -4,12 +4,14 @@ import itertools
import typing as t
from sqlglot import alias, exp
-from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def qualify_tables(
expression: E,
@@ -46,6 +48,18 @@ def qualify_tables(
db = exp.parse_identifier(db, dialect=dialect) if db else None
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
+ def _qualify(table: exp.Table) -> None:
+ if isinstance(table.this, exp.Identifier):
+ if not table.args.get("db"):
+ table.set("db", db)
+ if not table.args.get("catalog") and table.args.get("db"):
+ table.set("catalog", catalog)
+
+ if not isinstance(expression, exp.Subqueryable):
+ for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
+ if isinstance(node, exp.Table):
+ _qualify(node)
+
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
if isinstance(derived_table, exp.Subquery):
@@ -66,11 +80,7 @@ def qualify_tables(
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
- if isinstance(source.this, exp.Identifier):
- if not source.args.get("db"):
- source.set("db", db)
- if not source.args.get("catalog") and source.args.get("db"):
- source.set("catalog", catalog)
+ _qualify(source)
pivots = pivots = source.args.get("pivots")
if not source.alias:
@@ -107,5 +117,14 @@ def qualify_tables(
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
+ else:
+ for node, parent, _ in scope.walk():
+ if (
+ isinstance(node, exp.Table)
+ and not node.alias
+ and isinstance(parent, (exp.From, exp.Join))
+ ):
+ # Mutates the table by attaching an alias to it
+ alias(node, node.name, copy=False, table=True)
return expression
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index a3f08d5..16cd548 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -323,9 +323,14 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
- self._external_columns = [
- c for c in self.columns if c.table not in self.selected_sources
- ]
+ if isinstance(self.expression, exp.Union):
+ left, right = self.union_scopes
+ self._external_columns = left.external_columns + right.external_columns
+ else:
+ self._external_columns = [
+ c for c in self.columns if c.table not in self.selected_sources
+ ]
+
return self._external_columns
@property
@@ -477,11 +482,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Args:
expression (exp.Expression): expression to traverse
+
Returns:
list[Scope]: scope instances
"""
if isinstance(expression, exp.Unionable) or (
- isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
+ isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable)
):
return list(_traverse_scope(Scope(expression)))
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 25d4e75..d5b9119 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1068,9 +1068,11 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
- exp.DataType.Type.DATETIME
- if isinstance(date, datetime.datetime)
- else exp.DataType.Type.DATE,
+ (
+ exp.DataType.Type.DATETIME
+ if isinstance(date, datetime.datetime)
+ else exp.DataType.Type.DATE
+ ),
)
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 4d35175..26f4159 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -50,11 +50,12 @@ def unnest(select, parent_select, next_alias_name):
):
return
+ clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
+
# This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
column = exp.column(select.selects[0].alias_or_name, alias)
- clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
clause_parent_select = clause.parent_select if clause else None
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
@@ -84,12 +85,18 @@ def unnest(select, parent_select, next_alias_name):
column = _other_operand(predicate)
value = select.selects[0]
- on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
- _replace(predicate, f"NOT {on.right} IS NULL")
+ join_key = exp.column(value.alias, alias)
+ join_key_not_null = join_key.is_(exp.null()).not_()
+
+ if isinstance(clause, exp.Join):
+ _replace(predicate, exp.true())
+ parent_select.where(join_key_not_null, copy=False)
+ else:
+ _replace(predicate, join_key_not_null)
parent_select.join(
select.group_by(value.this, copy=False),
- on=on,
+ on=column.eq(join_key),
join_type="LEFT",
join_alias=alias,
copy=False,
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 790ee0d..c091605 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -12,9 +12,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
- from typing_extensions import Literal
-
- from sqlglot._typing import E
+ from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import Dialect, DialectType
logger = logging.getLogger("sqlglot")
@@ -148,6 +146,11 @@ class Parser(metaclass=_Parser):
TokenType.ENUM16,
}
+ AGGREGATE_TYPE_TOKENS = {
+ TokenType.AGGREGATEFUNCTION,
+ TokenType.SIMPLEAGGREGATEFUNCTION,
+ }
+
TYPE_TOKENS = {
TokenType.BIT,
TokenType.BOOLEAN,
@@ -241,6 +244,7 @@ class Parser(metaclass=_Parser):
TokenType.NULL,
*ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
+ *AGGREGATE_TYPE_TOKENS,
}
SIGNED_TO_UNSIGNED_TYPE_TOKEN = {
@@ -653,9 +657,11 @@ class Parser(metaclass=_Parser):
PLACEHOLDER_PARSERS = {
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
TokenType.PARAMETER: lambda self: self._parse_parameter(),
- TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
- if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
- else None,
+ TokenType.COLON: lambda self: (
+ self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
+ else None
+ ),
}
RANGE_PARSERS = {
@@ -705,6 +711,9 @@ class Parser(metaclass=_Parser):
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
+ "INHERITS": lambda self: self.expression(
+ exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table)
+ ),
"INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()),
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
@@ -822,6 +831,7 @@ class Parser(metaclass=_Parser):
ALTER_PARSERS = {
"ADD": lambda self: self._parse_alter_table_add(),
"ALTER": lambda self: self._parse_alter_table_alter(),
+ "CLUSTER BY": lambda self: self._parse_cluster(wrapped=True),
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
"DROP": lambda self: self._parse_alter_table_drop(),
"RENAME": lambda self: self._parse_alter_table_rename(),
@@ -973,6 +983,9 @@ class Parser(metaclass=_Parser):
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
+ # parses no parenthesis if statements as commands
+ NO_PAREN_IF_COMMANDS = True
+
__slots__ = (
"error_level",
"error_message_context",
@@ -1207,7 +1220,20 @@ class Parser(metaclass=_Parser):
if index != self._index:
self._advance(index - self._index)
+ def _warn_unsupported(self) -> None:
+ if len(self._tokens) <= 1:
+ return
+
+ # We use _find_sql because self.sql may comprise multiple chunks, and we're only
+ # interested in emitting a warning for the one being currently processed.
+ sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context]
+
+ logger.warning(
+ f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'."
+ )
+
def _parse_command(self) -> exp.Command:
+ self._warn_unsupported()
return self.expression(
exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
)
@@ -1329,8 +1355,10 @@ class Parser(metaclass=_Parser):
start = self._prev
comments = self._prev_comments
- replace = start.text.upper() == "REPLACE" or self._match_pair(
- TokenType.OR, TokenType.REPLACE
+ replace = (
+ start.token_type == TokenType.REPLACE
+ or self._match_pair(TokenType.OR, TokenType.REPLACE)
+ or self._match_pair(TokenType.OR, TokenType.ALTER)
)
unique = self._match(TokenType.UNIQUE)
@@ -1440,6 +1468,9 @@ class Parser(metaclass=_Parser):
exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
)
+ if self._curr:
+ return self._parse_as_command(start)
+
return self.expression(
exp.Create,
comments=comments,
@@ -1516,11 +1547,13 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.FileFormatProperty,
- this=self.expression(
- exp.InputOutputFormat, input_format=input_format, output_format=output_format
- )
- if input_format or output_format
- else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
+ this=(
+ self.expression(
+ exp.InputOutputFormat, input_format=input_format, output_format=output_format
+ )
+ if input_format or output_format
+ else self._parse_var_or_string() or self._parse_number() or self._parse_id_var()
+ ),
)
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
@@ -1632,8 +1665,15 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
- def _parse_cluster(self) -> exp.Cluster:
- return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
+ def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster:
+ return self.expression(
+ exp.Cluster,
+ expressions=(
+ self._parse_wrapped_csv(self._parse_ordered)
+ if wrapped
+ else self._parse_csv(self._parse_ordered)
+ ),
+ )
def _parse_clustered_by(self) -> exp.ClusteredByProperty:
self._match_text_seq("BY")
@@ -2681,6 +2721,8 @@ class Parser(metaclass=_Parser):
else:
columns = None
+ include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
+
return self.expression(
exp.Index,
this=index,
@@ -2690,6 +2732,7 @@ class Parser(metaclass=_Parser):
unique=unique,
primary=primary,
amp=amp,
+ include=include,
partition_by=self._parse_partition_by(),
where=self._parse_where(),
)
@@ -3380,8 +3423,8 @@ class Parser(metaclass=_Parser):
def _parse_comparison(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_range, self.COMPARISON)
- def _parse_range(self) -> t.Optional[exp.Expression]:
- this = self._parse_bitwise()
+ def _parse_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
+ this = this or self._parse_bitwise()
negate = self._match(TokenType.NOT)
if self._match_set(self.RANGE_PARSERS):
@@ -3535,14 +3578,21 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_factor, self.TERM)
def _parse_factor(self) -> t.Optional[exp.Expression]:
- if self.EXPONENT:
- factor = self._parse_tokens(self._parse_exponent, self.FACTOR)
- else:
- factor = self._parse_tokens(self._parse_unary, self.FACTOR)
- if isinstance(factor, exp.Div):
- factor.args["typed"] = self.dialect.TYPED_DIVISION
- factor.args["safe"] = self.dialect.SAFE_DIVISION
- return factor
+ parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary
+ this = parse_method()
+
+ while self._match_set(self.FACTOR):
+ this = self.expression(
+ self.FACTOR[self._prev.token_type],
+ this=this,
+ comments=self._prev_comments,
+ expression=parse_method(),
+ )
+ if isinstance(this, exp.Div):
+ this.args["typed"] = self.dialect.TYPED_DIVISION
+ this.args["safe"] = self.dialect.SAFE_DIVISION
+
+ return this
def _parse_exponent(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_unary, self.EXPONENT)
@@ -3617,6 +3667,7 @@ class Parser(metaclass=_Parser):
return exp.DataType.build(type_name, udt=True)
else:
+ self._retreat(self._index - 1)
return None
else:
return None
@@ -3631,6 +3682,7 @@ class Parser(metaclass=_Parser):
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
+ is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS
expressions = None
maybe_func = False
@@ -3645,6 +3697,18 @@ class Parser(metaclass=_Parser):
)
elif type_token in self.ENUM_TYPE_TOKENS:
expressions = self._parse_csv(self._parse_equality)
+ elif is_aggregate:
+ func_or_ident = self._parse_function(anonymous=True) or self._parse_id_var(
+ any_token=False, tokens=(TokenType.VAR,)
+ )
+ if not func_or_ident or not self._match(TokenType.COMMA):
+ return None
+ expressions = self._parse_csv(
+ lambda: self._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
+ )
+ expressions.insert(0, func_or_ident)
else:
expressions = self._parse_csv(self._parse_type_size)
@@ -4413,6 +4477,10 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
else:
index = self._index - 1
+
+ if self.NO_PAREN_IF_COMMANDS and index == 0:
+ return self._parse_as_command(self._prev)
+
condition = self._parse_conjunction()
if not condition:
@@ -4624,12 +4692,10 @@ class Parser(metaclass=_Parser):
return None
@t.overload
- def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
- ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
- def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
- ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
star = self._parse_star()
@@ -4974,11 +5040,12 @@ class Parser(metaclass=_Parser):
if alias:
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
+ column = this.this
# Moves the comment next to the alias in `expr /* comment */ AS alias`
- if not this.comments and this.this.comments:
- this.comments = this.this.comments
- this.this.comments = None
+ if not this.comments and column and column.comments:
+ this.comments = column.comments
+ column.comments = None
return this
@@ -5244,7 +5311,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("CHECK"):
expression = self._parse_wrapped(self._parse_conjunction)
- enforced = self._match_text_seq("ENFORCED")
+ enforced = self._match_text_seq("ENFORCED") or False
return self.expression(
exp.AddConstraint, this=this, expression=expression, enforced=enforced
@@ -5278,6 +5345,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AlterColumn, this=column, drop=True)
if self._match_pair(TokenType.SET, TokenType.DEFAULT):
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
+ if self._match(TokenType.COMMENT):
+ return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
self._match_text_seq("SET", "DATA")
return self.expression(
@@ -5298,7 +5367,18 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_drop_column)
- def _parse_alter_table_rename(self) -> exp.RenameTable:
+ def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]:
+ if self._match(TokenType.COLUMN):
+ exists = self._parse_exists()
+ old_column = self._parse_column()
+ to = self._match_text_seq("TO")
+ new_column = self._parse_column()
+
+ if old_column is None or to is None or new_column is None:
+ return None
+
+ return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists)
+
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
@@ -5319,7 +5399,7 @@ class Parser(metaclass=_Parser):
if parser:
actions = ensure_list(parser(self))
- if not self._curr:
+ if not self._curr and actions:
return self.expression(
exp.AlterTable,
this=this,
@@ -5467,6 +5547,7 @@ class Parser(metaclass=_Parser):
self._advance()
text = self._find_sql(start, self._prev)
size = len(start.text)
+ self._warn_unsupported()
return exp.Command(this=text[:size], expression=text[size:])
def _parse_dict_property(self, this: str) -> exp.DictProperty:
@@ -5634,7 +5715,7 @@ class Parser(metaclass=_Parser):
if advance:
self._advance()
return True
- return False
+ return None
def _match_text_seq(self, *texts, advance=True):
index = self._index
@@ -5643,7 +5724,7 @@ class Parser(metaclass=_Parser):
self._advance()
else:
self._retreat(index)
- return False
+ return None
if not advance:
self._retreat(index)
@@ -5651,14 +5732,12 @@ class Parser(metaclass=_Parser):
return True
@t.overload
- def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
- ...
+ def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
@t.overload
def _replace_columns_with_dots(
self, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]:
- ...
+ ) -> t.Optional[exp.Expression]: ...
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 8acd89f..13f72d6 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -106,6 +106,19 @@ class Schema(abc.ABC):
name = column if isinstance(column, str) else column.name
return name in self.column_names(table, dialect=dialect, normalize=normalize)
+ @abc.abstractmethod
+ def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
+ """
+ Returns the schema of a given table.
+
+ Args:
+ table: the target table.
+ raise_on_missing: whether or not to raise in case the schema is not found.
+
+ Returns:
+ The schema of the target table.
+ """
+
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
@@ -156,11 +169,9 @@ class AbstractMappingSchema:
return [table.this.name]
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
- def find(
- self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
- ) -> t.Optional[t.Any]:
+ def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
- value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
+ value, trie = in_trie(self.mapping_trie, parts)
if value == TrieResult.FAILED:
return None
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index d8fb98b..8a363d2 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -191,6 +191,8 @@ class TokenType(AutoName):
FIXEDSTRING = auto()
LOWCARDINALITY = auto()
NESTED = auto()
+ AGGREGATEFUNCTION = auto()
+ SIMPLEAGGREGATEFUNCTION = auto()
UNKNOWN = auto()
# keywords