summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-10-04 09:37:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-10-04 09:37:14 +0000
commit7b29f6168bf9fcb2d886447066a9bb51675e5665 (patch)
treeff74c45f55651c73cce0cd58145667de43db9d12
parentReleasing debian version 6.2.6-1. (diff)
downloadsqlglot-7b29f6168bf9fcb2d886447066a9bb51675e5665.tar.xz
sqlglot-7b29f6168bf9fcb2d886447066a9bb51675e5665.zip
Merging upstream version 6.2.8.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--README.md6
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dialects/bigquery.py22
-rw-r--r--sqlglot/dialects/dialect.py2
-rw-r--r--sqlglot/dialects/snowflake.py2
-rw-r--r--sqlglot/dialects/spark.py2
-rw-r--r--sqlglot/expressions.py68
-rw-r--r--sqlglot/generator.py20
-rw-r--r--sqlglot/optimizer/merge_subqueries.py13
-rw-r--r--sqlglot/optimizer/qualify_columns.py10
-rw-r--r--sqlglot/optimizer/scope.py122
-rw-r--r--sqlglot/parser.py84
-rw-r--r--sqlglot/tokens.py24
-rw-r--r--tests/dialects/test_bigquery.py4
-rw-r--r--tests/dialects/test_dialect.py2
-rw-r--r--tests/dialects/test_snowflake.py12
-rw-r--r--tests/fixtures/identity.sql9
-rw-r--r--tests/fixtures/optimizer/merge_subqueries.sql8
-rw-r--r--tests/fixtures/optimizer/optimizer.sql17
-rw-r--r--tests/fixtures/optimizer/pushdown_projections.sql12
-rw-r--r--tests/test_optimizer.py19
-rw-r--r--tests/test_parser.py3
22 files changed, 363 insertions, 100 deletions
diff --git a/README.md b/README.md
index 5ab4507..be59b84 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# SQLGlot
-SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
+SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python.
@@ -30,7 +30,7 @@ sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read='duckdb', write='hive')
```
```sql
-SELECT TO_UTC_TIMESTAMP(FROM_UNIXTIME(1618088028295 / 1000, 'yyyy-MM-dd HH:mm:ss'), 'UTC')
+SELECT FROM_UNIXTIME(1618088028295 / 1000)
```
SQLGlot can even translate custom time formats.
@@ -299,7 +299,7 @@ class Custom(Dialect):
}
-Dialects["custom"]
+Dialect["custom"]
```
## Benchmarks
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index befbc8a..1f7b28c 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "6.2.6"
+__version__ = "6.2.8"
pretty = False
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 432fd8c..40298e7 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -33,10 +33,10 @@ def _date_add_sql(data_type, kind):
return func
-def _subquery_to_unnest_if_values(self, expression):
- if not isinstance(expression.this, exp.Values):
- return self.subquery_sql(expression)
- rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
+def _derived_table_values_to_unnest(self, expression):
+ if not isinstance(expression.unnest().parent, exp.From):
+ return self.values_sql(expression)
+ rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
@@ -99,6 +99,7 @@ class BigQuery(Dialect):
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
+ "NOT DETERMINISTIC": TokenType.VOLATILE,
}
class Parser(Parser):
@@ -140,9 +141,10 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"),
- exp.Subquery: _subquery_to_unnest_if_values,
+ exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
+ exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
@@ -160,6 +162,16 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
}
+ ROOT_PROPERTIES = {
+ exp.LanguageProperty,
+ exp.ReturnsProperty,
+ exp.VolatilityProperty,
+ }
+
+ WITH_PROPERTIES = {
+ exp.AnonymousProperty,
+ }
+
def in_unnest_op(self, unnest):
return self.sql(unnest)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 0ab584e..98dc330 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -77,6 +77,7 @@ class Dialect(metaclass=_Dialect):
alias_post_tablesample = False
normalize_functions = "upper"
null_ordering = "nulls_are_small"
+ wrap_derived_values = True
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
@@ -169,6 +170,7 @@ class Dialect(metaclass=_Dialect):
"alias_post_tablesample": self.alias_post_tablesample,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
+ "wrap_derived_values": self.wrap_derived_values,
**opts,
}
)
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 1b718f7..fb2d900 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -177,6 +177,8 @@ class Snowflake(Dialect):
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
+ exp.ExecuteAsProperty,
+ exp.VolatilityProperty,
}
def except_op(self, expression):
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 5446e83..e8da07a 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -47,6 +47,8 @@ def _unix_to_time(self, expression):
class Spark(Hive):
+ wrap_derived_values = False
+
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 599c7db..8cdacce 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -213,21 +213,23 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
- def walk(self, bfs=True):
+ def walk(self, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in this tree.
Args:
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
+ prune ((node, parent, arg_key) -> bool): callable that returns True if
+ the generator should stop traversing this branch of the tree.
Returns:
the generator object.
"""
if bfs:
- yield from self.bfs()
+ yield from self.bfs(prune=prune)
else:
- yield from self.dfs()
+ yield from self.dfs(prune=prune)
def dfs(self, parent=None, key=None, prune=None):
"""
@@ -506,6 +508,10 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects]
+class UDTF(DerivedTable):
+ pass
+
+
class Annotation(Expression):
arg_types = {
"this": True,
@@ -652,7 +658,13 @@ class Delete(Expression):
class Drop(Expression):
- arg_types = {"this": False, "kind": False, "exists": False}
+ arg_types = {
+ "this": False,
+ "kind": False,
+ "exists": False,
+ "temporary": False,
+ "materialized": False,
+ }
class Filter(Expression):
@@ -827,7 +839,7 @@ class Join(Expression):
return join
-class Lateral(DerivedTable):
+class Lateral(UDTF):
arg_types = {"this": True, "outer": False, "alias": False}
@@ -915,6 +927,14 @@ class LanguageProperty(Property):
pass
+class ExecuteAsProperty(Property):
+ pass
+
+
+class VolatilityProperty(Property):
+ arg_types = {"this": True}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -1098,7 +1118,7 @@ class Intersect(Union):
pass
-class Unnest(DerivedTable):
+class Unnest(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
@@ -1116,8 +1136,12 @@ class Update(Expression):
}
-class Values(Expression):
- arg_types = {"expressions": True}
+class Values(UDTF):
+ arg_types = {
+ "expressions": True,
+ "ordinality": False,
+ "alias": False,
+ }
class Var(Expression):
@@ -2033,23 +2057,17 @@ class Func(Condition):
@classmethod
def from_arg_list(cls, args):
- args_num = len(args)
-
- all_arg_keys = list(cls.arg_types)
- # If this function supports variable length argument treat the last argument as such.
- non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
-
- args_dict = {}
- arg_idx = 0
- for arg_key in non_var_len_arg_keys:
- if arg_idx >= args_num:
- break
- if args[arg_idx] is not None:
- args_dict[arg_key] = args[arg_idx]
- arg_idx += 1
-
- if arg_idx < args_num and cls.is_var_len_args:
- args_dict[all_arg_keys[-1]] = args[arg_idx:]
+ if cls.is_var_len_args:
+ all_arg_keys = list(cls.arg_types)
+ # If this function supports variable length argument treat the last argument as such.
+ non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
+ num_non_var = len(non_var_len_arg_keys)
+
+ args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)}
+ args_dict[all_arg_keys[-1]] = args[num_non_var:]
+ else:
+ args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)}
+
return cls(**args_dict)
@classmethod
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 9099307..8b356f3 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -49,10 +49,12 @@ class Generator:
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
- exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
+ exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
+ exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
NULL_ORDERING_SUPPORTED = True
@@ -99,6 +101,7 @@ class Generator:
"unsupported_messages",
"null_ordering",
"max_unsupported",
+ "wrap_derived_values",
"_indent",
"_replace_backslash",
"_escaped_quote_end",
@@ -127,6 +130,7 @@ class Generator:
null_ordering=None,
max_unsupported=3,
leading_comma=False,
+ wrap_derived_values=True,
):
import sqlglot
@@ -150,6 +154,7 @@ class Generator:
self.unsupported_messages = []
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
+ self.wrap_derived_values = wrap_derived_values
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
@@ -407,7 +412,9 @@ class Generator:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
- return f"DROP {kind}{exists_sql}{this}"
+ temporary = " TEMPORARY" if expression.args.get("temporary") else ""
+ materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
def except_sql(self, expression):
return self.prepend_ctes(
@@ -583,7 +590,14 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression):
- return f"VALUES{self.seg('')}{self.expressions(expression)}"
+ alias = self.sql(expression, "alias")
+ args = self.expressions(expression)
+ if not alias:
+ return f"VALUES{self.seg('')}{args}"
+ alias = f" AS {alias}" if alias else alias
+ if self.wrap_derived_values:
+ return f"(VALUES{self.seg('')}{args}){alias}"
+ return f"VALUES{self.seg('')}{args}{alias}"
def var_sql(self, expression):
return self.sql(expression, "this")
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 9d966b7..d29c22b 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Returns:
sqlglot.Expression: optimized expression
"""
- merge_ctes(expression, leave_tables_isolated)
- merge_derived_tables(expression, leave_tables_isolated)
+ expression = merge_ctes(expression, leave_tables_isolated)
+ expression = merge_derived_tables(expression, leave_tables_isolated)
return expression
@@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
alias = node_to_replace.alias
else:
alias = table.name
-
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
- _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
+ return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
@@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
- _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
+ return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
@@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
continue
columns_to_replace = outer_columns.get(projection_name, [])
for column in columns_to_replace:
- column.replace(expression.unalias())
+ column.replace(expression.unalias().copy())
def _merge_where(outer_scope, inner_scope, from_or_join):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 0bb947a..72ce256 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
-SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
-
def qualify_columns(expression, schema):
"""
@@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
- if not isinstance(scope.expression, SKIP_QUALIFY):
+ if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
@@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
- if isinstance(derived_table, SKIP_QUALIFY):
+ if isinstance(derived_table, exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
- if not scope.is_subquery and not scope.is_unnest:
+ if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
@@ -296,7 +294,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
- if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
+ if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index be6cfb9..6332cdd 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,5 +1,4 @@
import itertools
-from copy import copy
from enum import Enum, auto
from sqlglot import exp
@@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
- UNNEST = auto()
+ UDTF = auto()
class Scope:
@@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
- def branch(self, expression, scope_type, add_sources=None, **kwargs):
+ def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
- sources = copy(self.sources)
- if add_sources:
- sources.update(add_sources)
return Scope(
expression=expression.unnest(),
- sources=sources,
+ sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
- # We'll use this variable to pass state into the dfs generator.
- # Whenever we set it to True, we exclude a subtree from traversal.
- prune = False
-
- for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
- prune = False
-
+ for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
- if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+ elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
- elif isinstance(node, (exp.Unnest, exp.Lateral)):
+ elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
- prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
- prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
- prune = True
self._collected = True
@@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
+ def walk(self, bfs=True):
+ return walk_in_scope(self.expression, bfs=bfs)
+
+ def find(self, *expression_types, bfs=True):
+ """
+ Returns the first node in this scope which matches at least one of the specified types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Returns:
+ exp.Expression: the node which matches the criteria or None if no node matching
+ the criteria was found.
+ """
+ return next(self.find_all(*expression_types, bfs=bfs), None)
+
+ def find_all(self, *expression_types, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this scope and only yields those that
+ match at least one of the specified expression types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Yields:
+ exp.Expression: nodes
+ """
+ for expression, _, _ in self.walk(bfs=bfs):
+ if isinstance(expression, expression_types):
+ yield expression
+
def replace(self, old, new):
"""
Replace `old` with `new`.
@@ -247,6 +271,16 @@ class Scope:
return self._selected_sources
@property
+ def cte_sources(self):
+ """
+ Sources that are CTEs.
+
+ Returns:
+ dict[str, Scope]: Mapping of source alias to Scope
+ """
+ return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
+
+ @property
def selects(self):
"""
Select expressions of this scope.
@@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
- def is_unnest(self):
- """Determine if this scope is an unnest"""
- return self.scope_type == ScopeType.UNNEST
+ def is_udtf(self):
+ """Determine if this scope is a UDTF (User Defined Table Function)"""
+ return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
+ self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
- elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
+ elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
- yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
+ yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
- derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
- add_sources=sources if scope_type == ScopeType.CTE else None,
+ derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
+ chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
+ scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
+
+
+def walk_in_scope(expression, bfs=True):
+ """
+ Returns a generator object which visits all nodes in the syntrax tree, stopping at
+ nodes that start child scopes.
+
+ Args:
+ expression (exp.Expression):
+ bfs (bool): if set to True the BFS traversal order will be applied,
+ otherwise the DFS traversal will be used instead.
+
+ Yields:
+ tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
+ """
+ # We'll use this variable to pass state into the dfs generator.
+ # Whenever we set it to True, we exclude a subtree from traversal.
+ prune = False
+
+ for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
+ prune = False
+
+ yield node, parent, key
+
+ if node is expression:
+ continue
+ elif isinstance(node, exp.CTE):
+ prune = True
+ elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
+ prune = True
+ elif isinstance(node, exp.Subqueryable):
+ prune = True
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 72bad92..5f20afc 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -126,6 +126,8 @@ class Parser:
TokenType.CONSTRAINT,
TokenType.DEFAULT,
TokenType.DELETE,
+ TokenType.DETERMINISTIC,
+ TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
TokenType.EXPLAIN,
@@ -139,6 +141,7 @@ class Parser:
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
+ TokenType.IMMUTABLE,
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LANGUAGE,
@@ -163,6 +166,7 @@ class Parser:
TokenType.SEED,
TokenType.SET,
TokenType.SHOW,
+ TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
TokenType.TABLE_FORMAT,
@@ -175,6 +179,8 @@ class Parser:
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
+ TokenType.PROCEDURE,
+ TokenType.VOLATILE,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
}
@@ -204,7 +210,7 @@ class Parser:
TokenType.DATETIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
- *NESTED_TYPE_TOKENS,
+ *TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@@ -379,6 +385,13 @@ class Parser:
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
+ TokenType.EXECUTE: lambda self: self._parse_execute_as(),
+ TokenType.DETERMINISTIC: lambda self: self.expression(
+ exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
+ ),
+ TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
+ TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
+ TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
}
CONSTRAINT_PARSERS = {
@@ -418,7 +431,7 @@ class Parser:
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
- CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
+ CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE}
STRICT_CAST = True
@@ -615,18 +628,20 @@ class Parser:
return expression
def _parse_drop(self):
- if self._match(TokenType.TABLE):
- kind = "TABLE"
- elif self._match(TokenType.VIEW):
- kind = "VIEW"
- else:
- self.raise_error("Expected TABLE or View")
+ temporary = self._match(TokenType.TEMPORARY)
+ materialized = self._match(TokenType.MATERIALIZED)
+ kind = self._match_set(self.CREATABLES) and self._prev.text
+ if not kind:
+ self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
+ return
return self.expression(
exp.Drop,
exists=self._parse_exists(),
this=self._parse_table(schema=True),
kind=kind,
+ temporary=temporary,
+ materialized=materialized,
)
def _parse_exists(self, not_=False):
@@ -644,14 +659,15 @@ class Parser:
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
- self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION")
+ self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
+ return
exists = self._parse_exists(not_=True)
this = None
expression = None
properties = None
- if create_token.token_type == TokenType.FUNCTION:
+ if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
@@ -747,7 +763,9 @@ class Parser:
if is_table:
if self._match(TokenType.LT):
value = self.expression(
- exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
+ exp.Schema,
+ this="TABLE",
+ expressions=self._parse_csv(self._parse_struct_kwargs),
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@@ -763,6 +781,14 @@ class Parser:
is_table=is_table,
)
+ def _parse_execute_as(self):
+ self._match(TokenType.ALIAS)
+ return self.expression(
+ exp.ExecuteAsProperty,
+ this=exp.Literal.string("EXECUTE AS"),
+ value=self._parse_var(),
+ )
+
def _parse_properties(self):
properties = []
@@ -997,7 +1023,12 @@ class Parser:
)
def _parse_subquery(self, this):
- return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
+ return self.expression(
+ exp.Subquery,
+ this=this,
+ pivots=self._parse_pivots(),
+ alias=self._parse_table_alias(),
+ )
def _parse_query_modifiers(self, this):
if not isinstance(this, self.MODIFIABLES):
@@ -1118,6 +1149,11 @@ class Parser:
if unnest:
return unnest
+ values = self._parse_derived_table_values()
+
+ if values:
+ return values
+
subquery = self._parse_select(table=True)
if subquery:
@@ -1186,6 +1222,24 @@ class Parser:
alias=alias,
)
+ def _parse_derived_table_values(self):
+ is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
+ if not is_derived and not self._match(TokenType.VALUES):
+ return None
+
+ expressions = self._parse_csv(self._parse_value)
+
+ if is_derived:
+ self._match_r_paren()
+
+ alias = self._parse_table_alias()
+
+ return self.expression(
+ exp.Values,
+ expressions=expressions,
+ alias=alias,
+ )
+
def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE):
return None
@@ -1700,7 +1754,11 @@ class Parser:
return self._parse_window(this)
def _parse_user_defined_function(self):
- this = self._parse_var()
+ this = self._parse_id_var()
+
+ while self._match(TokenType.DOT):
+ this = self.expression(exp.Dot, this=this, expression=self._parse_id_var())
+
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index c81f0db..39bf421 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -136,6 +136,7 @@ class TokenType(AutoName):
DEFAULT = auto()
DELETE = auto()
DESC = auto()
+ DETERMINISTIC = auto()
DISTINCT = auto()
DISTRIBUTE_BY = auto()
DROP = auto()
@@ -144,6 +145,7 @@ class TokenType(AutoName):
ENGINE = auto()
ESCAPE = auto()
EXCEPT = auto()
+ EXECUTE = auto()
EXISTS = auto()
EXPLAIN = auto()
FALSE = auto()
@@ -167,6 +169,7 @@ class TokenType(AutoName):
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
+ IMMUTABLE = auto()
IN = auto()
INDEX = auto()
INNER = auto()
@@ -215,6 +218,7 @@ class TokenType(AutoName):
PLACEHOLDER = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
+ PROCEDURE = auto()
PROPERTIES = auto()
QUALIFY = auto()
QUOTE = auto()
@@ -238,6 +242,7 @@ class TokenType(AutoName):
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
+ STABLE = auto()
STORED = auto()
STRUCT = auto()
TABLE_FORMAT = auto()
@@ -258,6 +263,7 @@ class TokenType(AutoName):
USING = auto()
VALUES = auto()
VIEW = auto()
+ VOLATILE = auto()
WHEN = auto()
WHERE = auto()
WINDOW = auto()
@@ -430,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
+ "DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DROP": TokenType.DROP,
@@ -438,6 +445,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ENGINE": TokenType.ENGINE,
"ESCAPE": TokenType.ESCAPE,
"EXCEPT": TokenType.EXCEPT,
+ "EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE,
@@ -456,6 +464,7 @@ class Tokenizer(metaclass=_Tokenizer):
"HAVING": TokenType.HAVING,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
+ "IMMUTABLE": TokenType.IMMUTABLE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
@@ -504,6 +513,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
+ "PROCEDURE": TokenType.PROCEDURE,
"RANGE": TokenType.RANGE,
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
@@ -522,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
+ "STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
"TABLE": TokenType.TABLE,
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
@@ -542,6 +553,7 @@ class Tokenizer(metaclass=_Tokenizer):
"USING": TokenType.USING,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
+ "VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
"WITH": TokenType.WITH,
@@ -637,6 +649,7 @@ class Tokenizer(metaclass=_Tokenizer):
"_char",
"_end",
"_peek",
+ "_prev_token_type",
)
def __init__(self):
@@ -657,6 +670,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._char = None
self._end = None
self._peek = None
+ self._prev_token_type = None
def tokenize(self, sql):
self.reset()
@@ -706,8 +720,8 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[self._start : self._current]
def _add(self, token_type, text=None):
- text = self._text if text is None else text
- self.tokens.append(Token(token_type, text, self._line, self._col))
+ self._prev_token_type = token_type
+ self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
self._start = self._current
@@ -910,7 +924,11 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
else:
break
- self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
+ self._add(
+ TokenType.VAR
+ if self._prev_token_type == TokenType.PARAMETER
+ else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
+ )
def _extract_string(self, delimiter):
text = ""
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index c929e59..7110eac 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -239,7 +239,7 @@ class TestBigQuery(Validator):
self.validate_all(
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
write={
- "spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
+ "spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)",
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])",
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
},
@@ -253,7 +253,7 @@ class TestBigQuery(Validator):
def test_user_defined_functions(self):
self.validate_identity(
- "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'"
+ "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index e0ec824..a9a313c 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1009,7 +1009,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT * FROM VALUES ('x'), ('y') AS t(z)",
write={
- "spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)",
+ "spark": "SELECT * FROM VALUES ('x'), ('y') AS t(z)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index b7e39a7..2145966 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -293,3 +293,15 @@ class TestSnowflake(Validator):
"bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1",
},
)
+ self.validate_all(
+ "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
+ write={
+ "snowflake": "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
+ },
+ )
+
+ def test_stored_procedures(self):
+ self.validate_identity("CALL a.b.c(x, y)")
+ self.validate_identity(
+ "CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
+ )
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 2654be1..a0de281 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -50,6 +50,7 @@ a.B()
a['x'].C()
int.x
map.x
+a.b.INT(1.234)
x IN (-1, 1)
x IN ('a', 'a''a')
x IN ((1))
@@ -357,6 +358,7 @@ SELECT * REPLACE (a + 1 AS b, b AS C)
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
+SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals)
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
@@ -444,6 +446,8 @@ CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
CREATE TEMPORARY VIEW x AS SELECT a FROM d
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
+CREATE MATERIALIZED VIEW x.y.z AS SELECT a FROM b
+DROP MATERIALIZED VIEW x.y.z
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
CREATE TABLE z (end INT)
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
@@ -471,10 +475,13 @@ CREATE FUNCTION f AS 'g'
CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1'
CREATE FUNCTION a() LANGUAGE sql
CREATE FUNCTION a() LANGUAGE sql RETURNS INT
+CREATE FUNCTION a.b.c()
+DROP FUNCTION a.b.c (INT)
CREATE INDEX abc ON t (a)
CREATE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
+DROP INDEX a.b.c
CACHE TABLE x
CACHE LAZY TABLE x
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value')
@@ -484,6 +491,8 @@ CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS (SELECT 1 AS y)
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
+CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END'
+DROP PROCEDURE a.b.c (INT)
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y
diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql
index 35aed3b..e13d3b3 100644
--- a/tests/fixtures/optimizer/merge_subqueries.sql
+++ b/tests/fixtures/optimizer/merge_subqueries.sql
@@ -97,3 +97,11 @@ WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM
-- Nested CTE
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;
+
+-- Inner select is an expression
+SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x;
+SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
+
+-- CTE select is an expression
+WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x;
+SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index 0bb742b..eb6761a 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -137,3 +137,20 @@ SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
FROM "x" AS "x";
+
+SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
+SELECT
+ "tab"."cola" AS "cola",
+ "tab"."colb" AS "colb"
+FROM (VALUES
+ (1, 'test'),
+ (2, 'test2')) AS "tab"("cola", "colb");
+
+# dialect: spark
+SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
+SELECT
+ `tab`.`cola` AS `cola`,
+ `tab`.`colb` AS `colb`
+FROM VALUES
+ (1, 'test'),
+ (2, 'test2') AS `tab`(`cola`, `colb`);
diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql
index 9deceb6..b03ffab 100644
--- a/tests/fixtures/optimizer/pushdown_projections.sql
+++ b/tests/fixtures/optimizer/pushdown_projections.sql
@@ -39,3 +39,15 @@ SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q
SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a);
SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0";
+
+SELECT x FROM (VALUES(1, 2)) AS q(x, y);
+SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
+
+SELECT x FROM UNNEST([1, 2]) AS q(x, y);
+SELECT q.x AS x FROM UNNEST(ARRAY(1, 2)) AS q(x, y);
+
+WITH t1 AS (SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)]) AS "q"("cola", "colb")) SELECT cola FROM t1;
+WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS colb))) AS "q"("cola", "colb")) SELECT t1.cola AS cola FROM t1;
+
+SELECT x FROM VALUES(1, 2) AS q(x, y);
+SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 8d4aecc..aad84ed 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
-from sqlglot.optimizer.scope import build_scope, traverse_scope
+from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
@@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
ON s.b = r.b
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
"""
- for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
+ expression = parse_one(sql)
+ for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
- self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
- self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
@@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
+ self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
+ self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
+ self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"})
+
+ # Check that we can walk in scope from an arbitrary node
+ self.assertEqual(
+ {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
+ {"s.b"},
+ )
+
def test_literal_type_annotation(self):
tests = {
"SELECT 5": exp.DataType.Type.INT,
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 4c46531..4e86516 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -122,6 +122,9 @@ class TestParser(unittest.TestCase):
def test_parameter(self):
self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")
+ def test_var(self):
+ self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
+
def test_annotations(self):
expression = parse_one(
"""