summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/sql/functions.py4
-rw-r--r--sqlglot/dialects/databricks.py3
-rw-r--r--sqlglot/dialects/duckdb.py7
-rw-r--r--sqlglot/dialects/oracle.py24
-rw-r--r--sqlglot/dialects/presto.py2
-rw-r--r--sqlglot/dialects/redshift.py21
-rw-r--r--sqlglot/dialects/spark.py3
-rw-r--r--sqlglot/dialects/tsql.py7
-rw-r--r--sqlglot/executor/table.py2
-rw-r--r--sqlglot/expressions.py36
-rw-r--r--sqlglot/generator.py19
-rw-r--r--sqlglot/helper.py8
-rw-r--r--sqlglot/optimizer/normalize.py37
-rw-r--r--sqlglot/optimizer/optimize_joins.py6
-rw-r--r--sqlglot/optimizer/pushdown_projections.py10
-rw-r--r--sqlglot/optimizer/scope.py18
-rw-r--r--sqlglot/optimizer/simplify.py93
-rw-r--r--sqlglot/parser.py85
-rw-r--r--sqlglot/schema.py51
-rw-r--r--sqlglot/tokens.py3
-rw-r--r--sqlglot/transforms.py6
21 files changed, 344 insertions, 101 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 9ab00d5..d98feee 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -1019,11 +1019,11 @@ def posexplode(col: ColumnOrName) -> Column:
def explode_outer(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "EXPLODE_OUTER")
+ return Column.invoke_expression_over_column(col, expression.ExplodeOuter)
def posexplode_outer(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER")
+ return Column.invoke_expression_over_column(col, expression.PosexplodeOuter)
def get_json_object(col: ColumnOrName, path: str) -> Column:
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index a044bc0..314a821 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -10,6 +10,7 @@ from sqlglot.tokens import TokenType
class Databricks(Spark):
class Parser(Spark.Parser):
LOG_DEFAULTS_TO_LN = True
+ STRICT_CAST = True
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
@@ -51,6 +52,8 @@ class Databricks(Spark):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
+ TRANSFORMS.pop(exp.TryCast)
+
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
kind = expression.args.get("kind")
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 352f11a..5b94bcb 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -133,6 +133,10 @@ class DuckDB(Dialect):
"UINTEGER": TokenType.UINT,
"USMALLINT": TokenType.USMALLINT,
"UTINYINT": TokenType.UTINYINT,
+ "TIMESTAMP_S": TokenType.TIMESTAMP_S,
+ "TIMESTAMP_MS": TokenType.TIMESTAMP_MS,
+ "TIMESTAMP_NS": TokenType.TIMESTAMP_NS,
+ "TIMESTAMP_US": TokenType.TIMESTAMP,
}
class Parser(parser.Parser):
@@ -321,6 +325,9 @@ class DuckDB(Dialect):
exp.DataType.Type.UINT: "UINTEGER",
exp.DataType.Type.VARBINARY: "BLOB",
exp.DataType.Type.VARCHAR: "TEXT",
+ exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S",
+ exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS",
+ exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS",
}
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 6a007ab..6bdd8d6 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -82,7 +82,6 @@ class Oracle(Dialect):
this=self._parse_format_json(self._parse_bitwise()),
order=self._parse_order(),
),
- "JSON_TABLE": lambda self: self._parse_json_table(),
"XMLTABLE": _parse_xml_table,
}
@@ -96,29 +95,6 @@ class Oracle(Dialect):
# Reference: https://stackoverflow.com/a/336455
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
- # Note: this is currently incomplete; it only implements the "JSON_value_column" part
- def _parse_json_column_def(self) -> exp.JSONColumnDef:
- this = self._parse_id_var()
- kind = self._parse_types(allow_identifiers=False)
- path = self._match_text_seq("PATH") and self._parse_string()
- return self.expression(exp.JSONColumnDef, this=this, kind=kind, path=path)
-
- def _parse_json_table(self) -> exp.JSONTable:
- this = self._parse_format_json(self._parse_bitwise())
- path = self._match(TokenType.COMMA) and self._parse_string()
- error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL")
- empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL")
- self._match(TokenType.COLUMN)
- expressions = self._parse_wrapped_csv(self._parse_json_column_def, optional=True)
-
- return exp.JSONTable(
- this=this,
- expressions=expressions,
- path=path,
- error_handling=error_handling,
- empty_handling=empty_handling,
- )
-
def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E:
return self.expression(
expr_type,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index e5cfa1c..88525a2 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -34,7 +34,7 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct)
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
- if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
+ if isinstance(expression.this, exp.Explode):
expression = expression.copy()
return self.sql(
exp.Join(
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index b70a8a1..04e78a5 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -58,6 +58,11 @@ class Redshift(Postgres):
"STRTOL": exp.FromBase.from_arg_list,
}
+ NO_PAREN_FUNCTION_PARSERS = {
+ **Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
+ "APPROXIMATE": lambda self: self._parse_approximate_count(),
+ }
+
def _parse_table(
self,
schema: bool = False,
@@ -93,11 +98,22 @@ class Redshift(Postgres):
return this
- def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
+ def _parse_convert(
+ self, strict: bool, safe: t.Optional[bool] = None
+ ) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_bitwise()
- return self.expression(exp.TryCast, this=this, to=to)
+ return self.expression(exp.TryCast, this=this, to=to, safe=safe)
+
+ def _parse_approximate_count(self) -> t.Optional[exp.ApproxDistinct]:
+ index = self._index - 1
+ func = self._parse_function()
+
+ if isinstance(func, exp.Count) and isinstance(func.this, exp.Distinct):
+ return self.expression(exp.ApproxDistinct, this=seq_get(func.this.expressions, 0))
+ self._retreat(index)
+ return None
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
@@ -144,6 +160,7 @@ class Redshift(Postgres):
**Postgres.Generator.TRANSFORMS,
exp.Concat: concat_to_dpipe_sql,
exp.ConcatWs: concat_ws_to_dpipe_sql,
+ exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 2eaa2ae..8461920 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -76,6 +76,9 @@ class Spark(Spark2):
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
+ exp.TryCast: lambda self, e: self.trycast_sql(e)
+ if e.args.get("safe")
+ else self.cast_sql(e),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index d8bea6d..69adb45 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -477,7 +477,9 @@ class TSQL(Dialect):
returns.set("table", table)
return returns
- def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
+ def _parse_convert(
+ self, strict: bool, safe: t.Optional[bool] = None
+ ) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_conjunction()
@@ -513,12 +515,13 @@ class TSQL(Dialect):
exp.Cast if strict else exp.TryCast,
to=to,
this=self.expression(exp.TimeToStr, this=this, format=format_norm),
+ safe=safe,
)
elif to.this == DataType.Type.TEXT:
return self.expression(exp.TimeToStr, this=this, format=format_norm)
# Entails a simple cast without any format requirement
- return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
def _parse_user_defined_function(
self, kind: t.Optional[TokenType] = None
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 74b9b7c..7931535 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -105,7 +105,7 @@ class RowReader:
return self.row[self.columns[column]]
-class Tables(AbstractMappingSchema[Table]):
+class Tables(AbstractMappingSchema):
pass
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 80f1c0f..b94b1e1 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -487,7 +487,7 @@ class Expression(metaclass=_Expression):
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
- yield node.unnest() if unnest else node
+ yield node.unnest() if unnest and not isinstance(node, Subquery) else node
def __str__(self) -> str:
return self.sql()
@@ -2107,7 +2107,7 @@ class LockingProperty(Property):
arg_types = {
"this": False,
"kind": True,
- "for_or_in": True,
+ "for_or_in": False,
"lock_type": True,
"override": False,
}
@@ -3605,6 +3605,9 @@ class DataType(Expression):
TIMESTAMP = auto()
TIMESTAMPLTZ = auto()
TIMESTAMPTZ = auto()
+ TIMESTAMP_S = auto()
+ TIMESTAMP_MS = auto()
+ TIMESTAMP_NS = auto()
TINYINT = auto()
TSMULTIRANGE = auto()
TSRANGE = auto()
@@ -3661,6 +3664,9 @@ class DataType(Expression):
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
+ Type.TIMESTAMP_S,
+ Type.TIMESTAMP_MS,
+ Type.TIMESTAMP_NS,
Type.DATE,
Type.DATETIME,
Type.DATETIME64,
@@ -4286,7 +4292,7 @@ class Case(Func):
class Cast(Func):
- arg_types = {"this": True, "to": True, "format": False}
+ arg_types = {"this": True, "to": True, "format": False, "safe": False}
@property
def name(self) -> str:
@@ -4538,6 +4544,18 @@ class Explode(Func):
pass
+class ExplodeOuter(Explode):
+ pass
+
+
+class Posexplode(Explode):
+ pass
+
+
+class PosexplodeOuter(Posexplode):
+ pass
+
+
class Floor(Func):
arg_types = {"this": True, "decimals": False}
@@ -4621,14 +4639,18 @@ class JSONArrayAgg(Func):
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html
# Note: parsing of JSON column definitions is currently incomplete.
class JSONColumnDef(Expression):
- arg_types = {"this": True, "kind": False, "path": False}
+ arg_types = {"this": False, "kind": False, "path": False, "nested_schema": False}
+
+
+class JSONSchema(Expression):
+ arg_types = {"expressions": True}
# # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html
class JSONTable(Func):
arg_types = {
"this": True,
- "expressions": True,
+ "schema": True,
"path": False,
"error_handling": False,
"empty_handling": False,
@@ -4790,10 +4812,6 @@ class Nvl2(Func):
arg_types = {"this": True, "true": True, "false": False}
-class Posexplode(Func):
- pass
-
-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function
class Predict(Func):
arg_types = {"this": True, "expression": True, "params_struct": False}
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 7a2879c..b7e26bb 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1226,9 +1226,10 @@ class Generator:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
for_or_in = expression.args.get("for_or_in")
+ for_or_in = f" {for_or_in}" if for_or_in else ""
lock_type = expression.args.get("lock_type")
override = " OVERRIDE" if expression.args.get("override") else ""
- return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
+ return f"LOCKING {kind}{this}{for_or_in} {lock_type}{override}"
def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str:
data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA"
@@ -2179,13 +2180,21 @@ class Generator:
)
def jsoncolumndef_sql(self, expression: exp.JSONColumnDef) -> str:
+ path = self.sql(expression, "path")
+ path = f" PATH {path}" if path else ""
+ nested_schema = self.sql(expression, "nested_schema")
+
+ if nested_schema:
+ return f"NESTED{path} {nested_schema}"
+
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
kind = f" {kind}" if kind else ""
- path = self.sql(expression, "path")
- path = f" PATH {path}" if path else ""
return f"{this}{kind}{path}"
+ def jsonschema_sql(self, expression: exp.JSONSchema) -> str:
+ return self.func("COLUMNS", *expression.expressions)
+
def jsontable_sql(self, expression: exp.JSONTable) -> str:
this = self.sql(expression, "this")
path = self.sql(expression, "path")
@@ -2194,9 +2203,9 @@ class Generator:
error_handling = f" {error_handling}" if error_handling else ""
empty_handling = expression.args.get("empty_handling")
empty_handling = f" {empty_handling}" if empty_handling else ""
- columns = f" COLUMNS ({self.expressions(expression, skip_first=True)})"
+ schema = self.sql(expression, "schema")
return self.func(
- "JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling}{columns})"
+ "JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling} {schema})"
)
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 00d49ae..74b61e3 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -441,6 +441,14 @@ def first(it: t.Iterable[T]) -> T:
def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
+ """
+ Merges a sequence of ranges, represented as tuples (low, high) whose values
+ belong to some totally-ordered set.
+
+ Example:
+ >>> merge_ranges([(1, 3), (2, 6)])
+ [(1, 6)]
+ """
if not ranges:
return []
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 1db094e..8d82b2d 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -6,6 +6,7 @@ from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
+from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
logger = logging.getLogger("sqlglot")
@@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
return expression
-def normalized(expression, dnf=False):
- ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
+def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
+ """
+ Checks whether a given expression is in a normal form of interest.
- return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
+ Example:
+ >>> from sqlglot import parse_one
+ >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
+ True
+ >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
+ True
+ >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
+ False
+ Args:
+ expression: The expression to check if it's normalized.
+ dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
+ Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
+ """
+ ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
+ return not any(
+ connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
+ )
-def normalization_distance(expression, dnf=False):
+
+def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
"""
- The difference in the number of predicates between the current expression and the normalized form.
+ The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
@@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False):
4
Args:
- expression (sqlglot.Expression): expression to compute distance
- dnf (bool): compute to dnf distance instead
+ expression: The expression to compute the normalization distance for.
+ dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
+ Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
+
Returns:
- int: difference
+ The normalization distance.
"""
return sum(_predicate_lengths(expression, dnf)) - (
sum(1 for _ in expression.find_all(exp.Connector)) + 1
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 9d401fc..1530456 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -39,10 +39,14 @@ def optimize_joins(expression):
if len(other_table_names(dep)) < 2:
continue
+ operator = type(on)
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
- join.on(predicate, copy=False)
+ predicate = exp._combine(
+ [join.args.get("on"), predicate], operator, copy=False
+ )
+ join.on(predicate, append=False, copy=False)
expression = reorder_joins(expression)
expression = normalize(expression)
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index b51601f..4bc3bd2 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -9,7 +9,9 @@ from sqlglot.schema import ensure_schema
SELECT_ALL = object()
# Selection to use if selection list is empty
-DEFAULT_SELECTION = lambda: alias("1", "_")
+DEFAULT_SELECTION = lambda is_agg: alias(
+ exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_"
+)
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
@@ -98,6 +100,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
new_selections = []
removed = False
star = False
+ is_agg = False
select_all = SELECT_ALL in parent_selections
@@ -112,6 +115,9 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
star = True
removed = True
+ if not is_agg and selection.find(exp.AggFunc):
+ is_agg = True
+
if star:
resolver = Resolver(scope, schema)
names = {s.alias_or_name for s in new_selections}
@@ -124,7 +130,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION())
+ new_selections.append(DEFAULT_SELECTION(is_agg))
scope.expression.select(*new_selections, append=False, copy=False)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 435899a..4af5b49 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -137,8 +137,8 @@ class Scope:
if not self._collected:
self._collect()
- def walk(self, bfs=True):
- return walk_in_scope(self.expression, bfs=bfs)
+ def walk(self, bfs=True, prune=None):
+ return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
return find_in_scope(self.expression, expression_types, bfs=bfs)
@@ -731,7 +731,7 @@ def _traverse_ddl(scope):
yield from _traverse_scope(query_scope)
-def walk_in_scope(expression, bfs=True):
+def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
@@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True):
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
+ prune ((node, parent, arg_key) -> bool): callable that returns True if
+ the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
- prune = False
+ crossed_scope_boundary = False
- for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
- prune = False
+ for node, parent, key in expression.walk(
+ bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
+ ):
+ crossed_scope_boundary = False
yield node, parent, key
@@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True):
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
):
- prune = True
+ crossed_scope_boundary = True
if isinstance(node, (exp.Subquery, exp.UDTF)):
# The following args are not actually in the inner scope, so we should visit them
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 51214c4..849643c 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,9 +5,11 @@ import typing as t
from collections import deque
from decimal import Decimal
+import sqlglot
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
+from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
FINAL = "final"
@@ -17,7 +19,7 @@ class UnsupportedUnit(Exception):
pass
-def simplify(expression):
+def simplify(expression, constant_propagation=False):
"""
Rewrite sqlglot AST to simplify expressions.
@@ -29,6 +31,8 @@ def simplify(expression):
Args:
expression (sqlglot.Expression): expression to simplify
+ constant_propagation: whether or not the constant propagation rule should be used
+
Returns:
sqlglot.Expression: simplified expression
"""
@@ -67,13 +71,16 @@ def simplify(expression):
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
+ if constant_propagation:
+ node = propagate_constants(node, root)
+
exp.replace_children(node, lambda e: _simplify(e, False))
# Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
- node = remove_compliments(node, root)
+ node = remove_complements(node, root)
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
@@ -287,19 +294,19 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
-def remove_compliments(expression, root=True):
+def remove_complements(expression, root=True):
"""
- Removing compliments.
+ Removing complements.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
- compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
+ complement = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
- return compliment
+ return complement
return expression
@@ -369,6 +376,51 @@ def absorb_and_eliminate(expression, root=True):
return expression
+def propagate_constants(expression, root=True):
+ """
+ Propagate constants for conjunctions in DNF:
+
+ SELECT * FROM t WHERE a = b AND b = 5 becomes
+ SELECT * FROM t WHERE a = 5 AND b = 5
+
+ Reference: https://www.sqlite.org/optoverview.html
+ """
+
+ if (
+ isinstance(expression, exp.And)
+ and (root or not expression.same_parent)
+ and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
+ ):
+ constant_mapping = {}
+ for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
+ if isinstance(expr, exp.EQ):
+ l, r = expr.left, expr.right
+
+ # TODO: create a helper that can be used to detect nested literal expressions such
+ # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
+ if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
+ pass
+ elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
+ l, r = r, l
+ else:
+ continue
+
+ constant_mapping[l] = (id(l), r)
+
+ if constant_mapping:
+ for column in find_all_in_scope(expression, exp.Column):
+ parent = column.parent
+ column_id, constant = constant_mapping.get(column) or (None, None)
+ if (
+ column_id is not None
+ and id(column) != column_id
+ and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
+ ):
+ column.replace(constant.copy())
+
+ return expression
+
+
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
@@ -609,21 +661,38 @@ SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
def simplify_concat(expression):
"""Reduces all groups that contain string literals by concatenating them."""
- if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
+ if not isinstance(expression, CONCATS) or (
+ # We can't reduce a CONCAT_WS call if we don't statically know the separator
+ isinstance(expression, exp.ConcatWs)
+ and not expression.expressions[0].is_string
+ ):
return expression
+ if isinstance(expression, exp.ConcatWs):
+ sep_expr, *expressions = expression.expressions
+ sep = sep_expr.name
+ concat_type = exp.ConcatWs
+ else:
+ expressions = expression.expressions
+ sep = ""
+ concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
+
new_args = []
for is_string_group, group in itertools.groupby(
- expression.expressions or expression.flatten(), lambda e: e.is_string
+ expressions or expression.flatten(), lambda e: e.is_string
):
if is_string_group:
- new_args.append(exp.Literal.string("".join(string.name for string in group)))
+ new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
else:
new_args.extend(group)
- # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
- concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
- return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
+ if len(new_args) == 1 and new_args[0].is_string:
+ return new_args[0]
+
+ if concat_type is exp.ConcatWs:
+ new_args = [sep_expr] + new_args
+
+ return concat_type(expressions=new_args)
DateRange = t.Tuple[datetime.date, datetime.date]
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 510abfb..8de76ca 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -160,6 +160,9 @@ class Parser(metaclass=_Parser):
TokenType.TIME,
TokenType.TIMETZ,
TokenType.TIMESTAMP,
+ TokenType.TIMESTAMP_S,
+ TokenType.TIMESTAMP_MS,
+ TokenType.TIMESTAMP_NS,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
TokenType.DATETIME,
@@ -792,17 +795,18 @@ class Parser(metaclass=_Parser):
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
+ "JSON_TABLE": lambda self: self._parse_json_table(),
"LOG": lambda self: self._parse_logarithm(),
"MATCH": lambda self: self._parse_match_against(),
"OPENJSON": lambda self: self._parse_open_json(),
"POSITION": lambda self: self._parse_position(),
"PREDICT": lambda self: self._parse_predict(),
- "SAFE_CAST": lambda self: self._parse_cast(False),
+ "SAFE_CAST": lambda self: self._parse_cast(False, safe=True),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
- "TRY_CAST": lambda self: self._parse_cast(False),
- "TRY_CONVERT": lambda self: self._parse_convert(False),
+ "TRY_CAST": lambda self: self._parse_cast(False, safe=True),
+ "TRY_CONVERT": lambda self: self._parse_convert(False, safe=True),
}
QUERY_MODIFIER_PARSERS = {
@@ -4135,7 +4139,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AnyValue, this=this, having=having, max=is_max)
- def _parse_cast(self, strict: bool) -> exp.Expression:
+ def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression:
this = self._parse_conjunction()
if not self._match(TokenType.ALIAS):
@@ -4176,7 +4180,9 @@ class Parser(metaclass=_Parser):
return this
- return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt)
+ return self.expression(
+ exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
+ )
def _parse_concat(self) -> t.Optional[exp.Expression]:
args = self._parse_csv(self._parse_conjunction)
@@ -4230,7 +4236,9 @@ class Parser(metaclass=_Parser):
order = self._parse_order(this=seq_get(args, 0))
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
- def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
+ def _parse_convert(
+ self, strict: bool, safe: t.Optional[bool] = None
+ ) -> t.Optional[exp.Expression]:
this = self._parse_bitwise()
if self._match(TokenType.USING):
@@ -4242,7 +4250,7 @@ class Parser(metaclass=_Parser):
else:
to = None
- return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
"""
@@ -4347,6 +4355,50 @@ class Parser(metaclass=_Parser):
encoding=encoding,
)
+ # Note: this is currently incomplete; it only implements the "JSON_value_column" part
+ def _parse_json_column_def(self) -> exp.JSONColumnDef:
+ if not self._match_text_seq("NESTED"):
+ this = self._parse_id_var()
+ kind = self._parse_types(allow_identifiers=False)
+ nested = None
+ else:
+ this = None
+ kind = None
+ nested = True
+
+ path = self._match_text_seq("PATH") and self._parse_string()
+ nested_schema = nested and self._parse_json_schema()
+
+ return self.expression(
+ exp.JSONColumnDef,
+ this=this,
+ kind=kind,
+ path=path,
+ nested_schema=nested_schema,
+ )
+
+ def _parse_json_schema(self) -> exp.JSONSchema:
+ self._match_text_seq("COLUMNS")
+ return self.expression(
+ exp.JSONSchema,
+ expressions=self._parse_wrapped_csv(self._parse_json_column_def, optional=True),
+ )
+
+ def _parse_json_table(self) -> exp.JSONTable:
+ this = self._parse_format_json(self._parse_bitwise())
+ path = self._match(TokenType.COMMA) and self._parse_string()
+ error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL")
+ empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL")
+ schema = self._parse_json_schema()
+
+ return exp.JSONTable(
+ this=this,
+ schema=schema,
+ path=path,
+ error_handling=error_handling,
+ empty_handling=empty_handling,
+ )
+
def _parse_logarithm(self) -> exp.Func:
# Default argument order is base, expression
args = self._parse_csv(self._parse_range)
@@ -4973,7 +5025,17 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ON)
on = self._parse_conjunction()
+ return self.expression(
+ exp.Merge,
+ this=target,
+ using=using,
+ on=on,
+ expressions=self._parse_when_matched(),
+ )
+
+ def _parse_when_matched(self) -> t.List[exp.When]:
whens = []
+
while self._match(TokenType.WHEN):
matched = not self._match(TokenType.NOT)
self._match_text_seq("MATCHED")
@@ -5020,14 +5082,7 @@ class Parser(metaclass=_Parser):
then=then,
)
)
-
- return self.expression(
- exp.Merge,
- this=target,
- using=using,
- on=on,
- expressions=whens,
- )
+ return whens
def _parse_show(self) -> t.Optional[exp.Expression]:
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f0b279b..778378c 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,7 +5,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot._typing import T
from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
@@ -71,7 +70,7 @@ class Schema(abc.ABC):
def get_column_type(
self,
table: exp.Table | str,
- column: exp.Column,
+ column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
@@ -88,6 +87,28 @@ class Schema(abc.ABC):
The resulting column type.
"""
+ def has_column(
+ self,
+ table: exp.Table | str,
+ column: exp.Column | str,
+ dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ ) -> bool:
+ """
+ Returns whether or not `column` appears in `table`'s schema.
+
+ Args:
+ table: the source table.
+ column: the target column.
+ dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
+
+ Returns:
+ True if the column appears in the schema, False otherwise.
+ """
+ name = column if isinstance(column, str) else column.name
+ return name in self.column_names(table, dialect=dialect, normalize=normalize)
+
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
@@ -101,7 +122,7 @@ class Schema(abc.ABC):
return True
-class AbstractMappingSchema(t.Generic[T]):
+class AbstractMappingSchema:
def __init__(
self,
mapping: t.Optional[t.Dict] = None,
@@ -140,7 +161,7 @@ class AbstractMappingSchema(t.Generic[T]):
def find(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
- ) -> t.Optional[T]:
+ ) -> t.Optional[t.Any]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
@@ -170,7 +191,7 @@ class AbstractMappingSchema(t.Generic[T]):
)
-class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
+class MappingSchema(AbstractMappingSchema, Schema):
"""
Schema based on a nested mapping.
@@ -287,7 +308,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
def get_column_type(
self,
table: exp.Table | str,
- column: exp.Column,
+ column: exp.Column | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
@@ -304,10 +325,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
if isinstance(column_type, exp.DataType):
return column_type
elif isinstance(column_type, str):
- return self._to_data_type(column_type.upper(), dialect=dialect)
+ return self._to_data_type(column_type, dialect=dialect)
return exp.DataType.build("unknown")
+ def has_column(
+ self,
+ table: exp.Table | str,
+ column: exp.Column | str,
+ dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ ) -> bool:
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
+
+ normalized_column_name = self._normalize_name(
+ column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
+ )
+
+ table_schema = self.find(normalized_table, raise_on_missing=False)
+ return normalized_column_name in table_schema if table_schema else False
+
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Normalizes all identifiers in the schema.
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 4ab01dd..c883858 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -121,6 +121,9 @@ class TokenType(AutoName):
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
+ TIMESTAMP_S = auto()
+ TIMESTAMP_MS = auto()
+ TIMESTAMP_NS = auto()
DATETIME = auto()
DATETIME64 = auto()
DATE = auto()
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index ac9dd81..8feee52 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -189,9 +189,9 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
# we use list here because expression.selects is mutated inside the loop
for select in expression.selects.copy():
- explode = select.find(exp.Explode, exp.Posexplode)
+ explode = select.find(exp.Explode)
- if isinstance(explode, (exp.Explode, exp.Posexplode)):
+ if explode:
pos_alias = ""
explode_alias = ""
@@ -204,7 +204,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
alias = select.replace(exp.alias_(select.this, "", copy=False))
else:
alias = select.replace(exp.alias_(select, ""))
- explode = alias.find(exp.Explode, exp.Posexplode)
+ explode = alias.find(exp.Explode)
assert explode
is_posexplode = isinstance(explode, exp.Posexplode)