summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py117
1 files changed, 100 insertions, 17 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 711ec4b..d093e29 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -22,6 +22,7 @@ from sqlglot.helper import (
split_num_words,
subclasses,
)
+from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
@@ -457,6 +458,23 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_)
return self
+ def dump(self):
+ """
+ Dump this Expression to a JSON-serializable dict.
+ """
+ from sqlglot.serde import dump
+
+ return dump(self)
+
+ @classmethod
+ def load(cls, obj):
+ """
+ Load a dict (as returned by `Expression.dump`) into an Expression instance.
+ """
+ from sqlglot.serde import load
+
+ return load(obj)
+
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
@@ -631,11 +649,15 @@ class Create(Expression):
"replace": False,
"unique": False,
"materialized": False,
+ "data": False,
+ "statistics": False,
+ "no_primary_index": False,
+ "indexes": False,
}
class Describe(Expression):
- pass
+ arg_types = {"this": True, "kind": False}
class Set(Expression):
@@ -731,7 +753,7 @@ class Column(Condition):
class ColumnDef(Expression):
arg_types = {
"this": True,
- "kind": True,
+ "kind": False,
"constraints": False,
"exists": False,
}
@@ -879,7 +901,15 @@ class Identifier(Expression):
class Index(Expression):
- arg_types = {"this": False, "table": False, "where": False, "columns": False}
+ arg_types = {
+ "this": False,
+ "table": False,
+ "where": False,
+ "columns": False,
+ "unique": False,
+ "primary": False,
+ "amp": False, # teradata
+ }
class Insert(Expression):
@@ -1361,6 +1391,7 @@ class Table(Expression):
"laterals": False,
"joins": False,
"pivots": False,
+ "hints": False,
}
@@ -1818,7 +1849,12 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
+ natural: t.Optional[Token]
+ side: t.Optional[Token]
+ kind: t.Optional[Token]
+
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
+
if natural:
join.set("natural", True)
if side:
@@ -2111,6 +2147,7 @@ class DataType(Expression):
JSON = auto()
JSONB = auto()
INTERVAL = auto()
+ TIME = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@@ -2171,11 +2208,24 @@ class DataType(Expression):
}
@classmethod
- def build(cls, dtype, **kwargs) -> DataType:
- return DataType(
- this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
- **kwargs,
- )
+ def build(
+ cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
+ ) -> DataType:
+ from sqlglot import parse_one
+
+ if isinstance(dtype, str):
+ data_type_exp: t.Optional[Expression]
+ if dtype.upper() in cls.Type.__members__:
+ data_type_exp = DataType(this=DataType.Type[dtype.upper()])
+ else:
+ data_type_exp = parse_one(dtype, read=dialect, into=DataType)
+ if data_type_exp is None:
+ raise ValueError(f"Unparsable data type value: {dtype}")
+ elif isinstance(dtype, DataType.Type):
+ data_type_exp = DataType(this=dtype)
+ else:
+ raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
+ return DataType(**{**data_type_exp.args, **kwargs})
# https://www.postgresql.org/docs/15/datatype-pseudo.html
@@ -2429,6 +2479,7 @@ class In(Predicate):
"query": False,
"unnest": False,
"field": False,
+ "is_global": False,
}
@@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
+class LastDateOfMonth(Func):
+ pass
+
+
class Extract(Func):
arg_types = {"this": True, "expression": True}
@@ -2815,7 +2870,13 @@ class Length(Func):
class Levenshtein(Func):
- arg_types = {"this": True, "expression": False}
+ arg_types = {
+ "this": True,
+ "expression": False,
+ "ins_cost": False,
+ "del_cost": False,
+ "sub_cost": False,
+ }
class Ln(Func):
@@ -2890,6 +2951,16 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
+# Clickhouse-specific:
+# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
+class Quantiles(AggFunc):
+ arg_types = {"parameters": True, "expressions": True}
+
+
+class QuantileIf(AggFunc):
+ arg_types = {"parameters": True, "expressions": True}
+
+
class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
@@ -2962,8 +3033,10 @@ class StrToTime(Func):
arg_types = {"this": True, "format": True}
+# Spark allows unix_timestamp()
+# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html
class StrToUnix(Func):
- arg_types = {"this": True, "format": True}
+ arg_types = {"this": False, "format": False}
class NumberToStr(Func):
@@ -3131,7 +3204,7 @@ def maybe_parse(
dialect=None,
prefix=None,
**opts,
-) -> t.Optional[Expression]:
+) -> Expression:
"""Gracefully handle a possible string or expression.
Example:
@@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
- catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
+ catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
-def to_column(sql_path: str, **kwargs) -> Column:
+def to_column(sql_path: str | Column, **kwargs) -> Column:
"""
Create a column from a `[table].[column]` sql path. Schema is optional.
@@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
- table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
+ table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
return Column(this=column_name, table=table_name, **kwargs)
@@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
def values(
values: t.Iterable[t.Tuple[t.Any, ...]],
alias: t.Optional[str] = None,
- columns: t.Optional[t.Iterable[str]] = None,
+ columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None,
) -> Values:
"""Build VALUES statement.
@@ -3759,7 +3832,10 @@ def values(
Args:
values: values statements that will be converted to SQL
alias: optional alias
- columns: Optional list of ordered column names. An alias is required when providing column names.
+ columns: Optional list of ordered column names or ordered dictionary of column names to types.
+ If either are provided then an alias is also required.
+ If a dictionary is provided then the first column of the values will be casted to the expected type
+ in order to help with type inference.
Returns:
Values: the Values expression object
@@ -3771,8 +3847,15 @@ def values(
if columns
else TableAlias(this=to_identifier(alias) if alias else None)
)
+ expressions = [convert(tup) for tup in values]
+ if columns and isinstance(columns, dict):
+ types = list(columns.values())
+ expressions[0].set(
+ "expressions",
+ [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
+ )
return Values(
- expressions=[convert(tup) for tup in values],
+ expressions=expressions,
alias=table_alias,
)