summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py159
1 files changed, 121 insertions, 38 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index de615d6..599c7db 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,13 +1,17 @@
-import inspect
import numbers
import re
-import sys
from collections import deque
from copy import deepcopy
from enum import auto
from sqlglot.errors import ParseError
-from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
+from sqlglot.helper import (
+ AutoName,
+ camel_to_snake_case,
+ ensure_list,
+ list_get,
+ subclasses,
+)
class _Expression(type):
@@ -31,12 +35,13 @@ class Expression(metaclass=_Expression):
key = None
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key")
+ __slots__ = ("args", "parent", "arg_key", "type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
+ self.type = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@@ -384,7 +389,7 @@ class Expression(metaclass=_Expression):
'SELECT y FROM tbl'
Args:
- expression (Expression): new node
+ expression (Expression|None): new node
Returns :
the new expression or expressions
@@ -398,6 +403,12 @@ class Expression(metaclass=_Expression):
replace_children(parent, lambda child: expression if child is self else child)
return expression
+ def pop(self):
+ """
+ Remove this expression from its AST.
+ """
+ self.replace(None)
+
def assert_is(self, type_):
"""
Assert that this `Expression` is an instance of `type_`.
@@ -527,9 +538,18 @@ class Create(Expression):
"temporary": False,
"replace": False,
"unique": False,
+ "materialized": False,
}
+class UserDefinedFunction(Expression):
+ arg_types = {"this": True, "expressions": False}
+
+
+class UserDefinedFunctionKwarg(Expression):
+ arg_types = {"this": True, "kind": True, "default": False}
+
+
class CharacterSet(Expression):
arg_types = {"this": True, "default": False}
@@ -887,6 +907,14 @@ class AnonymousProperty(Property):
pass
+class ReturnsProperty(Property):
+ arg_types = {"this": True, "value": True, "is_table": False}
+
+
+class LanguageProperty(Property):
+ pass
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -907,25 +935,9 @@ class Properties(Expression):
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
- expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
+ expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
return cls(expressions=expressions)
- @staticmethod
- def _convert_value(value):
- if value is None:
- return NULL
- if isinstance(value, Expression):
- return value
- if isinstance(value, bool):
- return Boolean(this=value)
- if isinstance(value, str):
- return Literal.string(value)
- if isinstance(value, numbers.Number):
- return Literal.number(value)
- if isinstance(value, list):
- return Tuple(expressions=[Properties._convert_value(v) for v in value])
- raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
-
class Qualify(Expression):
pass
@@ -1030,6 +1042,7 @@ class Subqueryable:
QUERY_MODIFIERS = {
"laterals": False,
"joins": False,
+ "pivots": False,
"where": False,
"group": False,
"having": False,
@@ -1051,6 +1064,7 @@ class Table(Expression):
"catalog": False,
"laterals": False,
"joins": False,
+ "pivots": False,
}
@@ -1643,6 +1657,16 @@ class TableSample(Expression):
"percent": False,
"rows": False,
"size": False,
+ "seed": False,
+ }
+
+
+class Pivot(Expression):
+ arg_types = {
+ "this": False,
+ "expressions": True,
+ "field": True,
+ "unpivot": True,
}
@@ -1741,7 +1765,8 @@ class DataType(Expression):
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
- SQL_VARIANT = auto()
+ VARIANT = auto()
+ OBJECT = auto()
@classmethod
def build(cls, dtype, **kwargs):
@@ -2124,6 +2149,7 @@ class TryCast(Cast):
class Ceil(Func):
+ arg_types = {"this": True, "decimals": False}
_sql_names = ["CEIL", "CEILING"]
@@ -2254,7 +2280,7 @@ class Explode(Func):
class Floor(Func):
- pass
+ arg_types = {"this": True, "decimals": False}
class Greatest(Func):
@@ -2371,7 +2397,7 @@ class Reduce(Func):
class RegexpLike(Func):
- arg_types = {"this": True, "expression": True}
+ arg_types = {"this": True, "expression": True, "flag": False}
class RegexpSplit(Func):
@@ -2540,6 +2566,8 @@ def _norm_args(expression):
for k, arg in expression.args.items():
if isinstance(arg, list):
arg = [_norm_arg(a) for a in arg]
+ if not arg:
+ arg = None
else:
arg = _norm_arg(arg)
@@ -2553,17 +2581,7 @@ def _norm_arg(arg):
return arg.lower() if isinstance(arg, str) else arg
-def _all_functions():
- return [
- obj
- for _, obj in inspect.getmembers(
- sys.modules[__name__],
- lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
- )
- ]
-
-
-ALL_FUNCTIONS = _all_functions()
+ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
def maybe_parse(
@@ -2793,6 +2811,37 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts)
+def update(table, properties, where=None, from_=None, dialect=None, **opts):
+ """
+ Creates an update statement.
+
+ Example:
+ >>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql()
+ "UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
+
+ Args:
+ *properties (Dict[str, Any]): dictionary of properties to set which are
+ auto converted to sql objects eg None -> NULL
+ where (str): sql conditional parsed into a WHERE statement
+ from_ (str): sql statement parsed into a FROM statement
+ dialect (str): the dialect used to parse the input expressions.
+ **opts: other options to use to parse the input expressions.
+
+ Returns:
+ Update: the syntax tree for the UPDATE statement.
+ """
+ update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
+ update.set(
+ "expressions",
+ [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
+ )
+ if from_:
+ update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
+ if where:
+ update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
+ return update
+
+
def condition(expression, dialect=None, **opts):
"""
Initialize a logical condition expression.
@@ -2980,12 +3029,13 @@ def column(col, table=None, quoted=None):
def table_(table, db=None, catalog=None, quoted=None):
- """
- Build a Table.
+ """Build a Table.
+
Args:
table (str or Expression): column name
db (str or Expression): db name
catalog (str or Expression): catalog name
+
Returns:
Table: table instance
"""
@@ -2996,6 +3046,39 @@ def table_(table, db=None, catalog=None, quoted=None):
)
+def convert(value):
+ """Convert a python value into an expression object.
+
+ Raises an error if a conversion is not possible.
+
+ Args:
+ value (Any): a python object
+
+ Returns:
+ Expression: the equivalent expression object
+ """
+ if isinstance(value, Expression):
+ return value
+ if value is None:
+ return NULL
+ if isinstance(value, bool):
+ return Boolean(this=value)
+ if isinstance(value, str):
+ return Literal.string(value)
+ if isinstance(value, numbers.Number):
+ return Literal.number(value)
+ if isinstance(value, tuple):
+ return Tuple(expressions=[convert(v) for v in value])
+ if isinstance(value, list):
+ return Array(expressions=[convert(v) for v in value])
+ if isinstance(value, dict):
+ return Map(
+ keys=[convert(k) for k in value.keys()],
+ values=[convert(v) for v in value.values()],
+ )
+ raise ValueError(f"Cannot convert {value}")
+
+
def replace_children(expression, fun):
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.