diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-30 05:07:28 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-30 05:07:28 +0000 |
commit | 5a674d94c3ab243e2dd6a00f9edf6cc50b018512 (patch) | |
tree | 0b6fe74b5b346f0b048162b56a12885f1a2c2912 /sqlglot/expressions.py | |
parent | Releasing debian version 6.2.1-1. (diff) | |
download | sqlglot-5a674d94c3ab243e2dd6a00f9edf6cc50b018512.tar.xz sqlglot-5a674d94c3ab243e2dd6a00f9edf6cc50b018512.zip |
Merging upstream version 6.2.6.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 159 |
1 files changed, 121 insertions, 38 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index de615d6..599c7db 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,13 +1,17 @@ -import inspect import numbers import re -import sys from collections import deque from copy import deepcopy from enum import auto from sqlglot.errors import ParseError -from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get +from sqlglot.helper import ( + AutoName, + camel_to_snake_case, + ensure_list, + list_get, + subclasses, +) class _Expression(type): @@ -31,12 +35,13 @@ class Expression(metaclass=_Expression): key = None arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key") + __slots__ = ("args", "parent", "arg_key", "type") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None + self.type = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -384,7 +389,7 @@ class Expression(metaclass=_Expression): 'SELECT y FROM tbl' Args: - expression (Expression): new node + expression (Expression|None): new node Returns : the new expression or expressions @@ -398,6 +403,12 @@ class Expression(metaclass=_Expression): replace_children(parent, lambda child: expression if child is self else child) return expression + def pop(self): + """ + Remove this expression from its AST. + """ + self.replace(None) + def assert_is(self, type_): """ Assert that this `Expression` is an instance of `type_`. @@ -527,9 +538,18 @@ class Create(Expression): "temporary": False, "replace": False, "unique": False, + "materialized": False, } +class UserDefinedFunction(Expression): + arg_types = {"this": True, "expressions": False} + + +class UserDefinedFunctionKwarg(Expression): + arg_types = {"this": True, "kind": True, "default": False} + + class CharacterSet(Expression): arg_types = {"this": True, "default": False} @@ -887,6 +907,14 @@ class AnonymousProperty(Property): pass +class ReturnsProperty(Property): + arg_types = {"this": True, "value": True, "is_table": False} + + +class LanguageProperty(Property): + pass + + class Properties(Expression): arg_types = {"expressions": True} @@ -907,25 +935,9 @@ class Properties(Expression): expressions = [] for key, value in properties_dict.items(): property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) - expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value))) + expressions.append(property_cls(this=Literal.string(key), value=convert(value))) return cls(expressions=expressions) - @staticmethod - def _convert_value(value): - if value is None: - return NULL - if isinstance(value, Expression): - return value - if isinstance(value, bool): - return Boolean(this=value) - if isinstance(value, str): - return Literal.string(value) - if isinstance(value, numbers.Number): - return Literal.number(value) - if isinstance(value, list): - return Tuple(expressions=[Properties._convert_value(v) for v in value]) - raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'") - class Qualify(Expression): pass @@ -1030,6 +1042,7 @@ class Subqueryable: QUERY_MODIFIERS = { "laterals": False, "joins": False, + "pivots": False, "where": False, "group": False, "having": False, @@ -1051,6 +1064,7 @@ class Table(Expression): "catalog": False, "laterals": False, "joins": False, + "pivots": False, } @@ -1643,6 +1657,16 @@ class TableSample(Expression): "percent": False, "rows": False, "size": False, + "seed": False, + } + + +class Pivot(Expression): + arg_types = { + "this": False, + "expressions": True, + "field": True, + "unpivot": True, } @@ -1741,7 +1765,8 @@ class DataType(Expression): SMALLMONEY = auto() ROWVERSION = auto() IMAGE = auto() - SQL_VARIANT = auto() + VARIANT = auto() + OBJECT = auto() @classmethod def build(cls, dtype, **kwargs): @@ -2124,6 +2149,7 @@ class TryCast(Cast): class Ceil(Func): + arg_types = {"this": True, "decimals": False} _sql_names = ["CEIL", "CEILING"] @@ -2254,7 +2280,7 @@ class Explode(Func): class Floor(Func): - pass + arg_types = {"this": True, "decimals": False} class Greatest(Func): @@ -2371,7 +2397,7 @@ class Reduce(Func): class RegexpLike(Func): - arg_types = {"this": True, "expression": True} + arg_types = {"this": True, "expression": True, "flag": False} class RegexpSplit(Func): @@ -2540,6 +2566,8 @@ def _norm_args(expression): for k, arg in expression.args.items(): if isinstance(arg, list): arg = [_norm_arg(a) for a in arg] + if not arg: + arg = None else: arg = _norm_arg(arg) @@ -2553,17 +2581,7 @@ def _norm_arg(arg): return arg.lower() if isinstance(arg, str) else arg -def _all_functions(): - return [ - obj - for _, obj in inspect.getmembers( - sys.modules[__name__], - lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func), - ) - ] - - -ALL_FUNCTIONS = _all_functions() +ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) def maybe_parse( @@ -2793,6 +2811,37 @@ def from_(*expressions, dialect=None, **opts): return Select().from_(*expressions, dialect=dialect, **opts) +def update(table, properties, where=None, from_=None, dialect=None, **opts): + """ + Creates an update statement. + + Example: + >>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql() + "UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1" + + Args: + *properties (Dict[str, Any]): dictionary of properties to set which are + auto converted to sql objects eg None -> NULL + where (str): sql conditional parsed into a WHERE statement + from_ (str): sql statement parsed into a FROM statement + dialect (str): the dialect used to parse the input expressions. + **opts: other options to use to parse the input expressions. + + Returns: + Update: the syntax tree for the UPDATE statement. + """ + update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) + update.set( + "expressions", + [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()], + ) + if from_: + update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) + if where: + update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) + return update + + def condition(expression, dialect=None, **opts): """ Initialize a logical condition expression. @@ -2980,12 +3029,13 @@ def column(col, table=None, quoted=None): def table_(table, db=None, catalog=None, quoted=None): - """ - Build a Table. + """Build a Table. + Args: table (str or Expression): column name db (str or Expression): db name catalog (str or Expression): catalog name + Returns: Table: table instance """ @@ -2996,6 +3046,39 @@ def table_(table, db=None, catalog=None, quoted=None): ) +def convert(value): + """Convert a python value into an expression object. + + Raises an error if a conversion is not possible. + + Args: + value (Any): a python object + + Returns: + Expression: the equivalent expression object + """ + if isinstance(value, Expression): + return value + if value is None: + return NULL + if isinstance(value, bool): + return Boolean(this=value) + if isinstance(value, str): + return Literal.string(value) + if isinstance(value, numbers.Number): + return Literal.number(value) + if isinstance(value, tuple): + return Tuple(expressions=[convert(v) for v in value]) + if isinstance(value, list): + return Array(expressions=[convert(v) for v in value]) + if isinstance(value, dict): + return Map( + keys=[convert(k) for k in value.keys()], + values=[convert(v) for v in value.values()], + ) + raise ValueError(f"Cannot convert {value}") + + def replace_children(expression, fun): """ Replace children of an expression with the result of a lambda fun(child) -> exp. |