diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-04 09:37:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-04 09:37:14 +0000 |
commit | 7b29f6168bf9fcb2d886447066a9bb51675e5665 (patch) | |
tree | ff74c45f55651c73cce0cd58145667de43db9d12 | |
parent | Releasing debian version 6.2.6-1. (diff) | |
download | sqlglot-7b29f6168bf9fcb2d886447066a9bb51675e5665.tar.xz sqlglot-7b29f6168bf9fcb2d886447066a9bb51675e5665.zip |
Merging upstream version 6.2.8.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | README.md | 6 | ||||
-rw-r--r-- | sqlglot/__init__.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 22 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 2 | ||||
-rw-r--r-- | sqlglot/expressions.py | 68 | ||||
-rw-r--r-- | sqlglot/generator.py | 20 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 13 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 10 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 122 | ||||
-rw-r--r-- | sqlglot/parser.py | 84 | ||||
-rw-r--r-- | sqlglot/tokens.py | 24 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 12 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 9 | ||||
-rw-r--r-- | tests/fixtures/optimizer/merge_subqueries.sql | 8 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 17 | ||||
-rw-r--r-- | tests/fixtures/optimizer/pushdown_projections.sql | 12 | ||||
-rw-r--r-- | tests/test_optimizer.py | 19 | ||||
-rw-r--r-- | tests/test_parser.py | 3 |
22 files changed, 363 insertions, 100 deletions
@@ -1,6 +1,6 @@ # SQLGlot -SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. +SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python. @@ -30,7 +30,7 @@ sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read='duckdb', write='hive') ``` ```sql -SELECT TO_UTC_TIMESTAMP(FROM_UNIXTIME(1618088028295 / 1000, 'yyyy-MM-dd HH:mm:ss'), 'UTC') +SELECT FROM_UNIXTIME(1618088028295 / 1000) ``` SQLGlot can even translate custom time formats. @@ -299,7 +299,7 @@ class Custom(Dialect): } -Dialects["custom"] +Dialect["custom"] ``` ## Benchmarks diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index befbc8a..1f7b28c 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -20,7 +20,7 @@ from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.2.6" +__version__ = "6.2.8" pretty = False diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 432fd8c..40298e7 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -33,10 +33,10 @@ def _date_add_sql(data_type, kind): return func -def _subquery_to_unnest_if_values(self, expression): - if not isinstance(expression.this, exp.Values): - return self.subquery_sql(expression) - rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)] +def _derived_table_values_to_unnest(self, expression): + if not isinstance(expression.unnest().parent, exp.From): + return self.values_sql(expression) + rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)] structs = [] for row in rows: aliases = [ @@ -99,6 +99,7 @@ class BigQuery(Dialect): "QUALIFY": TokenType.QUALIFY, "UNKNOWN": TokenType.NULL, "WINDOW": TokenType.WINDOW, + "NOT DETERMINISTIC": TokenType.VOLATILE, } class Parser(Parser): @@ -140,9 +141,10 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.VariancePop: rename_func("VAR_POP"), - exp.Subquery: _subquery_to_unnest_if_values, + exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, + exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", } TYPE_MAPPING = { @@ -160,6 +162,16 @@ class BigQuery(Dialect): exp.DataType.Type.NVARCHAR: "STRING", } + ROOT_PROPERTIES = { + exp.LanguageProperty, + exp.ReturnsProperty, + exp.VolatilityProperty, + } + + WITH_PROPERTIES = { + exp.AnonymousProperty, + } + def in_unnest_op(self, unnest): return self.sql(unnest) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0ab584e..98dc330 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -77,6 +77,7 @@ class Dialect(metaclass=_Dialect): alias_post_tablesample = False normalize_functions = "upper" null_ordering = "nulls_are_small" + wrap_derived_values = True date_format = "'%Y-%m-%d'" dateint_format = "'%Y%m%d'" @@ -169,6 +170,7 @@ class Dialect(metaclass=_Dialect): "alias_post_tablesample": self.alias_post_tablesample, "normalize_functions": self.normalize_functions, "null_ordering": self.null_ordering, + "wrap_derived_values": self.wrap_derived_values, **opts, } ) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 1b718f7..fb2d900 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -177,6 +177,8 @@ class Snowflake(Dialect): exp.ReturnsProperty, exp.LanguageProperty, exp.SchemaCommentProperty, + exp.ExecuteAsProperty, + exp.VolatilityProperty, } def except_op(self, expression): diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 5446e83..e8da07a 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -47,6 +47,8 @@ def _unix_to_time(self, expression): class Spark(Hive): + wrap_derived_values = False + class Parser(Hive.Parser): FUNCTIONS = { **Hive.Parser.FUNCTIONS, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 599c7db..8cdacce 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -213,21 +213,23 @@ class Expression(metaclass=_Expression): """ return self.find_ancestor(Select) - def walk(self, bfs=True): + def walk(self, bfs=True, prune=None): """ Returns a generator object which visits all nodes in this tree. Args: bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead. + prune ((node, parent, arg_key) -> bool): callable that returns True if + the generator should stop traversing this branch of the tree. Returns: the generator object. """ if bfs: - yield from self.bfs() + yield from self.bfs(prune=prune) else: - yield from self.dfs() + yield from self.dfs(prune=prune) def dfs(self, parent=None, key=None, prune=None): """ @@ -506,6 +508,10 @@ class DerivedTable(Expression): return [select.alias_or_name for select in self.selects] +class UDTF(DerivedTable): + pass + + class Annotation(Expression): arg_types = { "this": True, @@ -652,7 +658,13 @@ class Delete(Expression): class Drop(Expression): - arg_types = {"this": False, "kind": False, "exists": False} + arg_types = { + "this": False, + "kind": False, + "exists": False, + "temporary": False, + "materialized": False, + } class Filter(Expression): @@ -827,7 +839,7 @@ class Join(Expression): return join -class Lateral(DerivedTable): +class Lateral(UDTF): arg_types = {"this": True, "outer": False, "alias": False} @@ -915,6 +927,14 @@ class LanguageProperty(Property): pass +class ExecuteAsProperty(Property): + pass + + +class VolatilityProperty(Property): + arg_types = {"this": True} + + class Properties(Expression): arg_types = {"expressions": True} @@ -1098,7 +1118,7 @@ class Intersect(Union): pass -class Unnest(DerivedTable): +class Unnest(UDTF): arg_types = { "expressions": True, "ordinality": False, @@ -1116,8 +1136,12 @@ class Update(Expression): } -class Values(Expression): - arg_types = {"expressions": True} +class Values(UDTF): + arg_types = { + "expressions": True, + "ordinality": False, + "alias": False, + } class Var(Expression): @@ -2033,23 +2057,17 @@ class Func(Condition): @classmethod def from_arg_list(cls, args): - args_num = len(args) - - all_arg_keys = list(cls.arg_types) - # If this function supports variable length argument treat the last argument as such. - non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys - - args_dict = {} - arg_idx = 0 - for arg_key in non_var_len_arg_keys: - if arg_idx >= args_num: - break - if args[arg_idx] is not None: - args_dict[arg_key] = args[arg_idx] - arg_idx += 1 - - if arg_idx < args_num and cls.is_var_len_args: - args_dict[all_arg_keys[-1]] = args[arg_idx:] + if cls.is_var_len_args: + all_arg_keys = list(cls.arg_types) + # If this function supports variable length argument treat the last argument as such. + non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys + num_non_var = len(non_var_len_arg_keys) + + args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)} + args_dict[all_arg_keys[-1]] = args[num_non_var:] + else: + args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)} + return cls(**args_dict) @classmethod diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 9099307..8b356f3 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -49,10 +49,12 @@ class Generator: exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), - exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", + exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), + exp.VolatilityProperty: lambda self, e: self.sql(e.name), } NULL_ORDERING_SUPPORTED = True @@ -99,6 +101,7 @@ class Generator: "unsupported_messages", "null_ordering", "max_unsupported", + "wrap_derived_values", "_indent", "_replace_backslash", "_escaped_quote_end", @@ -127,6 +130,7 @@ class Generator: null_ordering=None, max_unsupported=3, leading_comma=False, + wrap_derived_values=True, ): import sqlglot @@ -150,6 +154,7 @@ class Generator: self.unsupported_messages = [] self.max_unsupported = max_unsupported self.null_ordering = null_ordering + self.wrap_derived_values = wrap_derived_values self._indent = indent self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end @@ -407,7 +412,9 @@ class Generator: this = self.sql(expression, "this") kind = expression.args["kind"] exists_sql = " IF EXISTS " if expression.args.get("exists") else " " - return f"DROP {kind}{exists_sql}{this}" + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + materialized = " MATERIALIZED" if expression.args.get("materialized") else "" + return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}" def except_sql(self, expression): return self.prepend_ctes( @@ -583,7 +590,14 @@ class Generator: return self.prepend_ctes(expression, sql) def values_sql(self, expression): - return f"VALUES{self.seg('')}{self.expressions(expression)}" + alias = self.sql(expression, "alias") + args = self.expressions(expression) + if not alias: + return f"VALUES{self.seg('')}{args}" + alias = f" AS {alias}" if alias else alias + if self.wrap_derived_values: + return f"(VALUES{self.seg('')}{args}){alias}" + return f"VALUES{self.seg('')}{args}{alias}" def var_sql(self, expression): return self.sql(expression, "this") diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 9d966b7..d29c22b 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False): Returns: sqlglot.Expression: optimized expression """ - merge_ctes(expression, leave_tables_isolated) - merge_derived_tables(expression, leave_tables_isolated) + expression = merge_ctes(expression, leave_tables_isolated) + expression = merge_derived_tables(expression, leave_tables_isolated) return expression @@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False): alias = node_to_replace.alias else: alias = table.name - _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, node_to_replace, alias) - _merge_joins(outer_scope, inner_scope, from_or_join) _merge_expressions(outer_scope, inner_scope, alias) + _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) _pop_cte(inner_scope) + return expression def merge_derived_tables(expression, leave_tables_isolated=False): @@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, subquery, alias) - _merge_joins(outer_scope, inner_scope, from_or_join) _merge_expressions(outer_scope, inner_scope, alias) + _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + return expression def _mergeable(outer_scope, inner_select, leave_tables_isolated): @@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias): continue columns_to_replace = outer_columns.get(projection_name, []) for column in columns_to_replace: - column.replace(expression.unalias()) + column.replace(expression.unalias().copy()) def _merge_where(outer_scope, inner_scope, from_or_join): diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 0bb947a..72ce256 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError from sqlglot.optimizer.schema import ensure_schema from sqlglot.optimizer.scope import traverse_scope -SKIP_QUALIFY = (exp.Unnest, exp.Lateral) - def qualify_columns(expression, schema): """ @@ -35,7 +33,7 @@ def qualify_columns(expression, schema): _expand_group_by(scope, resolver) _expand_order_by(scope) _qualify_columns(scope, resolver) - if not isinstance(scope.expression, SKIP_QUALIFY): + if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) _check_unknown_tables(scope) @@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables): (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: - if isinstance(derived_table, SKIP_QUALIFY): + if isinstance(derived_table, exp.UDTF): continue table_alias = derived_table.args.get("alias") if table_alias: @@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver): if not column_table: column_table = resolver.get_table(column_name) - if not scope.is_subquery and not scope.is_unnest: + if not scope.is_subquery and not scope.is_udtf: if column_name not in resolver.all_columns: raise OptimizeError(f"Unknown column: {column_name}") @@ -296,7 +294,7 @@ def _qualify_outputs(scope): def _check_unknown_tables(scope): - if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery: + if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery: raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index be6cfb9..6332cdd 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,5 +1,4 @@ import itertools -from copy import copy from enum import Enum, auto from sqlglot import exp @@ -12,7 +11,7 @@ class ScopeType(Enum): DERIVED_TABLE = auto() CTE = auto() UNION = auto() - UNNEST = auto() + UDTF = auto() class Scope: @@ -70,14 +69,11 @@ class Scope: self._columns = None self._external_columns = None - def branch(self, expression, scope_type, add_sources=None, **kwargs): + def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" - sources = copy(self.sources) - if add_sources: - sources.update(add_sources) return Scope( expression=expression.unnest(), - sources=sources, + sources={**self.cte_sources, **(chain_sources or {})}, parent=self, scope_type=scope_type, **kwargs, @@ -90,30 +86,21 @@ class Scope: self._derived_tables = [] self._raw_columns = [] - # We'll use this variable to pass state into the dfs generator. - # Whenever we set it to True, we exclude a subtree from traversal. - prune = False - - for node, parent, _ in self.expression.dfs(prune=lambda *_: prune): - prune = False - + for node, parent, _ in self.walk(bfs=False): if node is self.expression: continue - if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): + elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): self._raw_columns.append(node) elif isinstance(node, exp.Table): self._tables.append(node) - elif isinstance(node, (exp.Unnest, exp.Lateral)): + elif isinstance(node, exp.UDTF): self._derived_tables.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) - prune = True elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): self._derived_tables.append(node) - prune = True elif isinstance(node, exp.Subqueryable): self._subqueries.append(node) - prune = True self._collected = True @@ -121,6 +108,43 @@ class Scope: if not self._collected: self._collect() + def walk(self, bfs=True): + return walk_in_scope(self.expression, bfs=bfs) + + def find(self, *expression_types, bfs=True): + """ + Returns the first node in this scope which matches at least one of the specified types. + + This does NOT traverse into subscopes. + + Args: + expression_types (type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Returns: + exp.Expression: the node which matches the criteria or None if no node matching + the criteria was found. + """ + return next(self.find_all(*expression_types, bfs=bfs), None) + + def find_all(self, *expression_types, bfs=True): + """ + Returns a generator object which visits all nodes in this scope and only yields those that + match at least one of the specified expression types. + + This does NOT traverse into subscopes. + + Args: + expression_types (type): the expression type(s) to match. + bfs (bool): True to use breadth-first search, False to use depth-first. + + Yields: + exp.Expression: nodes + """ + for expression, _, _ in self.walk(bfs=bfs): + if isinstance(expression, expression_types): + yield expression + def replace(self, old, new): """ Replace `old` with `new`. @@ -247,6 +271,16 @@ class Scope: return self._selected_sources @property + def cte_sources(self): + """ + Sources that are CTEs. + + Returns: + dict[str, Scope]: Mapping of source alias to Scope + """ + return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte} + + @property def selects(self): """ Select expressions of this scope. @@ -313,9 +347,9 @@ class Scope: return self.scope_type == ScopeType.ROOT @property - def is_unnest(self): - """Determine if this scope is an unnest""" - return self.scope_type == ScopeType.UNNEST + def is_udtf(self): + """Determine if this scope is a UDTF (User Defined Table Function)""" + return self.scope_type == ScopeType.UDTF @property def is_correlated_subquery(self): @@ -348,7 +382,7 @@ class Scope: Scope: scope instances in depth-first-search post-order """ for child_scope in itertools.chain( - self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes + self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes ): yield from child_scope.traverse() yield self @@ -399,7 +433,7 @@ def _traverse_scope(scope): yield from _traverse_select(scope) elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) - elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)): + elif isinstance(scope.expression, exp.UDTF): pass elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) @@ -410,8 +444,8 @@ def _traverse_scope(scope): def _traverse_select(scope): yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) - yield from _traverse_subqueries(scope) yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) + yield from _traverse_subqueries(scope) _add_table_sources(scope) @@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): top = None for child_scope in _traverse_scope( scope.branch( - derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this, - add_sources=sources if scope_type == ScopeType.CTE else None, + derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, + chain_sources=sources if scope_type == ScopeType.CTE else None, outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type, + scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type, ) ): yield child_scope @@ -483,3 +517,35 @@ def _traverse_subqueries(scope): yield child_scope top = child_scope scope.subquery_scopes.append(top) + + +def walk_in_scope(expression, bfs=True): + """ + Returns a generator object which visits all nodes in the syntrax tree, stopping at + nodes that start child scopes. + + Args: + expression (exp.Expression): + bfs (bool): if set to True the BFS traversal order will be applied, + otherwise the DFS traversal will be used instead. + + Yields: + tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key + """ + # We'll use this variable to pass state into the dfs generator. + # Whenever we set it to True, we exclude a subtree from traversal. + prune = False + + for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): + prune = False + + yield node, parent, key + + if node is expression: + continue + elif isinstance(node, exp.CTE): + prune = True + elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): + prune = True + elif isinstance(node, exp.Subqueryable): + prune = True diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 72bad92..5f20afc 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -126,6 +126,8 @@ class Parser: TokenType.CONSTRAINT, TokenType.DEFAULT, TokenType.DELETE, + TokenType.DETERMINISTIC, + TokenType.EXECUTE, TokenType.ENGINE, TokenType.ESCAPE, TokenType.EXPLAIN, @@ -139,6 +141,7 @@ class Parser: TokenType.IF, TokenType.INDEX, TokenType.ISNULL, + TokenType.IMMUTABLE, TokenType.INTERVAL, TokenType.LAZY, TokenType.LANGUAGE, @@ -163,6 +166,7 @@ class Parser: TokenType.SEED, TokenType.SET, TokenType.SHOW, + TokenType.STABLE, TokenType.STORED, TokenType.TABLE, TokenType.TABLE_FORMAT, @@ -175,6 +179,8 @@ class Parser: TokenType.UNIQUE, TokenType.UNPIVOT, TokenType.PROPERTIES, + TokenType.PROCEDURE, + TokenType.VOLATILE, *SUBQUERY_PREDICATES, *TYPE_TOKENS, } @@ -204,7 +210,7 @@ class Parser: TokenType.DATETIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, - *NESTED_TYPE_TOKENS, + *TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -379,6 +385,13 @@ class Parser: TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), + TokenType.EXECUTE: lambda self: self._parse_execute_as(), + TokenType.DETERMINISTIC: lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + ), + TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")), + TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")), + TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")), } CONSTRAINT_PARSERS = { @@ -418,7 +431,7 @@ class Parser: MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) - CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX} + CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE} STRICT_CAST = True @@ -615,18 +628,20 @@ class Parser: return expression def _parse_drop(self): - if self._match(TokenType.TABLE): - kind = "TABLE" - elif self._match(TokenType.VIEW): - kind = "VIEW" - else: - self.raise_error("Expected TABLE or View") + temporary = self._match(TokenType.TEMPORARY) + materialized = self._match(TokenType.MATERIALIZED) + kind = self._match_set(self.CREATABLES) and self._prev.text + if not kind: + self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE") + return return self.expression( exp.Drop, exists=self._parse_exists(), this=self._parse_table(schema=True), kind=kind, + temporary=temporary, + materialized=materialized, ) def _parse_exists(self, not_=False): @@ -644,14 +659,15 @@ class Parser: create_token = self._match_set(self.CREATABLES) and self._prev if not create_token: - self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION") + self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE") + return exists = self._parse_exists(not_=True) this = None expression = None properties = None - if create_token.token_type == TokenType.FUNCTION: + if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function() properties = self._parse_properties() if self._match(TokenType.ALIAS): @@ -747,7 +763,9 @@ class Parser: if is_table: if self._match(TokenType.LT): value = self.expression( - exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs) + exp.Schema, + this="TABLE", + expressions=self._parse_csv(self._parse_struct_kwargs), ) if not self._match(TokenType.GT): self.raise_error("Expecting >") @@ -763,6 +781,14 @@ class Parser: is_table=is_table, ) + def _parse_execute_as(self): + self._match(TokenType.ALIAS) + return self.expression( + exp.ExecuteAsProperty, + this=exp.Literal.string("EXECUTE AS"), + value=self._parse_var(), + ) + def _parse_properties(self): properties = [] @@ -997,7 +1023,12 @@ class Parser: ) def _parse_subquery(self, this): - return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias()) + return self.expression( + exp.Subquery, + this=this, + pivots=self._parse_pivots(), + alias=self._parse_table_alias(), + ) def _parse_query_modifiers(self, this): if not isinstance(this, self.MODIFIABLES): @@ -1118,6 +1149,11 @@ class Parser: if unnest: return unnest + values = self._parse_derived_table_values() + + if values: + return values + subquery = self._parse_select(table=True) if subquery: @@ -1186,6 +1222,24 @@ class Parser: alias=alias, ) + def _parse_derived_table_values(self): + is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) + if not is_derived and not self._match(TokenType.VALUES): + return None + + expressions = self._parse_csv(self._parse_value) + + if is_derived: + self._match_r_paren() + + alias = self._parse_table_alias() + + return self.expression( + exp.Values, + expressions=expressions, + alias=alias, + ) + def _parse_table_sample(self): if not self._match(TokenType.TABLE_SAMPLE): return None @@ -1700,7 +1754,11 @@ class Parser: return self._parse_window(this) def _parse_user_defined_function(self): - this = self._parse_var() + this = self._parse_id_var() + + while self._match(TokenType.DOT): + this = self.expression(exp.Dot, this=this, expression=self._parse_id_var()) + if not self._match(TokenType.L_PAREN): return this expressions = self._parse_csv(self._parse_udf_kwarg) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index c81f0db..39bf421 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -136,6 +136,7 @@ class TokenType(AutoName): DEFAULT = auto() DELETE = auto() DESC = auto() + DETERMINISTIC = auto() DISTINCT = auto() DISTRIBUTE_BY = auto() DROP = auto() @@ -144,6 +145,7 @@ class TokenType(AutoName): ENGINE = auto() ESCAPE = auto() EXCEPT = auto() + EXECUTE = auto() EXISTS = auto() EXPLAIN = auto() FALSE = auto() @@ -167,6 +169,7 @@ class TokenType(AutoName): IF = auto() IGNORE_NULLS = auto() ILIKE = auto() + IMMUTABLE = auto() IN = auto() INDEX = auto() INNER = auto() @@ -215,6 +218,7 @@ class TokenType(AutoName): PLACEHOLDER = auto() PRECEDING = auto() PRIMARY_KEY = auto() + PROCEDURE = auto() PROPERTIES = auto() QUALIFY = auto() QUOTE = auto() @@ -238,6 +242,7 @@ class TokenType(AutoName): SIMILAR_TO = auto() SOME = auto() SORT_BY = auto() + STABLE = auto() STORED = auto() STRUCT = auto() TABLE_FORMAT = auto() @@ -258,6 +263,7 @@ class TokenType(AutoName): USING = auto() VALUES = auto() VIEW = auto() + VOLATILE = auto() WHEN = auto() WHERE = auto() WINDOW = auto() @@ -430,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer): "DEFAULT": TokenType.DEFAULT, "DELETE": TokenType.DELETE, "DESC": TokenType.DESC, + "DETERMINISTIC": TokenType.DETERMINISTIC, "DISTINCT": TokenType.DISTINCT, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, "DROP": TokenType.DROP, @@ -438,6 +445,7 @@ class Tokenizer(metaclass=_Tokenizer): "ENGINE": TokenType.ENGINE, "ESCAPE": TokenType.ESCAPE, "EXCEPT": TokenType.EXCEPT, + "EXECUTE": TokenType.EXECUTE, "EXISTS": TokenType.EXISTS, "EXPLAIN": TokenType.EXPLAIN, "FALSE": TokenType.FALSE, @@ -456,6 +464,7 @@ class Tokenizer(metaclass=_Tokenizer): "HAVING": TokenType.HAVING, "IF": TokenType.IF, "ILIKE": TokenType.ILIKE, + "IMMUTABLE": TokenType.IMMUTABLE, "IGNORE NULLS": TokenType.IGNORE_NULLS, "IN": TokenType.IN, "INDEX": TokenType.INDEX, @@ -504,6 +513,7 @@ class Tokenizer(metaclass=_Tokenizer): "PIVOT": TokenType.PIVOT, "PRECEDING": TokenType.PRECEDING, "PRIMARY KEY": TokenType.PRIMARY_KEY, + "PROCEDURE": TokenType.PROCEDURE, "RANGE": TokenType.RANGE, "RECURSIVE": TokenType.RECURSIVE, "REGEXP": TokenType.RLIKE, @@ -522,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer): "SHOW": TokenType.SHOW, "SOME": TokenType.SOME, "SORT BY": TokenType.SORT_BY, + "STABLE": TokenType.STABLE, "STORED": TokenType.STORED, "TABLE": TokenType.TABLE, "TABLE_FORMAT": TokenType.TABLE_FORMAT, @@ -542,6 +553,7 @@ class Tokenizer(metaclass=_Tokenizer): "USING": TokenType.USING, "VALUES": TokenType.VALUES, "VIEW": TokenType.VIEW, + "VOLATILE": TokenType.VOLATILE, "WHEN": TokenType.WHEN, "WHERE": TokenType.WHERE, "WITH": TokenType.WITH, @@ -637,6 +649,7 @@ class Tokenizer(metaclass=_Tokenizer): "_char", "_end", "_peek", + "_prev_token_type", ) def __init__(self): @@ -657,6 +670,7 @@ class Tokenizer(metaclass=_Tokenizer): self._char = None self._end = None self._peek = None + self._prev_token_type = None def tokenize(self, sql): self.reset() @@ -706,8 +720,8 @@ class Tokenizer(metaclass=_Tokenizer): return self.sql[self._start : self._current] def _add(self, token_type, text=None): - text = self._text if text is None else text - self.tokens.append(Token(token_type, text, self._line, self._col)) + self._prev_token_type = token_type + self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col)) if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON): self._start = self._current @@ -910,7 +924,11 @@ class Tokenizer(metaclass=_Tokenizer): self._advance() else: break - self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR)) + self._add( + TokenType.VAR + if self._prev_token_type == TokenType.PARAMETER + else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) + ) def _extract_string(self, delimiter): text = "" diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index c929e59..7110eac 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -239,7 +239,7 @@ class TestBigQuery(Validator): self.validate_all( "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", write={ - "spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + "spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)", "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])", "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", }, @@ -253,7 +253,7 @@ class TestBigQuery(Validator): def test_user_defined_functions(self): self.validate_identity( - "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'" + "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" ) self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index e0ec824..a9a313c 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1009,7 +1009,7 @@ class TestDialect(Validator): self.validate_all( "SELECT * FROM VALUES ('x'), ('y') AS t(z)", write={ - "spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)", + "spark": "SELECT * FROM VALUES ('x'), ('y') AS t(z)", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index b7e39a7..2145966 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -293,3 +293,15 @@ class TestSnowflake(Validator): "bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1", }, ) + self.validate_all( + "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'", + write={ + "snowflake": "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'", + }, + ) + + def test_stored_procedures(self): + self.validate_identity("CALL a.b.c(x, y)") + self.validate_identity( + "CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'" + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 2654be1..a0de281 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -50,6 +50,7 @@ a.B() a['x'].C() int.x map.x +a.b.INT(1.234) x IN (-1, 1) x IN ('a', 'a''a') x IN ((1)) @@ -357,6 +358,7 @@ SELECT * REPLACE (a + 1 AS b, b AS C) SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals) +SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals) WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2 WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2 WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2 @@ -444,6 +446,8 @@ CREATE OR REPLACE TEMPORARY VIEW x AS SELECT * CREATE TEMPORARY VIEW x AS SELECT a FROM d CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y +CREATE MATERIALIZED VIEW x.y.z AS SELECT a FROM b +DROP MATERIALIZED VIEW x.y.z CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3)) CREATE TABLE z (end INT) CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3)) @@ -471,10 +475,13 @@ CREATE FUNCTION f AS 'g' CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1' CREATE FUNCTION a() LANGUAGE sql CREATE FUNCTION a() LANGUAGE sql RETURNS INT +CREATE FUNCTION a.b.c() +DROP FUNCTION a.b.c (INT) CREATE INDEX abc ON t (a) CREATE INDEX abc ON t (a, b, b) CREATE UNIQUE INDEX abc ON t (a, b, b) CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b) +DROP INDEX a.b.c CACHE TABLE x CACHE LAZY TABLE x CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') @@ -484,6 +491,8 @@ CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE TABLE x AS (SELECT 1 AS y) CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2') +CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END' +DROP PROCEDURE a.b.c (INT) INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index 35aed3b..e13d3b3 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -97,3 +97,11 @@ WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM -- Nested CTE SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x); SELECT x.a AS a, x.b AS b FROM x AS x; + +-- Inner select is an expression +SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x; +SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; + +-- CTE select is an expression +WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x; +SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 0bb742b..eb6761a 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -137,3 +137,20 @@ SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg" FROM "x" AS "x"; + +SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); +SELECT + "tab"."cola" AS "cola", + "tab"."colb" AS "colb" +FROM (VALUES + (1, 'test'), + (2, 'test2')) AS "tab"("cola", "colb"); + +# dialect: spark +SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); +SELECT + `tab`.`cola` AS `cola`, + `tab`.`colb` AS `colb` +FROM VALUES + (1, 'test'), + (2, 'test2') AS `tab`(`cola`, `colb`); diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index 9deceb6..b03ffab 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -39,3 +39,15 @@ SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a); SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0"; + +SELECT x FROM (VALUES(1, 2)) AS q(x, y); +SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); + +SELECT x FROM UNNEST([1, 2]) AS q(x, y); +SELECT q.x AS x FROM UNNEST(ARRAY(1, 2)) AS q(x, y); + +WITH t1 AS (SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)]) AS "q"("cola", "colb")) SELECT cola FROM t1; +WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS colb))) AS "q"("cola", "colb")) SELECT t1.cola AS cola FROM t1; + +SELECT x FROM VALUES(1, 2) AS q(x, y); +SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 8d4aecc..aad84ed 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.schema import MappingSchema, ensure_schema -from sqlglot.optimizer.scope import build_scope, traverse_scope +from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures @@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ON s.b = r.b WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) """ - for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()): + expression = parse_one(sql) + for scopes in traverse_scope(expression), list(build_scope(expression).traverse()): self.assertEqual(len(scopes), 5) self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) @@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(len(scopes[4].source_columns("r")), 2) self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) + self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") + self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"}) + + # Check that we can walk in scope from an arbitrary node + self.assertEqual( + {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)}, + {"s.b"}, + ) + def test_literal_type_annotation(self): tests = { "SELECT 5": exp.DataType.Type.INT, diff --git a/tests/test_parser.py b/tests/test_parser.py index 4c46531..4e86516 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -122,6 +122,9 @@ class TestParser(unittest.TestCase): def test_parameter(self): self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1") + def test_var(self): + self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") + def test_annotations(self): expression = parse_one( """ |