From 7a2201963d5b03bd1828d350ccaecb4eda30d30c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 4 Jan 2023 08:24:08 +0100 Subject: Merging upstream version 10.4.2. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 104 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 90 insertions(+), 14 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index aeed218..711ec4b 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,6 +1,11 @@ +""" +.. include:: ../pdoc/docs/expressions.md +""" + from __future__ import annotations import datetime +import math import numbers import re import typing as t @@ -682,6 +687,10 @@ class CharacterSet(Expression): class With(Expression): arg_types = {"expressions": True, "recursive": False} + @property + def recursive(self) -> bool: + return bool(self.args.get("recursive")) + class WithinGroup(Expression): arg_types = {"this": True, "expression": False} @@ -724,6 +733,18 @@ class ColumnDef(Expression): "this": True, "kind": True, "constraints": False, + "exists": False, + } + + +class AlterColumn(Expression): + arg_types = { + "this": True, + "dtype": False, + "collate": False, + "using": False, + "default": False, + "drop": False, } @@ -877,6 +898,11 @@ class Introducer(Expression): arg_types = {"this": True, "expression": True} +# national char, like n'utf8' +class National(Expression): + pass + + class LoadData(Expression): arg_types = { "this": True, @@ -894,7 +920,7 @@ class Partition(Expression): class Fetch(Expression): - arg_types = {"direction": False, "count": True} + arg_types = {"direction": False, "count": False} class Group(Expression): @@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = { "group": False, "having": False, "qualify": False, - "window": False, + "windows": False, "distribute": False, "sort": False, "cluster": False, @@ -1353,7 +1379,7 @@ class Union(Subqueryable): Example: >>> select("1").union(select("1")).limit(1).sql() - 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1' + 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' Args: expression (str | int | Expression): the SQL code string to parse. @@ -1889,6 +1915,18 @@ class Select(Subqueryable): **opts, ) + def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + return _apply_list_builder( + *expressions, + instance=self, + arg="windows", + append=append, + into=Window, + dialect=dialect, + copy=copy, + **opts, + ) + def distinct(self, distinct=True, copy=True) -> Select: """ Set the OFFSET expression. @@ -2140,6 +2178,11 @@ class DataType(Expression): ) +# https://www.postgresql.org/docs/15/datatype-pseudo.html +class PseudoType(Expression): + pass + + class StructKwarg(Expression): arg_types = {"this": True, "expression": True} @@ -2167,18 +2210,26 @@ class Command(Expression): arg_types = {"this": True, "expression": False} -class Transaction(Command): +class Transaction(Expression): arg_types = {"this": False, "modes": False} -class Commit(Command): +class Commit(Expression): arg_types = {"chain": False} -class Rollback(Command): +class Rollback(Expression): arg_types = {"savepoint": False} +class AlterTable(Expression): + arg_types = { + "this": True, + "actions": True, + "exists": False, + } + + # Binary expressions like (ADD a b) class Binary(Expression): arg_types = {"this": True, "expression": True} @@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate): pass +class Slice(Binary): + arg_types = {"this": False, "expression": False} + + class Sub(Binary): pass @@ -2392,7 +2447,7 @@ class TimeUnit(Expression): class Interval(TimeUnit): - arg_types = {"this": True, "unit": False} + arg_types = {"this": False, "unit": False} class IgnoreNulls(Expression): @@ -2730,8 +2785,11 @@ class Initcap(Func): pass -class JSONExtract(Func): - arg_types = {"this": True, "path": True} +class JSONBContains(Binary): + _sql_names = ["JSONB_CONTAINS"] + + +class JSONExtract(Binary, Func): _sql_names = ["JSON_EXTRACT"] @@ -2776,6 +2834,10 @@ class Log10(Func): pass +class LogicalOr(AggFunc): + _sql_names = ["LOGICAL_OR", "BOOL_OR"] + + class Lower(Func): _sql_names = ["LOWER", "LCASE"] @@ -2846,6 +2908,10 @@ class RegexpLike(Func): arg_types = {"this": True, "expression": True, "flag": False} +class RegexpILike(Func): + arg_types = {"this": True, "expression": True, "flag": False} + + class RegexpSplit(Func): arg_types = {"this": True, "expression": True} @@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U ], ) if from_: - update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) + update.set( + "from", + maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), + ) if isinstance(where, Condition): where = Where(this=where) if where: - update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) + update.set( + "where", + maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), + ) return update @@ -3522,7 +3594,7 @@ def paren(expression) -> Paren: return Paren(this=expression) -SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") +SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: @@ -3724,6 +3796,8 @@ def convert(value) -> Expression: return Boolean(this=value) if isinstance(value, str): return Literal.string(value) + if isinstance(value, float) and math.isnan(value): + return NULL if isinstance(value, numbers.Number): return Literal.number(value) if isinstance(value, tuple): @@ -3732,11 +3806,13 @@ def convert(value) -> Expression: return Array(expressions=[convert(v) for v in value]) if isinstance(value, dict): return Map( - keys=[convert(k) for k in value.keys()], + keys=[convert(k) for k in value], values=[convert(v) for v in value.values()], ) if isinstance(value, datetime.datetime): - datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z")) + datetime_literal = Literal.string( + (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() + ) return TimeStrToTime(this=datetime_literal) if isinstance(value, datetime.date): date_literal = Literal.string(value.strftime("%Y-%m-%d")) -- cgit v1.2.3