diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-16 11:37:39 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-16 11:37:39 +0000 |
commit | f10d022e11dcd1015db1a74ce9f4198ebdcb7f40 (patch) | |
tree | ac7bdc1d214a0f97f991cff14e933f4895ee68e1 /sqlglot | |
parent | Releasing progress-linux version 18.11.6-1. (diff) | |
download | sqlglot-f10d022e11dcd1015db1a74ce9f4198ebdcb7f40.tar.xz sqlglot-f10d022e11dcd1015db1a74ce9f4198ebdcb7f40.zip |
Merging upstream version 18.13.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 24 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 21 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 7 | ||||
-rw-r--r-- | sqlglot/executor/table.py | 2 | ||||
-rw-r--r-- | sqlglot/expressions.py | 36 | ||||
-rw-r--r-- | sqlglot/generator.py | 19 | ||||
-rw-r--r-- | sqlglot/helper.py | 8 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize.py | 37 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 10 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 18 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 93 | ||||
-rw-r--r-- | sqlglot/parser.py | 85 | ||||
-rw-r--r-- | sqlglot/schema.py | 51 | ||||
-rw-r--r-- | sqlglot/tokens.py | 3 | ||||
-rw-r--r-- | sqlglot/transforms.py | 6 |
21 files changed, 344 insertions, 101 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 9ab00d5..d98feee 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -1019,11 +1019,11 @@ def posexplode(col: ColumnOrName) -> Column: def explode_outer(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "EXPLODE_OUTER") + return Column.invoke_expression_over_column(col, expression.ExplodeOuter) def posexplode_outer(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER") + return Column.invoke_expression_over_column(col, expression.PosexplodeOuter) def get_json_object(col: ColumnOrName, path: str) -> Column: diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index a044bc0..314a821 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -10,6 +10,7 @@ from sqlglot.tokens import TokenType class Databricks(Spark): class Parser(Spark.Parser): LOG_DEFAULTS_TO_LN = True + STRICT_CAST = True FUNCTIONS = { **Spark.Parser.FUNCTIONS, @@ -51,6 +52,8 @@ class Databricks(Spark): exp.ToChar: lambda self, e: self.function_fallback_sql(e), } + TRANSFORMS.pop(exp.TryCast) + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint) kind = expression.args.get("kind") diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 352f11a..5b94bcb 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -133,6 +133,10 @@ class DuckDB(Dialect): "UINTEGER": TokenType.UINT, "USMALLINT": TokenType.USMALLINT, "UTINYINT": TokenType.UTINYINT, + "TIMESTAMP_S": TokenType.TIMESTAMP_S, + "TIMESTAMP_MS": TokenType.TIMESTAMP_MS, + "TIMESTAMP_NS": TokenType.TIMESTAMP_NS, + "TIMESTAMP_US": TokenType.TIMESTAMP, } class Parser(parser.Parser): @@ -321,6 +325,9 @@ class DuckDB(Dialect): exp.DataType.Type.UINT: "UINTEGER", exp.DataType.Type.VARBINARY: "BLOB", exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S", + exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS", + exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS", } STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 6a007ab..6bdd8d6 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -82,7 +82,6 @@ class Oracle(Dialect): this=self._parse_format_json(self._parse_bitwise()), order=self._parse_order(), ), - "JSON_TABLE": lambda self: self._parse_json_table(), "XMLTABLE": _parse_xml_table, } @@ -96,29 +95,6 @@ class Oracle(Dialect): # Reference: https://stackoverflow.com/a/336455 DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE} - # Note: this is currently incomplete; it only implements the "JSON_value_column" part - def _parse_json_column_def(self) -> exp.JSONColumnDef: - this = self._parse_id_var() - kind = self._parse_types(allow_identifiers=False) - path = self._match_text_seq("PATH") and self._parse_string() - return self.expression(exp.JSONColumnDef, this=this, kind=kind, path=path) - - def _parse_json_table(self) -> exp.JSONTable: - this = self._parse_format_json(self._parse_bitwise()) - path = self._match(TokenType.COMMA) and self._parse_string() - error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL") - empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL") - self._match(TokenType.COLUMN) - expressions = self._parse_wrapped_csv(self._parse_json_column_def, optional=True) - - return exp.JSONTable( - this=this, - expressions=expressions, - path=path, - error_handling=error_handling, - empty_handling=empty_handling, - ) - def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E: return self.expression( expr_type, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index e5cfa1c..88525a2 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -34,7 +34,7 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: - if isinstance(expression.this, (exp.Explode, exp.Posexplode)): + if isinstance(expression.this, exp.Explode): expression = expression.copy() return self.sql( exp.Join( diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index b70a8a1..04e78a5 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -58,6 +58,11 @@ class Redshift(Postgres): "STRTOL": exp.FromBase.from_arg_list, } + NO_PAREN_FUNCTION_PARSERS = { + **Postgres.Parser.NO_PAREN_FUNCTION_PARSERS, + "APPROXIMATE": lambda self: self._parse_approximate_count(), + } + def _parse_table( self, schema: bool = False, @@ -93,11 +98,22 @@ class Redshift(Postgres): return this - def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: + def _parse_convert( + self, strict: bool, safe: t.Optional[bool] = None + ) -> t.Optional[exp.Expression]: to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_bitwise() - return self.expression(exp.TryCast, this=this, to=to) + return self.expression(exp.TryCast, this=this, to=to, safe=safe) + + def _parse_approximate_count(self) -> t.Optional[exp.ApproxDistinct]: + index = self._index - 1 + func = self._parse_function() + + if isinstance(func, exp.Count) and isinstance(func.this, exp.Distinct): + return self.expression(exp.ApproxDistinct, this=seq_get(func.this.expressions, 0)) + self._retreat(index) + return None class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] @@ -144,6 +160,7 @@ class Redshift(Postgres): **Postgres.Generator.TRANSFORMS, exp.Concat: concat_to_dpipe_sql, exp.ConcatWs: concat_ws_to_dpipe_sql, + exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})", exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 2eaa2ae..8461920 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -76,6 +76,9 @@ class Spark(Spark2): exp.TimestampAdd: lambda self, e: self.func( "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this ), + exp.TryCast: lambda self, e: self.trycast_sql(e) + if e.args.get("safe") + else self.cast_sql(e), } TRANSFORMS.pop(exp.AnyValue) TRANSFORMS.pop(exp.DateDiff) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index d8bea6d..69adb45 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -477,7 +477,9 @@ class TSQL(Dialect): returns.set("table", table) return returns - def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: + def _parse_convert( + self, strict: bool, safe: t.Optional[bool] = None + ) -> t.Optional[exp.Expression]: to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_conjunction() @@ -513,12 +515,13 @@ class TSQL(Dialect): exp.Cast if strict else exp.TryCast, to=to, this=self.expression(exp.TimeToStr, this=this, format=format_norm), + safe=safe, ) elif to.this == DataType.Type.TEXT: return self.expression(exp.TimeToStr, this=this, format=format_norm) # Entails a simple cast without any format requirement - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe) def _parse_user_defined_function( self, kind: t.Optional[TokenType] = None diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 74b9b7c..7931535 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -105,7 +105,7 @@ class RowReader: return self.row[self.columns[column]] -class Tables(AbstractMappingSchema[Table]): +class Tables(AbstractMappingSchema): pass diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 80f1c0f..b94b1e1 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -487,7 +487,7 @@ class Expression(metaclass=_Expression): """ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__): if not type(node) is self.__class__: - yield node.unnest() if unnest else node + yield node.unnest() if unnest and not isinstance(node, Subquery) else node def __str__(self) -> str: return self.sql() @@ -2107,7 +2107,7 @@ class LockingProperty(Property): arg_types = { "this": False, "kind": True, - "for_or_in": True, + "for_or_in": False, "lock_type": True, "override": False, } @@ -3605,6 +3605,9 @@ class DataType(Expression): TIMESTAMP = auto() TIMESTAMPLTZ = auto() TIMESTAMPTZ = auto() + TIMESTAMP_S = auto() + TIMESTAMP_MS = auto() + TIMESTAMP_NS = auto() TINYINT = auto() TSMULTIRANGE = auto() TSRANGE = auto() @@ -3661,6 +3664,9 @@ class DataType(Expression): Type.TIMESTAMP, Type.TIMESTAMPTZ, Type.TIMESTAMPLTZ, + Type.TIMESTAMP_S, + Type.TIMESTAMP_MS, + Type.TIMESTAMP_NS, Type.DATE, Type.DATETIME, Type.DATETIME64, @@ -4286,7 +4292,7 @@ class Case(Func): class Cast(Func): - arg_types = {"this": True, "to": True, "format": False} + arg_types = {"this": True, "to": True, "format": False, "safe": False} @property def name(self) -> str: @@ -4538,6 +4544,18 @@ class Explode(Func): pass +class ExplodeOuter(Explode): + pass + + +class Posexplode(Explode): + pass + + +class PosexplodeOuter(Posexplode): + pass + + class Floor(Func): arg_types = {"this": True, "decimals": False} @@ -4621,14 +4639,18 @@ class JSONArrayAgg(Func): # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html # Note: parsing of JSON column definitions is currently incomplete. class JSONColumnDef(Expression): - arg_types = {"this": True, "kind": False, "path": False} + arg_types = {"this": False, "kind": False, "path": False, "nested_schema": False} + + +class JSONSchema(Expression): + arg_types = {"expressions": True} # # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html class JSONTable(Func): arg_types = { "this": True, - "expressions": True, + "schema": True, "path": False, "error_handling": False, "empty_handling": False, @@ -4790,10 +4812,6 @@ class Nvl2(Func): arg_types = {"this": True, "true": True, "false": False} -class Posexplode(Func): - pass - - # https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function class Predict(Func): arg_types = {"this": True, "expression": True, "params_struct": False} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 7a2879c..b7e26bb 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1226,9 +1226,10 @@ class Generator: kind = expression.args.get("kind") this = f" {self.sql(expression, 'this')}" if expression.this else "" for_or_in = expression.args.get("for_or_in") + for_or_in = f" {for_or_in}" if for_or_in else "" lock_type = expression.args.get("lock_type") override = " OVERRIDE" if expression.args.get("override") else "" - return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}" + return f"LOCKING {kind}{this}{for_or_in} {lock_type}{override}" def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str: data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA" @@ -2179,13 +2180,21 @@ class Generator: ) def jsoncolumndef_sql(self, expression: exp.JSONColumnDef) -> str: + path = self.sql(expression, "path") + path = f" PATH {path}" if path else "" + nested_schema = self.sql(expression, "nested_schema") + + if nested_schema: + return f"NESTED{path} {nested_schema}" + this = self.sql(expression, "this") kind = self.sql(expression, "kind") kind = f" {kind}" if kind else "" - path = self.sql(expression, "path") - path = f" PATH {path}" if path else "" return f"{this}{kind}{path}" + def jsonschema_sql(self, expression: exp.JSONSchema) -> str: + return self.func("COLUMNS", *expression.expressions) + def jsontable_sql(self, expression: exp.JSONTable) -> str: this = self.sql(expression, "this") path = self.sql(expression, "path") @@ -2194,9 +2203,9 @@ class Generator: error_handling = f" {error_handling}" if error_handling else "" empty_handling = expression.args.get("empty_handling") empty_handling = f" {empty_handling}" if empty_handling else "" - columns = f" COLUMNS ({self.expressions(expression, skip_first=True)})" + schema = self.sql(expression, "schema") return self.func( - "JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling}{columns})" + "JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling} {schema})" ) def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str: diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 00d49ae..74b61e3 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -441,6 +441,14 @@ def first(it: t.Iterable[T]) -> T: def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: + """ + Merges a sequence of ranges, represented as tuples (low, high) whose values + belong to some totally-ordered set. + + Example: + >>> merge_ranges([(1, 3), (2, 6)]) + [(1, 6)] + """ if not ranges: return [] diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 1db094e..8d82b2d 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -6,6 +6,7 @@ from sqlglot import exp from sqlglot.errors import OptimizeError from sqlglot.generator import cached_generator from sqlglot.helper import while_changing +from sqlglot.optimizer.scope import find_all_in_scope from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = return expression -def normalized(expression, dnf=False): - ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) +def normalized(expression: exp.Expression, dnf: bool = False) -> bool: + """ + Checks whether a given expression is in a normal form of interest. - return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) + Example: + >>> from sqlglot import parse_one + >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) + True + >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default + True + >>> normalized(parse_one("a AND (b OR c)"), dnf=True) + False + Args: + expression: The expression to check if it's normalized. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + """ + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + return not any( + connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) + ) -def normalization_distance(expression, dnf=False): + +def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: """ - The difference in the number of predicates between the current expression and the normalized form. + The difference in the number of predicates between a given expression and its normalized form. This is used as an estimate of the cost of the conversion which is exponential in complexity. @@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False): 4 Args: - expression (sqlglot.Expression): expression to compute distance - dnf (bool): compute to dnf distance instead + expression: The expression to compute the normalization distance for. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + Returns: - int: difference + The normalization distance. """ return sum(_predicate_lengths(expression, dnf)) - ( sum(1 for _ in expression.find_all(exp.Connector)) + 1 diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 9d401fc..1530456 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -39,10 +39,14 @@ def optimize_joins(expression): if len(other_table_names(dep)) < 2: continue + operator = type(on) for predicate in on.flatten(): if name in exp.column_table_names(predicate): predicate.replace(exp.true()) - join.on(predicate, copy=False) + predicate = exp._combine( + [join.args.get("on"), predicate], operator, copy=False + ) + join.on(predicate, append=False, copy=False) expression = reorder_joins(expression) expression = normalize(expression) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index b51601f..4bc3bd2 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -9,7 +9,9 @@ from sqlglot.schema import ensure_schema SELECT_ALL = object() # Selection to use if selection list is empty -DEFAULT_SELECTION = lambda: alias("1", "_") +DEFAULT_SELECTION = lambda is_agg: alias( + exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_" +) def pushdown_projections(expression, schema=None, remove_unused_selections=True): @@ -98,6 +100,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count): new_selections = [] removed = False star = False + is_agg = False select_all = SELECT_ALL in parent_selections @@ -112,6 +115,9 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count): star = True removed = True + if not is_agg and selection.find(exp.AggFunc): + is_agg = True + if star: resolver = Resolver(scope, schema) names = {s.alias_or_name for s in new_selections} @@ -124,7 +130,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION()) + new_selections.append(DEFAULT_SELECTION(is_agg)) scope.expression.select(*new_selections, append=False, copy=False) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 435899a..4af5b49 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -137,8 +137,8 @@ class Scope: if not self._collected: self._collect() - def walk(self, bfs=True): - return walk_in_scope(self.expression, bfs=bfs) + def walk(self, bfs=True, prune=None): + return walk_in_scope(self.expression, bfs=bfs, prune=None) def find(self, *expression_types, bfs=True): return find_in_scope(self.expression, expression_types, bfs=bfs) @@ -731,7 +731,7 @@ def _traverse_ddl(scope): yield from _traverse_scope(query_scope) -def walk_in_scope(expression, bfs=True): +def walk_in_scope(expression, bfs=True, prune=None): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes. @@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True): expression (exp.Expression): bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead. + prune ((node, parent, arg_key) -> bool): callable that returns True if + the generator should stop traversing this branch of the tree. Yields: tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key """ # We'll use this variable to pass state into the dfs generator. # Whenever we set it to True, we exclude a subtree from traversal. - prune = False + crossed_scope_boundary = False - for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): - prune = False + for node, parent, key in expression.walk( + bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) + ): + crossed_scope_boundary = False yield node, parent, key @@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True): or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): - prune = True + crossed_scope_boundary = True if isinstance(node, (exp.Subquery, exp.UDTF)): # The following args are not actually in the inner scope, so we should visit them diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 51214c4..849643c 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,9 +5,11 @@ import typing as t from collections import deque from decimal import Decimal +import sqlglot from sqlglot import exp from sqlglot.generator import cached_generator from sqlglot.helper import first, merge_ranges, while_changing +from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope # Final means that an expression should not be simplified FINAL = "final" @@ -17,7 +19,7 @@ class UnsupportedUnit(Exception): pass -def simplify(expression): +def simplify(expression, constant_propagation=False): """ Rewrite sqlglot AST to simplify expressions. @@ -29,6 +31,8 @@ def simplify(expression): Args: expression (sqlglot.Expression): expression to simplify + constant_propagation: whether or not the constant propagation rule should be used + Returns: sqlglot.Expression: simplified expression """ @@ -67,13 +71,16 @@ def simplify(expression): node = absorb_and_eliminate(node, root) node = simplify_concat(node) + if constant_propagation: + node = propagate_constants(node, root) + exp.replace_children(node, lambda e: _simplify(e, False)) # Post-order transformations node = simplify_not(node) node = flatten(node) node = simplify_connectors(node, root) - node = remove_compliments(node, root) + node = remove_complements(node, root) node = simplify_coalesce(node) node.parent = expression.parent node = simplify_literals(node, root) @@ -287,19 +294,19 @@ def _simplify_comparison(expression, left, right, or_=False): return None -def remove_compliments(expression, root=True): +def remove_complements(expression, root=True): """ - Removing compliments. + Removing complements. A AND NOT A -> FALSE A OR NOT A -> TRUE """ if isinstance(expression, exp.Connector) and (root or not expression.same_parent): - compliment = exp.false() if isinstance(expression, exp.And) else exp.true() + complement = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): if is_complement(a, b): - return compliment + return complement return expression @@ -369,6 +376,51 @@ def absorb_and_eliminate(expression, root=True): return expression +def propagate_constants(expression, root=True): + """ + Propagate constants for conjunctions in DNF: + + SELECT * FROM t WHERE a = b AND b = 5 becomes + SELECT * FROM t WHERE a = 5 AND b = 5 + + Reference: https://www.sqlite.org/optoverview.html + """ + + if ( + isinstance(expression, exp.And) + and (root or not expression.same_parent) + and sqlglot.optimizer.normalize.normalized(expression, dnf=True) + ): + constant_mapping = {} + for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): + if isinstance(expr, exp.EQ): + l, r = expr.left, expr.right + + # TODO: create a helper that can be used to detect nested literal expressions such + # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too + if isinstance(l, exp.Column) and isinstance(r, exp.Literal): + pass + elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): + l, r = r, l + else: + continue + + constant_mapping[l] = (id(l), r) + + if constant_mapping: + for column in find_all_in_scope(expression, exp.Column): + parent = column.parent + column_id, constant = constant_mapping.get(column) or (None, None) + if ( + column_id is not None + and id(column) != column_id + and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) + ): + column.replace(constant.copy()) + + return expression + + INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.DateAdd: exp.Sub, exp.DateSub: exp.Add, @@ -609,21 +661,38 @@ SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) def simplify_concat(expression): """Reduces all groups that contain string literals by concatenating them.""" - if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): + if not isinstance(expression, CONCATS) or ( + # We can't reduce a CONCAT_WS call if we don't statically know the separator + isinstance(expression, exp.ConcatWs) + and not expression.expressions[0].is_string + ): return expression + if isinstance(expression, exp.ConcatWs): + sep_expr, *expressions = expression.expressions + sep = sep_expr.name + concat_type = exp.ConcatWs + else: + expressions = expression.expressions + sep = "" + concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat + new_args = [] for is_string_group, group in itertools.groupby( - expression.expressions or expression.flatten(), lambda e: e.is_string + expressions or expression.flatten(), lambda e: e.is_string ): if is_string_group: - new_args.append(exp.Literal.string("".join(string.name for string in group))) + new_args.append(exp.Literal.string(sep.join(string.name for string in group))) else: new_args.extend(group) - # Ensures we preserve the right concat type, i.e. whether it's "safe" or not - concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat - return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) + if len(new_args) == 1 and new_args[0].is_string: + return new_args[0] + + if concat_type is exp.ConcatWs: + new_args = [sep_expr] + new_args + + return concat_type(expressions=new_args) DateRange = t.Tuple[datetime.date, datetime.date] diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 510abfb..8de76ca 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -160,6 +160,9 @@ class Parser(metaclass=_Parser): TokenType.TIME, TokenType.TIMETZ, TokenType.TIMESTAMP, + TokenType.TIMESTAMP_S, + TokenType.TIMESTAMP_MS, + TokenType.TIMESTAMP_NS, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, TokenType.DATETIME, @@ -792,17 +795,18 @@ class Parser(metaclass=_Parser): "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), "JSON_OBJECT": lambda self: self._parse_json_object(), + "JSON_TABLE": lambda self: self._parse_json_table(), "LOG": lambda self: self._parse_logarithm(), "MATCH": lambda self: self._parse_match_against(), "OPENJSON": lambda self: self._parse_open_json(), "POSITION": lambda self: self._parse_position(), "PREDICT": lambda self: self._parse_predict(), - "SAFE_CAST": lambda self: self._parse_cast(False), + "SAFE_CAST": lambda self: self._parse_cast(False, safe=True), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), "TRIM": lambda self: self._parse_trim(), - "TRY_CAST": lambda self: self._parse_cast(False), - "TRY_CONVERT": lambda self: self._parse_convert(False), + "TRY_CAST": lambda self: self._parse_cast(False, safe=True), + "TRY_CONVERT": lambda self: self._parse_convert(False, safe=True), } QUERY_MODIFIER_PARSERS = { @@ -4135,7 +4139,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.AnyValue, this=this, having=having, max=is_max) - def _parse_cast(self, strict: bool) -> exp.Expression: + def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression: this = self._parse_conjunction() if not self._match(TokenType.ALIAS): @@ -4176,7 +4180,9 @@ class Parser(metaclass=_Parser): return this - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt) + return self.expression( + exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe + ) def _parse_concat(self) -> t.Optional[exp.Expression]: args = self._parse_csv(self._parse_conjunction) @@ -4230,7 +4236,9 @@ class Parser(metaclass=_Parser): order = self._parse_order(this=seq_get(args, 0)) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) - def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: + def _parse_convert( + self, strict: bool, safe: t.Optional[bool] = None + ) -> t.Optional[exp.Expression]: this = self._parse_bitwise() if self._match(TokenType.USING): @@ -4242,7 +4250,7 @@ class Parser(metaclass=_Parser): else: to = None - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe) def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]: """ @@ -4347,6 +4355,50 @@ class Parser(metaclass=_Parser): encoding=encoding, ) + # Note: this is currently incomplete; it only implements the "JSON_value_column" part + def _parse_json_column_def(self) -> exp.JSONColumnDef: + if not self._match_text_seq("NESTED"): + this = self._parse_id_var() + kind = self._parse_types(allow_identifiers=False) + nested = None + else: + this = None + kind = None + nested = True + + path = self._match_text_seq("PATH") and self._parse_string() + nested_schema = nested and self._parse_json_schema() + + return self.expression( + exp.JSONColumnDef, + this=this, + kind=kind, + path=path, + nested_schema=nested_schema, + ) + + def _parse_json_schema(self) -> exp.JSONSchema: + self._match_text_seq("COLUMNS") + return self.expression( + exp.JSONSchema, + expressions=self._parse_wrapped_csv(self._parse_json_column_def, optional=True), + ) + + def _parse_json_table(self) -> exp.JSONTable: + this = self._parse_format_json(self._parse_bitwise()) + path = self._match(TokenType.COMMA) and self._parse_string() + error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL") + empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL") + schema = self._parse_json_schema() + + return exp.JSONTable( + this=this, + schema=schema, + path=path, + error_handling=error_handling, + empty_handling=empty_handling, + ) + def _parse_logarithm(self) -> exp.Func: # Default argument order is base, expression args = self._parse_csv(self._parse_range) @@ -4973,7 +5025,17 @@ class Parser(metaclass=_Parser): self._match(TokenType.ON) on = self._parse_conjunction() + return self.expression( + exp.Merge, + this=target, + using=using, + on=on, + expressions=self._parse_when_matched(), + ) + + def _parse_when_matched(self) -> t.List[exp.When]: whens = [] + while self._match(TokenType.WHEN): matched = not self._match(TokenType.NOT) self._match_text_seq("MATCHED") @@ -5020,14 +5082,7 @@ class Parser(metaclass=_Parser): then=then, ) ) - - return self.expression( - exp.Merge, - this=target, - using=using, - on=on, - expressions=whens, - ) + return whens def _parse_show(self) -> t.Optional[exp.Expression]: parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f0b279b..778378c 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,7 +5,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot._typing import T from sqlglot.dialects.dialect import Dialect from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth @@ -71,7 +70,7 @@ class Schema(abc.ABC): def get_column_type( self, table: exp.Table | str, - column: exp.Column, + column: exp.Column | str, dialect: DialectType = None, normalize: t.Optional[bool] = None, ) -> exp.DataType: @@ -88,6 +87,28 @@ class Schema(abc.ABC): The resulting column type. """ + def has_column( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> bool: + """ + Returns whether or not `column` appears in `table`'s schema. + + Args: + table: the source table. + column: the target column. + dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + + Returns: + True if the column appears in the schema, False otherwise. + """ + name = column if isinstance(column, str) else column.name + return name in self.column_names(table, dialect=dialect, normalize=normalize) + @property @abc.abstractmethod def supported_table_args(self) -> t.Tuple[str, ...]: @@ -101,7 +122,7 @@ class Schema(abc.ABC): return True -class AbstractMappingSchema(t.Generic[T]): +class AbstractMappingSchema: def __init__( self, mapping: t.Optional[t.Dict] = None, @@ -140,7 +161,7 @@ class AbstractMappingSchema(t.Generic[T]): def find( self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True - ) -> t.Optional[T]: + ) -> t.Optional[t.Any]: parts = self.table_parts(table)[0 : len(self.supported_table_args)] value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) @@ -170,7 +191,7 @@ class AbstractMappingSchema(t.Generic[T]): ) -class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): +class MappingSchema(AbstractMappingSchema, Schema): """ Schema based on a nested mapping. @@ -287,7 +308,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): def get_column_type( self, table: exp.Table | str, - column: exp.Column, + column: exp.Column | str, dialect: DialectType = None, normalize: t.Optional[bool] = None, ) -> exp.DataType: @@ -304,10 +325,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): if isinstance(column_type, exp.DataType): return column_type elif isinstance(column_type, str): - return self._to_data_type(column_type.upper(), dialect=dialect) + return self._to_data_type(column_type, dialect=dialect) return exp.DataType.build("unknown") + def has_column( + self, + table: exp.Table | str, + column: exp.Column | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> bool: + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) + + normalized_column_name = self._normalize_name( + column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize + ) + + table_schema = self.find(normalized_table, raise_on_missing=False) + return normalized_column_name in table_schema if table_schema else False + def _normalize(self, schema: t.Dict) -> t.Dict: """ Normalizes all identifiers in the schema. diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 4ab01dd..c883858 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -121,6 +121,9 @@ class TokenType(AutoName): TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() + TIMESTAMP_S = auto() + TIMESTAMP_MS = auto() + TIMESTAMP_NS = auto() DATETIME = auto() DATETIME64 = auto() DATE = auto() diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index ac9dd81..8feee52 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -189,9 +189,9 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp # we use list here because expression.selects is mutated inside the loop for select in expression.selects.copy(): - explode = select.find(exp.Explode, exp.Posexplode) + explode = select.find(exp.Explode) - if isinstance(explode, (exp.Explode, exp.Posexplode)): + if explode: pos_alias = "" explode_alias = "" @@ -204,7 +204,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp alias = select.replace(exp.alias_(select.this, "", copy=False)) else: alias = select.replace(exp.alias_(select, "")) - explode = alias.find(exp.Explode, exp.Posexplode) + explode = alias.find(exp.Explode) assert explode is_posexplode = isinstance(explode, exp.Posexplode) |