summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dialects/bigquery.py22
-rw-r--r--sqlglot/dialects/dialect.py2
-rw-r--r--sqlglot/dialects/snowflake.py2
-rw-r--r--sqlglot/dialects/spark.py2
-rw-r--r--sqlglot/expressions.py68
-rw-r--r--sqlglot/generator.py20
-rw-r--r--sqlglot/optimizer/merge_subqueries.py13
-rw-r--r--sqlglot/optimizer/qualify_columns.py10
-rw-r--r--sqlglot/optimizer/scope.py122
-rw-r--r--sqlglot/parser.py84
-rw-r--r--sqlglot/tokens.py24
12 files changed, 281 insertions, 90 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index befbc8a..1f7b28c 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "6.2.6"
+__version__ = "6.2.8"
pretty = False
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 432fd8c..40298e7 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -33,10 +33,10 @@ def _date_add_sql(data_type, kind):
return func
-def _subquery_to_unnest_if_values(self, expression):
- if not isinstance(expression.this, exp.Values):
- return self.subquery_sql(expression)
- rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
+def _derived_table_values_to_unnest(self, expression):
+ if not isinstance(expression.unnest().parent, exp.From):
+ return self.values_sql(expression)
+ rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
@@ -99,6 +99,7 @@ class BigQuery(Dialect):
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
+ "NOT DETERMINISTIC": TokenType.VOLATILE,
}
class Parser(Parser):
@@ -140,9 +141,10 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"),
- exp.Subquery: _subquery_to_unnest_if_values,
+ exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
+ exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
@@ -160,6 +162,16 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
}
+ ROOT_PROPERTIES = {
+ exp.LanguageProperty,
+ exp.ReturnsProperty,
+ exp.VolatilityProperty,
+ }
+
+ WITH_PROPERTIES = {
+ exp.AnonymousProperty,
+ }
+
def in_unnest_op(self, unnest):
return self.sql(unnest)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 0ab584e..98dc330 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -77,6 +77,7 @@ class Dialect(metaclass=_Dialect):
alias_post_tablesample = False
normalize_functions = "upper"
null_ordering = "nulls_are_small"
+ wrap_derived_values = True
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
@@ -169,6 +170,7 @@ class Dialect(metaclass=_Dialect):
"alias_post_tablesample": self.alias_post_tablesample,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
+ "wrap_derived_values": self.wrap_derived_values,
**opts,
}
)
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 1b718f7..fb2d900 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -177,6 +177,8 @@ class Snowflake(Dialect):
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
+ exp.ExecuteAsProperty,
+ exp.VolatilityProperty,
}
def except_op(self, expression):
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 5446e83..e8da07a 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -47,6 +47,8 @@ def _unix_to_time(self, expression):
class Spark(Hive):
+ wrap_derived_values = False
+
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 599c7db..8cdacce 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -213,21 +213,23 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
- def walk(self, bfs=True):
+ def walk(self, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in this tree.
Args:
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
+ prune ((node, parent, arg_key) -> bool): callable that returns True if
+ the generator should stop traversing this branch of the tree.
Returns:
the generator object.
"""
if bfs:
- yield from self.bfs()
+ yield from self.bfs(prune=prune)
else:
- yield from self.dfs()
+ yield from self.dfs(prune=prune)
def dfs(self, parent=None, key=None, prune=None):
"""
@@ -506,6 +508,10 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects]
+class UDTF(DerivedTable):
+ pass
+
+
class Annotation(Expression):
arg_types = {
"this": True,
@@ -652,7 +658,13 @@ class Delete(Expression):
class Drop(Expression):
- arg_types = {"this": False, "kind": False, "exists": False}
+ arg_types = {
+ "this": False,
+ "kind": False,
+ "exists": False,
+ "temporary": False,
+ "materialized": False,
+ }
class Filter(Expression):
@@ -827,7 +839,7 @@ class Join(Expression):
return join
-class Lateral(DerivedTable):
+class Lateral(UDTF):
arg_types = {"this": True, "outer": False, "alias": False}
@@ -915,6 +927,14 @@ class LanguageProperty(Property):
pass
+class ExecuteAsProperty(Property):
+ pass
+
+
+class VolatilityProperty(Property):
+ arg_types = {"this": True}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -1098,7 +1118,7 @@ class Intersect(Union):
pass
-class Unnest(DerivedTable):
+class Unnest(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
@@ -1116,8 +1136,12 @@ class Update(Expression):
}
-class Values(Expression):
- arg_types = {"expressions": True}
+class Values(UDTF):
+ arg_types = {
+ "expressions": True,
+ "ordinality": False,
+ "alias": False,
+ }
class Var(Expression):
@@ -2033,23 +2057,17 @@ class Func(Condition):
@classmethod
def from_arg_list(cls, args):
- args_num = len(args)
-
- all_arg_keys = list(cls.arg_types)
- # If this function supports variable length argument treat the last argument as such.
- non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
-
- args_dict = {}
- arg_idx = 0
- for arg_key in non_var_len_arg_keys:
- if arg_idx >= args_num:
- break
- if args[arg_idx] is not None:
- args_dict[arg_key] = args[arg_idx]
- arg_idx += 1
-
- if arg_idx < args_num and cls.is_var_len_args:
- args_dict[all_arg_keys[-1]] = args[arg_idx:]
+ if cls.is_var_len_args:
+ all_arg_keys = list(cls.arg_types)
+ # If this function supports variable length argument treat the last argument as such.
+ non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
+ num_non_var = len(non_var_len_arg_keys)
+
+ args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)}
+ args_dict[all_arg_keys[-1]] = args[num_non_var:]
+ else:
+ args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)}
+
return cls(**args_dict)
@classmethod
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 9099307..8b356f3 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -49,10 +49,12 @@ class Generator:
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
- exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
+ exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
+ exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
NULL_ORDERING_SUPPORTED = True
@@ -99,6 +101,7 @@ class Generator:
"unsupported_messages",
"null_ordering",
"max_unsupported",
+ "wrap_derived_values",
"_indent",
"_replace_backslash",
"_escaped_quote_end",
@@ -127,6 +130,7 @@ class Generator:
null_ordering=None,
max_unsupported=3,
leading_comma=False,
+ wrap_derived_values=True,
):
import sqlglot
@@ -150,6 +154,7 @@ class Generator:
self.unsupported_messages = []
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
+ self.wrap_derived_values = wrap_derived_values
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
@@ -407,7 +412,9 @@ class Generator:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
- return f"DROP {kind}{exists_sql}{this}"
+ temporary = " TEMPORARY" if expression.args.get("temporary") else ""
+ materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
def except_sql(self, expression):
return self.prepend_ctes(
@@ -583,7 +590,14 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression):
- return f"VALUES{self.seg('')}{self.expressions(expression)}"
+ alias = self.sql(expression, "alias")
+ args = self.expressions(expression)
+ if not alias:
+ return f"VALUES{self.seg('')}{args}"
+ alias = f" AS {alias}" if alias else alias
+ if self.wrap_derived_values:
+ return f"(VALUES{self.seg('')}{args}){alias}"
+ return f"VALUES{self.seg('')}{args}{alias}"
def var_sql(self, expression):
return self.sql(expression, "this")
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 9d966b7..d29c22b 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -32,8 +32,8 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Returns:
sqlglot.Expression: optimized expression
"""
- merge_ctes(expression, leave_tables_isolated)
- merge_derived_tables(expression, leave_tables_isolated)
+ expression = merge_ctes(expression, leave_tables_isolated)
+ expression = merge_derived_tables(expression, leave_tables_isolated)
return expression
@@ -76,14 +76,14 @@ def merge_ctes(expression, leave_tables_isolated=False):
alias = node_to_replace.alias
else:
alias = table.name
-
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
- _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
+ return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
@@ -97,10 +97,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
- _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
+ _merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
+ return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
@@ -229,7 +230,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
continue
columns_to_replace = outer_columns.get(projection_name, [])
for column in columns_to_replace:
- column.replace(expression.unalias())
+ column.replace(expression.unalias().copy())
def _merge_where(outer_scope, inner_scope, from_or_join):
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 0bb947a..72ce256 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -5,8 +5,6 @@ from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
-SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
-
def qualify_columns(expression, schema):
"""
@@ -35,7 +33,7 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
- if not isinstance(scope.expression, SKIP_QUALIFY):
+ if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
@@ -50,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
- if isinstance(derived_table, SKIP_QUALIFY):
+ if isinstance(derived_table, exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@@ -202,7 +200,7 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
- if not scope.is_subquery and not scope.is_unnest:
+ if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
@@ -296,7 +294,7 @@ def _qualify_outputs(scope):
def _check_unknown_tables(scope):
- if scope.external_columns and not scope.is_unnest and not scope.is_correlated_subquery:
+ if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index be6cfb9..6332cdd 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,5 +1,4 @@
import itertools
-from copy import copy
from enum import Enum, auto
from sqlglot import exp
@@ -12,7 +11,7 @@ class ScopeType(Enum):
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
- UNNEST = auto()
+ UDTF = auto()
class Scope:
@@ -70,14 +69,11 @@ class Scope:
self._columns = None
self._external_columns = None
- def branch(self, expression, scope_type, add_sources=None, **kwargs):
+ def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
- sources = copy(self.sources)
- if add_sources:
- sources.update(add_sources)
return Scope(
expression=expression.unnest(),
- sources=sources,
+ sources={**self.cte_sources, **(chain_sources or {})},
parent=self,
scope_type=scope_type,
**kwargs,
@@ -90,30 +86,21 @@ class Scope:
self._derived_tables = []
self._raw_columns = []
- # We'll use this variable to pass state into the dfs generator.
- # Whenever we set it to True, we exclude a subtree from traversal.
- prune = False
-
- for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
- prune = False
-
+ for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
- if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+ elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
- elif isinstance(node, (exp.Unnest, exp.Lateral)):
+ elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
- prune = True
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
self._derived_tables.append(node)
- prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
- prune = True
self._collected = True
@@ -121,6 +108,43 @@ class Scope:
if not self._collected:
self._collect()
+ def walk(self, bfs=True):
+ return walk_in_scope(self.expression, bfs=bfs)
+
+ def find(self, *expression_types, bfs=True):
+ """
+ Returns the first node in this scope which matches at least one of the specified types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Returns:
+ exp.Expression: the node which matches the criteria or None if no node matching
+ the criteria was found.
+ """
+ return next(self.find_all(*expression_types, bfs=bfs), None)
+
+ def find_all(self, *expression_types, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this scope and only yields those that
+ match at least one of the specified expression types.
+
+ This does NOT traverse into subscopes.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+ bfs (bool): True to use breadth-first search, False to use depth-first.
+
+ Yields:
+ exp.Expression: nodes
+ """
+ for expression, _, _ in self.walk(bfs=bfs):
+ if isinstance(expression, expression_types):
+ yield expression
+
def replace(self, old, new):
"""
Replace `old` with `new`.
@@ -247,6 +271,16 @@ class Scope:
return self._selected_sources
@property
+ def cte_sources(self):
+ """
+ Sources that are CTEs.
+
+ Returns:
+ dict[str, Scope]: Mapping of source alias to Scope
+ """
+ return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
+
+ @property
def selects(self):
"""
Select expressions of this scope.
@@ -313,9 +347,9 @@ class Scope:
return self.scope_type == ScopeType.ROOT
@property
- def is_unnest(self):
- """Determine if this scope is an unnest"""
- return self.scope_type == ScopeType.UNNEST
+ def is_udtf(self):
+ """Determine if this scope is a UDTF (User Defined Table Function)"""
+ return self.scope_type == ScopeType.UDTF
@property
def is_correlated_subquery(self):
@@ -348,7 +382,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
+ self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@@ -399,7 +433,7 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
- elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
+ elif isinstance(scope.expression, exp.UDTF):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
@@ -410,8 +444,8 @@ def _traverse_scope(scope):
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
- yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
+ yield from _traverse_subqueries(scope)
_add_table_sources(scope)
@@ -437,10 +471,10 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
top = None
for child_scope in _traverse_scope(
scope.branch(
- derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
- add_sources=sources if scope_type == ScopeType.CTE else None,
+ derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
+ chain_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UNNEST if isinstance(derived_table, exp.Unnest) else scope_type,
+ scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
)
):
yield child_scope
@@ -483,3 +517,35 @@ def _traverse_subqueries(scope):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)
+
+
+def walk_in_scope(expression, bfs=True):
+ """
+ Returns a generator object which visits all nodes in the syntrax tree, stopping at
+ nodes that start child scopes.
+
+ Args:
+ expression (exp.Expression):
+ bfs (bool): if set to True the BFS traversal order will be applied,
+ otherwise the DFS traversal will be used instead.
+
+ Yields:
+ tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
+ """
+ # We'll use this variable to pass state into the dfs generator.
+ # Whenever we set it to True, we exclude a subtree from traversal.
+ prune = False
+
+ for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
+ prune = False
+
+ yield node, parent, key
+
+ if node is expression:
+ continue
+ elif isinstance(node, exp.CTE):
+ prune = True
+ elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
+ prune = True
+ elif isinstance(node, exp.Subqueryable):
+ prune = True
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 72bad92..5f20afc 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -126,6 +126,8 @@ class Parser:
TokenType.CONSTRAINT,
TokenType.DEFAULT,
TokenType.DELETE,
+ TokenType.DETERMINISTIC,
+ TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
TokenType.EXPLAIN,
@@ -139,6 +141,7 @@ class Parser:
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
+ TokenType.IMMUTABLE,
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LANGUAGE,
@@ -163,6 +166,7 @@ class Parser:
TokenType.SEED,
TokenType.SET,
TokenType.SHOW,
+ TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
TokenType.TABLE_FORMAT,
@@ -175,6 +179,8 @@ class Parser:
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
+ TokenType.PROCEDURE,
+ TokenType.VOLATILE,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
}
@@ -204,7 +210,7 @@ class Parser:
TokenType.DATETIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
- *NESTED_TYPE_TOKENS,
+ *TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@@ -379,6 +385,13 @@ class Parser:
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
+ TokenType.EXECUTE: lambda self: self._parse_execute_as(),
+ TokenType.DETERMINISTIC: lambda self: self.expression(
+ exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
+ ),
+ TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
+ TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
+ TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
}
CONSTRAINT_PARSERS = {
@@ -418,7 +431,7 @@ class Parser:
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
- CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
+ CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE}
STRICT_CAST = True
@@ -615,18 +628,20 @@ class Parser:
return expression
def _parse_drop(self):
- if self._match(TokenType.TABLE):
- kind = "TABLE"
- elif self._match(TokenType.VIEW):
- kind = "VIEW"
- else:
- self.raise_error("Expected TABLE or View")
+ temporary = self._match(TokenType.TEMPORARY)
+ materialized = self._match(TokenType.MATERIALIZED)
+ kind = self._match_set(self.CREATABLES) and self._prev.text
+ if not kind:
+ self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
+ return
return self.expression(
exp.Drop,
exists=self._parse_exists(),
this=self._parse_table(schema=True),
kind=kind,
+ temporary=temporary,
+ materialized=materialized,
)
def _parse_exists(self, not_=False):
@@ -644,14 +659,15 @@ class Parser:
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
- self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION")
+ self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
+ return
exists = self._parse_exists(not_=True)
this = None
expression = None
properties = None
- if create_token.token_type == TokenType.FUNCTION:
+ if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
@@ -747,7 +763,9 @@ class Parser:
if is_table:
if self._match(TokenType.LT):
value = self.expression(
- exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
+ exp.Schema,
+ this="TABLE",
+ expressions=self._parse_csv(self._parse_struct_kwargs),
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@@ -763,6 +781,14 @@ class Parser:
is_table=is_table,
)
+ def _parse_execute_as(self):
+ self._match(TokenType.ALIAS)
+ return self.expression(
+ exp.ExecuteAsProperty,
+ this=exp.Literal.string("EXECUTE AS"),
+ value=self._parse_var(),
+ )
+
def _parse_properties(self):
properties = []
@@ -997,7 +1023,12 @@ class Parser:
)
def _parse_subquery(self, this):
- return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
+ return self.expression(
+ exp.Subquery,
+ this=this,
+ pivots=self._parse_pivots(),
+ alias=self._parse_table_alias(),
+ )
def _parse_query_modifiers(self, this):
if not isinstance(this, self.MODIFIABLES):
@@ -1118,6 +1149,11 @@ class Parser:
if unnest:
return unnest
+ values = self._parse_derived_table_values()
+
+ if values:
+ return values
+
subquery = self._parse_select(table=True)
if subquery:
@@ -1186,6 +1222,24 @@ class Parser:
alias=alias,
)
+ def _parse_derived_table_values(self):
+ is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
+ if not is_derived and not self._match(TokenType.VALUES):
+ return None
+
+ expressions = self._parse_csv(self._parse_value)
+
+ if is_derived:
+ self._match_r_paren()
+
+ alias = self._parse_table_alias()
+
+ return self.expression(
+ exp.Values,
+ expressions=expressions,
+ alias=alias,
+ )
+
def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE):
return None
@@ -1700,7 +1754,11 @@ class Parser:
return self._parse_window(this)
def _parse_user_defined_function(self):
- this = self._parse_var()
+ this = self._parse_id_var()
+
+ while self._match(TokenType.DOT):
+ this = self.expression(exp.Dot, this=this, expression=self._parse_id_var())
+
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index c81f0db..39bf421 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -136,6 +136,7 @@ class TokenType(AutoName):
DEFAULT = auto()
DELETE = auto()
DESC = auto()
+ DETERMINISTIC = auto()
DISTINCT = auto()
DISTRIBUTE_BY = auto()
DROP = auto()
@@ -144,6 +145,7 @@ class TokenType(AutoName):
ENGINE = auto()
ESCAPE = auto()
EXCEPT = auto()
+ EXECUTE = auto()
EXISTS = auto()
EXPLAIN = auto()
FALSE = auto()
@@ -167,6 +169,7 @@ class TokenType(AutoName):
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
+ IMMUTABLE = auto()
IN = auto()
INDEX = auto()
INNER = auto()
@@ -215,6 +218,7 @@ class TokenType(AutoName):
PLACEHOLDER = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
+ PROCEDURE = auto()
PROPERTIES = auto()
QUALIFY = auto()
QUOTE = auto()
@@ -238,6 +242,7 @@ class TokenType(AutoName):
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
+ STABLE = auto()
STORED = auto()
STRUCT = auto()
TABLE_FORMAT = auto()
@@ -258,6 +263,7 @@ class TokenType(AutoName):
USING = auto()
VALUES = auto()
VIEW = auto()
+ VOLATILE = auto()
WHEN = auto()
WHERE = auto()
WINDOW = auto()
@@ -430,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
+ "DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DROP": TokenType.DROP,
@@ -438,6 +445,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ENGINE": TokenType.ENGINE,
"ESCAPE": TokenType.ESCAPE,
"EXCEPT": TokenType.EXCEPT,
+ "EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE,
@@ -456,6 +464,7 @@ class Tokenizer(metaclass=_Tokenizer):
"HAVING": TokenType.HAVING,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
+ "IMMUTABLE": TokenType.IMMUTABLE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
@@ -504,6 +513,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
+ "PROCEDURE": TokenType.PROCEDURE,
"RANGE": TokenType.RANGE,
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
@@ -522,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
+ "STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
"TABLE": TokenType.TABLE,
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
@@ -542,6 +553,7 @@ class Tokenizer(metaclass=_Tokenizer):
"USING": TokenType.USING,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
+ "VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
"WITH": TokenType.WITH,
@@ -637,6 +649,7 @@ class Tokenizer(metaclass=_Tokenizer):
"_char",
"_end",
"_peek",
+ "_prev_token_type",
)
def __init__(self):
@@ -657,6 +670,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._char = None
self._end = None
self._peek = None
+ self._prev_token_type = None
def tokenize(self, sql):
self.reset()
@@ -706,8 +720,8 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[self._start : self._current]
def _add(self, token_type, text=None):
- text = self._text if text is None else text
- self.tokens.append(Token(token_type, text, self._line, self._col))
+ self._prev_token_type = token_type
+ self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
self._start = self._current
@@ -910,7 +924,11 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
else:
break
- self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
+ self._add(
+ TokenType.VAR
+ if self._prev_token_type == TokenType.PARAMETER
+ else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
+ )
def _extract_string(self, delimiter):
text = ""