diff options
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 117 |
1 files changed, 100 insertions, 17 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 711ec4b..d093e29 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -22,6 +22,7 @@ from sqlglot.helper import ( split_num_words, subclasses, ) +from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import Dialect @@ -457,6 +458,23 @@ class Expression(metaclass=_Expression): assert isinstance(self, type_) return self + def dump(self): + """ + Dump this Expression to a JSON-serializable dict. + """ + from sqlglot.serde import dump + + return dump(self) + + @classmethod + def load(cls, obj): + """ + Load a dict (as returned by `Expression.dump`) into an Expression instance. + """ + from sqlglot.serde import load + + return load(obj) + class Condition(Expression): def and_(self, *expressions, dialect=None, **opts): @@ -631,11 +649,15 @@ class Create(Expression): "replace": False, "unique": False, "materialized": False, + "data": False, + "statistics": False, + "no_primary_index": False, + "indexes": False, } class Describe(Expression): - pass + arg_types = {"this": True, "kind": False} class Set(Expression): @@ -731,7 +753,7 @@ class Column(Condition): class ColumnDef(Expression): arg_types = { "this": True, - "kind": True, + "kind": False, "constraints": False, "exists": False, } @@ -879,7 +901,15 @@ class Identifier(Expression): class Index(Expression): - arg_types = {"this": False, "table": False, "where": False, "columns": False} + arg_types = { + "this": False, + "table": False, + "where": False, + "columns": False, + "unique": False, + "primary": False, + "amp": False, # teradata + } class Insert(Expression): @@ -1361,6 +1391,7 @@ class Table(Expression): "laterals": False, "joins": False, "pivots": False, + "hints": False, } @@ -1818,7 +1849,12 @@ class Select(Subqueryable): join.this.replace(join.this.subquery()) if join_type: + natural: t.Optional[Token] + side: t.Optional[Token] + kind: t.Optional[Token] + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore + if natural: join.set("natural", True) if side: @@ -2111,6 +2147,7 @@ class DataType(Expression): JSON = auto() JSONB = auto() INTERVAL = auto() + TIME = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -2171,11 +2208,24 @@ class DataType(Expression): } @classmethod - def build(cls, dtype, **kwargs) -> DataType: - return DataType( - this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], - **kwargs, - ) + def build( + cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs + ) -> DataType: + from sqlglot import parse_one + + if isinstance(dtype, str): + data_type_exp: t.Optional[Expression] + if dtype.upper() in cls.Type.__members__: + data_type_exp = DataType(this=DataType.Type[dtype.upper()]) + else: + data_type_exp = parse_one(dtype, read=dialect, into=DataType) + if data_type_exp is None: + raise ValueError(f"Unparsable data type value: {dtype}") + elif isinstance(dtype, DataType.Type): + data_type_exp = DataType(this=dtype) + else: + raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") + return DataType(**{**data_type_exp.args, **kwargs}) # https://www.postgresql.org/docs/15/datatype-pseudo.html @@ -2429,6 +2479,7 @@ class In(Predicate): "query": False, "unnest": False, "field": False, + "is_global": False, } @@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class LastDateOfMonth(Func): + pass + + class Extract(Func): arg_types = {"this": True, "expression": True} @@ -2815,7 +2870,13 @@ class Length(Func): class Levenshtein(Func): - arg_types = {"this": True, "expression": False} + arg_types = { + "this": True, + "expression": False, + "ins_cost": False, + "del_cost": False, + "sub_cost": False, + } class Ln(Func): @@ -2890,6 +2951,16 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} +# Clickhouse-specific: +# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles +class Quantiles(AggFunc): + arg_types = {"parameters": True, "expressions": True} + + +class QuantileIf(AggFunc): + arg_types = {"parameters": True, "expressions": True} + + class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False} @@ -2962,8 +3033,10 @@ class StrToTime(Func): arg_types = {"this": True, "format": True} +# Spark allows unix_timestamp() +# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html class StrToUnix(Func): - arg_types = {"this": True, "format": True} + arg_types = {"this": False, "format": False} class NumberToStr(Func): @@ -3131,7 +3204,7 @@ def maybe_parse( dialect=None, prefix=None, **opts, -) -> t.Optional[Expression]: +) -> Expression: """Gracefully handle a possible string or expression. Example: @@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") - catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)] + catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3)) return Table(this=table_name, db=db, catalog=catalog, **kwargs) -def to_column(sql_path: str, **kwargs) -> Column: +def to_column(sql_path: str | Column, **kwargs) -> Column: """ Create a column from a `[table].[column]` sql path. Schema is optional. @@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column: return sql_path if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)] + table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2)) return Column(this=column_name, table=table_name, **kwargs) @@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: def values( values: t.Iterable[t.Tuple[t.Any, ...]], alias: t.Optional[str] = None, - columns: t.Optional[t.Iterable[str]] = None, + columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None, ) -> Values: """Build VALUES statement. @@ -3759,7 +3832,10 @@ def values( Args: values: values statements that will be converted to SQL alias: optional alias - columns: Optional list of ordered column names. An alias is required when providing column names. + columns: Optional list of ordered column names or ordered dictionary of column names to types. + If either are provided then an alias is also required. + If a dictionary is provided then the first column of the values will be casted to the expected type + in order to help with type inference. Returns: Values: the Values expression object @@ -3771,8 +3847,15 @@ def values( if columns else TableAlias(this=to_identifier(alias) if alias else None) ) + expressions = [convert(tup) for tup in values] + if columns and isinstance(columns, dict): + types = list(columns.values()) + expressions[0].set( + "expressions", + [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)], + ) return Values( - expressions=[convert(tup) for tup in values], + expressions=expressions, alias=table_alias, ) |