summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/expressions.py80
1 files changed, 67 insertions, 13 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index b94b1e1..5b012b1 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -23,7 +23,7 @@ from enum import auto
from functools import reduce
from sqlglot._typing import E
-from sqlglot.errors import ParseError
+from sqlglot.errors import ErrorLevel, ParseError
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
@@ -120,14 +120,14 @@ class Expression(metaclass=_Expression):
return hash((self.__class__, self.hashable_args))
@property
- def this(self):
+ def this(self) -> t.Any:
"""
Retrieves the argument with key "this".
"""
return self.args.get("this")
@property
- def expression(self):
+ def expression(self) -> t.Any:
"""
Retrieves the argument with key "expression".
"""
@@ -1235,6 +1235,10 @@ class RenameTable(Expression):
pass
+class SwapTable(Expression):
+ pass
+
+
class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
@@ -1979,7 +1983,7 @@ class ChecksumProperty(Property):
class CollateProperty(Property):
- arg_types = {"this": True}
+ arg_types = {"this": True, "default": False}
class CopyGrantsProperty(Property):
@@ -2607,11 +2611,11 @@ class Union(Subqueryable):
return self.this.unnest().selects
@property
- def left(self):
+ def left(self) -> Expression:
return self.this
@property
- def right(self):
+ def right(self) -> Expression:
return self.expression
@@ -3700,7 +3704,9 @@ class DataType(Expression):
return DataType(this=DataType.Type.UNKNOWN, **kwargs)
try:
- data_type_exp = parse_one(dtype, read=dialect, into=DataType)
+ data_type_exp = parse_one(
+ dtype, read=dialect, into=DataType, error_level=ErrorLevel.IGNORE
+ )
except ParseError:
if udt:
return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
@@ -3804,11 +3810,11 @@ class Binary(Condition):
arg_types = {"this": True, "expression": True}
@property
- def left(self):
+ def left(self) -> Expression:
return self.this
@property
- def right(self):
+ def right(self) -> Expression:
return self.expression
@@ -4063,10 +4069,25 @@ class TimeUnit(Expression):
arg_types = {"unit": False}
+ UNABBREVIATED_UNIT_NAME = {
+ "d": "day",
+ "h": "hour",
+ "m": "minute",
+ "ms": "millisecond",
+ "ns": "nanosecond",
+ "q": "quarter",
+ "s": "second",
+ "us": "microsecond",
+ "w": "week",
+ "y": "year",
+ }
+
+ VAR_LIKE = (Column, Literal, Var)
+
def __init__(self, **args):
unit = args.get("unit")
- if isinstance(unit, (Column, Literal)):
- args["unit"] = Var(this=unit.name)
+ if isinstance(unit, self.VAR_LIKE):
+ args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
@@ -4168,6 +4189,24 @@ class Abs(Func):
pass
+class ArgMax(AggFunc):
+ arg_types = {"this": True, "expression": True, "count": False}
+ _sql_names = ["ARG_MAX", "ARGMAX", "MAX_BY"]
+
+
+class ArgMin(AggFunc):
+ arg_types = {"this": True, "expression": True, "count": False}
+ _sql_names = ["ARG_MIN", "ARGMIN", "MIN_BY"]
+
+
+class ApproxTopK(AggFunc):
+ arg_types = {"this": True, "expression": False, "counters": False}
+
+
+class Flatten(Func):
+ pass
+
+
# https://spark.apache.org/docs/latest/api/sql/index.html#transform
class Transform(Func):
arg_types = {"this": True, "expression": True}
@@ -4540,8 +4579,10 @@ class Exp(Func):
pass
+# https://docs.snowflake.com/en/sql-reference/functions/flatten
class Explode(Func):
- pass
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
class ExplodeOuter(Explode):
@@ -4698,6 +4739,8 @@ class JSONArrayContains(Binary, Predicate, Func):
class ParseJSON(Func):
# BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE
_sql_names = ["PARSE_JSON", "JSON_PARSE"]
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
class Least(Func):
@@ -4758,6 +4801,16 @@ class Lower(Func):
class Map(Func):
arg_types = {"keys": False, "values": False}
+ @property
+ def keys(self) -> t.List[Expression]:
+ keys = self.args.get("keys")
+ return keys.expressions if keys else []
+
+ @property
+ def values(self) -> t.List[Expression]:
+ values = self.args.get("values")
+ return values.expressions if values else []
+
class MapFromEntries(Func):
pass
@@ -4870,6 +4923,7 @@ class RegexpReplace(Func):
"position": False,
"occurrence": False,
"parameters": False,
+ "modifiers": False,
}
@@ -4877,7 +4931,7 @@ class RegexpLike(Binary, Func):
arg_types = {"this": True, "expression": True, "flag": False}
-class RegexpILike(Func):
+class RegexpILike(Binary, Func):
arg_types = {"this": True, "expression": True, "flag": False}