summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py104
1 files changed, 90 insertions, 14 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index aeed218..711ec4b 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,6 +1,11 @@
+"""
+.. include:: ../pdoc/docs/expressions.md
+"""
+
from __future__ import annotations
import datetime
+import math
import numbers
import re
import typing as t
@@ -682,6 +687,10 @@ class CharacterSet(Expression):
class With(Expression):
arg_types = {"expressions": True, "recursive": False}
+ @property
+ def recursive(self) -> bool:
+ return bool(self.args.get("recursive"))
+
class WithinGroup(Expression):
arg_types = {"this": True, "expression": False}
@@ -724,6 +733,18 @@ class ColumnDef(Expression):
"this": True,
"kind": True,
"constraints": False,
+ "exists": False,
+ }
+
+
+class AlterColumn(Expression):
+ arg_types = {
+ "this": True,
+ "dtype": False,
+ "collate": False,
+ "using": False,
+ "default": False,
+ "drop": False,
}
@@ -877,6 +898,11 @@ class Introducer(Expression):
arg_types = {"this": True, "expression": True}
+# national char, like n'utf8'
+class National(Expression):
+ pass
+
+
class LoadData(Expression):
arg_types = {
"this": True,
@@ -894,7 +920,7 @@ class Partition(Expression):
class Fetch(Expression):
- arg_types = {"direction": False, "count": True}
+ arg_types = {"direction": False, "count": False}
class Group(Expression):
@@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = {
"group": False,
"having": False,
"qualify": False,
- "window": False,
+ "windows": False,
"distribute": False,
"sort": False,
"cluster": False,
@@ -1353,7 +1379,7 @@ class Union(Subqueryable):
Example:
>>> select("1").union(select("1")).limit(1).sql()
- 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
+ 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
Args:
expression (str | int | Expression): the SQL code string to parse.
@@ -1889,6 +1915,18 @@ class Select(Subqueryable):
**opts,
)
+ def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ return _apply_list_builder(
+ *expressions,
+ instance=self,
+ arg="windows",
+ append=append,
+ into=Window,
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@@ -2140,6 +2178,11 @@ class DataType(Expression):
)
+# https://www.postgresql.org/docs/15/datatype-pseudo.html
+class PseudoType(Expression):
+ pass
+
+
class StructKwarg(Expression):
arg_types = {"this": True, "expression": True}
@@ -2167,18 +2210,26 @@ class Command(Expression):
arg_types = {"this": True, "expression": False}
-class Transaction(Command):
+class Transaction(Expression):
arg_types = {"this": False, "modes": False}
-class Commit(Command):
+class Commit(Expression):
arg_types = {"chain": False}
-class Rollback(Command):
+class Rollback(Expression):
arg_types = {"savepoint": False}
+class AlterTable(Expression):
+ arg_types = {
+ "this": True,
+ "actions": True,
+ "exists": False,
+ }
+
+
# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate):
pass
+class Slice(Binary):
+ arg_types = {"this": False, "expression": False}
+
+
class Sub(Binary):
pass
@@ -2392,7 +2447,7 @@ class TimeUnit(Expression):
class Interval(TimeUnit):
- arg_types = {"this": True, "unit": False}
+ arg_types = {"this": False, "unit": False}
class IgnoreNulls(Expression):
@@ -2730,8 +2785,11 @@ class Initcap(Func):
pass
-class JSONExtract(Func):
- arg_types = {"this": True, "path": True}
+class JSONBContains(Binary):
+ _sql_names = ["JSONB_CONTAINS"]
+
+
+class JSONExtract(Binary, Func):
_sql_names = ["JSON_EXTRACT"]
@@ -2776,6 +2834,10 @@ class Log10(Func):
pass
+class LogicalOr(AggFunc):
+ _sql_names = ["LOGICAL_OR", "BOOL_OR"]
+
+
class Lower(Func):
_sql_names = ["LOWER", "LCASE"]
@@ -2846,6 +2908,10 @@ class RegexpLike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
+class RegexpILike(Func):
+ arg_types = {"this": True, "expression": True, "flag": False}
+
+
class RegexpSplit(Func):
arg_types = {"this": True, "expression": True}
@@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
],
)
if from_:
- update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
+ update.set(
+ "from",
+ maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
+ )
if isinstance(where, Condition):
where = Where(this=where)
if where:
- update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
+ update.set(
+ "where",
+ maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
+ )
return update
@@ -3522,7 +3594,7 @@ def paren(expression) -> Paren:
return Paren(this=expression)
-SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
+SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
@@ -3724,6 +3796,8 @@ def convert(value) -> Expression:
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
+ if isinstance(value, float) and math.isnan(value):
+ return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, tuple):
@@ -3732,11 +3806,13 @@ def convert(value) -> Expression:
return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict):
return Map(
- keys=[convert(k) for k in value.keys()],
+ keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
- datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z"))
+ datetime_literal = Literal.string(
+ (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
+ )
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))