diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/expressions.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 360 |
1 files changed, 273 insertions, 87 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 99ebfb3..99722be 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -53,6 +53,7 @@ class _Expression(type): SQLGLOT_META = "sqlglot.meta" +TABLE_PARTS = ("this", "db", "catalog") class Expression(metaclass=_Expression): @@ -134,7 +135,7 @@ class Expression(metaclass=_Expression): return self.args.get("expression") @property - def expressions(self): + def expressions(self) -> t.List[t.Any]: """ Retrieves the argument with key "expressions". """ @@ -238,6 +239,9 @@ class Expression(metaclass=_Expression): dtype = DataType.build(dtype) self._type = dtype # type: ignore + def is_type(self, *dtypes) -> bool: + return self.type is not None and self.type.is_type(*dtypes) + @property def meta(self) -> t.Dict[str, t.Any]: if self._meta is None: @@ -481,7 +485,7 @@ class Expression(metaclass=_Expression): def flatten(self, unnest=True): """ - Returns a generator which yields child nodes who's parents are the same class. + Returns a generator which yields child nodes whose parents are the same class. A AND B AND C -> [A, B, C] """ @@ -508,7 +512,7 @@ class Expression(metaclass=_Expression): """ from sqlglot.dialects import Dialect - return Dialect.get_or_raise(dialect)().generate(self, **opts) + return Dialect.get_or_raise(dialect).generate(self, **opts) def _to_s(self, hide_missing: bool = True, level: int = 0) -> str: indent = "" if not level else "\n" @@ -821,6 +825,12 @@ class Expression(metaclass=_Expression): def rlike(self, other: ExpOrStr) -> RegexpLike: return self._binop(RegexpLike, other) + def div(self, other: ExpOrStr, typed: bool = False, safe: bool = False) -> Div: + div = self._binop(Div, other) + div.args["typed"] = typed + div.args["safe"] = safe + return div + def __lt__(self, other: t.Any) -> LT: return self._binop(LT, other) @@ -1000,7 +1010,6 @@ class UDTF(DerivedTable, Unionable): class Cache(Expression): arg_types = { - "with": False, "this": True, "lazy": False, "options": False, @@ -1012,6 +1021,10 @@ class Uncache(Expression): arg_types = {"this": True, "exists": False} +class Refresh(Expression): + pass + + class DDL(Expression): @property def ctes(self): @@ -1033,6 +1046,43 @@ class DDL(Expression): return [] +class DML(Expression): + def returning( + self, + expression: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> DML: + """ + Set the RETURNING expression. Not supported by all dialects. + + Example: + >>> delete("tbl").returning("*", dialect="postgres").sql() + 'DELETE FROM tbl RETURNING *' + + Args: + expression: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="returning", + prefix="RETURNING", + dialect=dialect, + copy=copy, + into=Returning, + **opts, + ) + + class Create(DDL): arg_types = { "with": False, @@ -1133,8 +1183,10 @@ class WithinGroup(Expression): arg_types = {"this": True, "expression": False} +# clickhouse supports scalar ctes +# https://clickhouse.com/docs/en/sql-reference/statements/select/with class CTE(DerivedTable): - arg_types = {"this": True, "alias": True} + arg_types = {"this": True, "alias": True, "scalar": False} class TableAlias(Expression): @@ -1297,6 +1349,10 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind): pass +class PeriodForSystemTimeConstraint(ColumnConstraintKind): + arg_types = {"this": True, "expression": True} + + class CaseSpecificColumnConstraint(ColumnConstraintKind): arg_types = {"not_": True} @@ -1351,6 +1407,10 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): } +class GeneratedAsRowColumnConstraint(ColumnConstraintKind): + arg_types = {"start": True, "hidden": False} + + # https://dev.mysql.com/doc/refman/8.0/en/create-table.html class IndexColumnConstraint(ColumnConstraintKind): arg_types = { @@ -1383,6 +1443,11 @@ class OnUpdateColumnConstraint(ColumnConstraintKind): pass +# https://docs.snowflake.com/en/sql-reference/sql/create-external-table#optional-parameters +class TransformColumnConstraint(ColumnConstraintKind): + pass + + class PrimaryKeyColumnConstraint(ColumnConstraintKind): arg_types = {"desc": False} @@ -1413,7 +1478,7 @@ class Constraint(Expression): arg_types = {"this": True, "expressions": True} -class Delete(Expression): +class Delete(DML): arg_types = { "with": False, "this": False, @@ -1496,41 +1561,6 @@ class Delete(Expression): **opts, ) - def returning( - self, - expression: ExpOrStr, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Delete: - """ - Set the RETURNING expression. Not supported by all dialects. - - Example: - >>> delete("tbl").returning("*", dialect="postgres").sql() - 'DELETE FROM tbl RETURNING *' - - Args: - expression: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Delete: the modified expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="returning", - prefix="RETURNING", - dialect=dialect, - copy=copy, - into=Returning, - **opts, - ) - class Drop(Expression): arg_types = { @@ -1648,7 +1678,7 @@ class Index(Expression): } -class Insert(DDL): +class Insert(DDL, DML): arg_types = { "with": False, "this": True, @@ -2259,6 +2289,11 @@ class WithJournalTableProperty(Property): arg_types = {"this": True} +class WithSystemVersioningProperty(Property): + # this -> history table name, expression -> data consistency check + arg_types = {"this": False, "expression": False} + + class Properties(Expression): arg_types = {"expressions": True} @@ -3663,6 +3698,7 @@ class DataType(Expression): Type.BIGINT, Type.INT128, Type.INT256, + Type.BIT, } FLOAT_TYPES = { @@ -3692,7 +3728,7 @@ class DataType(Expression): @classmethod def build( cls, - dtype: str | DataType | DataType.Type, + dtype: DATA_TYPE, dialect: DialectType = None, udt: bool = False, **kwargs, @@ -3733,7 +3769,7 @@ class DataType(Expression): return DataType(**{**data_type_exp.args, **kwargs}) - def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: + def is_type(self, *dtypes: DATA_TYPE) -> bool: """ Checks whether this DataType matches one of the provided data types. Nested types or precision will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>. @@ -3761,6 +3797,9 @@ class DataType(Expression): return False +DATA_TYPE = t.Union[str, DataType, DataType.Type] + + # https://www.postgresql.org/docs/15/datatype-pseudo.html class PseudoType(DataType): arg_types = {"this": True} @@ -3868,7 +3907,7 @@ class BitwiseXor(Binary): class Div(Binary): - pass + arg_types = {"this": True, "expression": True, "typed": False, "safe": False} class Overlaps(Binary): @@ -3892,13 +3931,25 @@ class Dot(Binary): return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions)) + @property + def parts(self) -> t.List[Expression]: + """Return the parts of a table / column in order catalog, db, table.""" + this, *parts = self.flatten() -class DPipe(Binary): - pass + parts.reverse() + for arg in ("this", "table", "db", "catalog"): + part = this.args.get(arg) -class SafeDPipe(DPipe): - pass + if isinstance(part, Expression): + parts.append(part) + + parts.reverse() + return parts + + +class DPipe(Binary): + arg_types = {"this": True, "expression": True, "safe": False} class EQ(Binary, Predicate): @@ -3913,6 +3964,11 @@ class NullSafeNEQ(Binary, Predicate): pass +# Represents e.g. := in DuckDB which is mostly used for setting parameters +class PropertyEQ(Binary): + pass + + class Distance(Binary): pass @@ -3981,6 +4037,11 @@ class NEQ(Binary, Predicate): pass +# https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH +class Operator(Binary): + arg_types = {"this": True, "operator": True, "expression": True} + + class SimilarTo(Binary, Predicate): pass @@ -4048,7 +4109,8 @@ class Between(Predicate): class Bracket(Condition): - arg_types = {"this": True, "expressions": True} + # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator + arg_types = {"this": True, "expressions": True, "offset": False, "safe": False} @property def output_name(self) -> str: @@ -4058,10 +4120,6 @@ class Bracket(Condition): return super().output_name -class SafeBracket(Bracket): - """Represents array lookup where OOB index yields NULL instead of causing a failure.""" - - class Distinct(Expression): arg_types = {"expressions": False, "on": False} @@ -4077,6 +4135,11 @@ class In(Predicate): } +# https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#for-in +class ForIn(Expression): + arg_types = {"this": True, "expression": True} + + class TimeUnit(Expression): """Automatically converts unit arg into a var.""" @@ -4248,8 +4311,9 @@ class Array(Func): # https://docs.snowflake.com/en/sql-reference/functions/to_char +# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html class ToChar(Func): - arg_types = {"this": True, "format": False} + arg_types = {"this": True, "format": False, "nlsparam": False} class GenerateSeries(Func): @@ -4260,6 +4324,10 @@ class ArrayAgg(AggFunc): pass +class ArrayUniqueAgg(AggFunc): + pass + + class ArrayAll(Func): arg_types = {"this": True, "expression": True} @@ -4358,7 +4426,7 @@ class Cast(Func): def output_name(self) -> str: return self.name - def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: + def is_type(self, *dtypes: DATA_TYPE) -> bool: """ Checks whether this Cast's DataType matches one of the provided data types. Nested types like arrays or structs will be compared using "structural equivalence" semantics, so e.g. @@ -4403,14 +4471,10 @@ class Chr(Func): class Concat(Func): - arg_types = {"expressions": True} + arg_types = {"expressions": True, "safe": False, "coalesce": False} is_var_len_args = True -class SafeConcat(Concat): - pass - - class ConcatWs(Concat): _sql_names = ["CONCAT_WS"] @@ -4643,6 +4707,10 @@ class If(Func): arg_types = {"this": True, "true": True, "false": False} +class Nullif(Func): + arg_types = {"this": True, "expression": True} + + class Initcap(Func): arg_types = {"this": True, "expression": False} @@ -4651,6 +4719,10 @@ class IsNan(Func): _sql_names = ["IS_NAN", "ISNAN"] +class IsInf(Func): + _sql_names = ["IS_INF", "ISINF"] + + class FormatJson(Expression): pass @@ -4970,10 +5042,6 @@ class SafeDivide(Func): arg_types = {"this": True, "expression": True} -class SetAgg(AggFunc): - pass - - class SHA(Func): _sql_names = ["SHA", "SHA1"] @@ -5118,6 +5186,15 @@ class Trim(Func): class TsOrDsAdd(Func, TimeUnit): + # return_type is used to correctly cast the arguments of this expression when transpiling it + arg_types = {"this": True, "expression": True, "unit": False, "return_type": False} + + @property + def return_type(self) -> DataType: + return DataType.build(self.args.get("return_type") or DataType.Type.DATE) + + +class TsOrDsDiff(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -5149,6 +5226,7 @@ class UnixToTime(Func): SECONDS = Literal.string("seconds") MILLIS = Literal.string("millis") MICROS = Literal.string("micros") + NANOS = Literal.string("nanos") class UnixToTimeStr(Func): @@ -5202,6 +5280,7 @@ def _norm_arg(arg): ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) +FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()} # Helpers @@ -5693,7 +5772,9 @@ def delete( if where: delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts) if returning: - delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts) + delete_expr = t.cast( + Delete, delete_expr.returning(returning, dialect=dialect, copy=False, **opts) + ) return delete_expr @@ -5702,6 +5783,7 @@ def insert( into: ExpOrStr, columns: t.Optional[t.Sequence[ExpOrStr]] = None, overwrite: t.Optional[bool] = None, + returning: t.Optional[ExpOrStr] = None, dialect: DialectType = None, copy: bool = True, **opts, @@ -5718,6 +5800,7 @@ def insert( into: the tbl to insert data to. columns: optionally the table's column names. overwrite: whether to INSERT OVERWRITE or not. + returning: sql conditional parsed into a RETURNING statement dialect: the dialect used to parse the input expressions. copy: whether or not to copy the expression. **opts: other options to use to parse the input expressions. @@ -5739,7 +5822,12 @@ def insert( **opts, ) - return Insert(this=this, expression=expr, overwrite=overwrite) + insert = Insert(this=this, expression=expr, overwrite=overwrite) + + if returning: + insert = t.cast(Insert, insert.returning(returning, dialect=dialect, copy=False, **opts)) + + return insert def condition( @@ -5913,7 +6001,7 @@ def to_identifier(name, quoted=None, copy=True): return identifier -def parse_identifier(name: str, dialect: DialectType = None) -> Identifier: +def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Identifier: """ Parses a given string into an identifier. @@ -5965,7 +6053,7 @@ def to_table(sql_path: None, **kwargs) -> None: def to_table( - sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs + sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs ) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. @@ -5974,13 +6062,14 @@ def to_table( Args: sql_path: a `[catalog].[schema].[table]` string. dialect: the source dialect according to which the table name will be parsed. + copy: Whether or not to copy a table if it is passed in. kwargs: the kwargs to instantiate the resulting `Table` expression with. Returns: A table expression. """ if sql_path is None or isinstance(sql_path, Table): - return sql_path + return maybe_copy(sql_path, copy=copy) if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") @@ -6123,7 +6212,7 @@ def column( ) -def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast: +def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast: """Cast an expression to a data type. Example: @@ -6335,12 +6424,15 @@ def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]: } -def table_name(table: Table | str, dialect: DialectType = None) -> str: +def table_name(table: Table | str, dialect: DialectType = None, identify: bool = False) -> str: """Get the full name of a table as a string. Args: table: Table expression node or string. dialect: The dialect to generate the table name for. + identify: Determines when an identifier should be quoted. Possible values are: + False (default): Never quote, except in cases where it's mandatory by the dialect. + True: Always quote. Examples: >>> from sqlglot import exp, parse_one @@ -6358,37 +6450,68 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str: return ".".join( part.sql(dialect=dialect, identify=True) - if not SAFE_IDENTIFIER_RE.match(part.name) + if identify or not SAFE_IDENTIFIER_RE.match(part.name) else part.name for part in table.parts ) -def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E: +def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: bool = True) -> str: + """Returns a case normalized table name without quotes. + + Args: + table: the table to normalize + dialect: the dialect to use for normalization rules + copy: whether or not to copy the expression. + + Examples: + >>> normalize_table_name("`A-B`.c", dialect="bigquery") + 'A-B.c' + """ + from sqlglot.optimizer.normalize_identifiers import normalize_identifiers + + return ".".join( + p.name + for p in normalize_identifiers( + to_table(table, dialect=dialect, copy=copy), dialect=dialect + ).parts + ) + + +def replace_tables( + expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True +) -> E: """Replace all tables in expression according to the mapping. Args: expression: expression node to be transformed and replaced. mapping: mapping of table names. + dialect: the dialect of the mapping table copy: whether or not to copy the expression. Examples: >>> from sqlglot import exp, parse_one >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() - 'SELECT * FROM c' + 'SELECT * FROM c /* a.b */' Returns: The mapped expression. """ + mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()} + def _replace_tables(node: Expression) -> Expression: if isinstance(node, Table): - new_name = mapping.get(table_name(node)) + original = normalize_table_name(node, dialect=dialect) + new_name = mapping.get(original) + if new_name: - return to_table( + table = to_table( new_name, - **{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")}, + **{k: v for k, v in node.args.items() if k not in TABLE_PARTS}, ) + table.add_comments([original]) + return table return node return expression.transform(_replace_tables, copy=copy) @@ -6431,7 +6554,10 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: def expand( - expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True + expression: Expression, + sources: t.Dict[str, Subqueryable], + dialect: DialectType = None, + copy: bool = True, ) -> Expression: """Transforms an expression by expanding all referenced sources into subqueries. @@ -6446,15 +6572,17 @@ def expand( Args: expression: The expression to expand. sources: A dictionary of name to Subqueryables. + dialect: The dialect of the sources dict. copy: Whether or not to copy the expression during transformation. Defaults to True. Returns: The transformed expression. """ + sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()} def _expand(node: Expression): if isinstance(node, Table): - name = table_name(node) + name = normalize_table_name(node, dialect=dialect) source = sources.get(name) if source: subquery = source.subquery(node.alias or name) @@ -6465,7 +6593,7 @@ def expand( return expression.transform(_expand, copy=copy) -def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: +def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func: """ Returns a Func expression. @@ -6479,6 +6607,7 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: Args: name: the name of the function to build. args: the args used to instantiate the function of interest. + copy: whether or not to copy the argument expressions. dialect: the source dialect. kwargs: the kwargs used to instantiate the function of interest. @@ -6494,14 +6623,29 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: from sqlglot.dialects.dialect import Dialect - converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args] - kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()} + dialect = Dialect.get_or_raise(dialect) - parser = Dialect.get_or_raise(dialect)().parser() - from_args_list = parser.FUNCTIONS.get(name.upper()) + converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args] + kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()} - if from_args_list: - function = from_args_list(converted) if converted else from_args_list.__self__(**kwargs) # type: ignore + constructor = dialect.parser_class.FUNCTIONS.get(name.upper()) + if constructor: + if converted: + if "dialect" in constructor.__code__.co_varnames: + function = constructor(converted, dialect=dialect) + else: + function = constructor(converted) + elif constructor.__name__ == "from_arg_list": + function = constructor.__self__(**kwargs) # type: ignore + else: + constructor = FUNCTION_BY_NAME.get(name.upper()) + if constructor: + function = constructor(**kwargs) + else: + raise ValueError( + f"Unable to convert '{name}' into a Func. Either manually construct " + "the Func expression of interest or parse the function call." + ) else: kwargs = kwargs or {"expressions": converted} function = Anonymous(this=name, **kwargs) @@ -6512,6 +6656,48 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: return function +def case( + expression: t.Optional[ExpOrStr] = None, + **opts, +) -> Case: + """ + Initialize a CASE statement. + + Example: + case().when("a = 1", "foo").else_("bar") + + Args: + expression: Optionally, the input expression (not all dialects support this) + **opts: Extra keyword arguments for parsing `expression` + """ + if expression is not None: + this = maybe_parse(expression, **opts) + else: + this = None + return Case(this=this, ifs=[]) + + +def cast_unless( + expression: ExpOrStr, + to: DATA_TYPE, + *types: DATA_TYPE, + **opts: t.Any, +) -> Expression | Cast: + """ + Cast an expression to a data type unless it is a specified type. + + Args: + expression: The expression to cast. + to: The data type to cast to. + **types: The types to exclude from casting. + **opts: Extra keyword arguments for parsing `expression` + """ + expr = maybe_parse(expression, **opts) + if expr.is_type(*types): + return expr + return cast(expr, to, **opts) + + def true() -> Boolean: """ Returns a true Boolean expression. |