summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/expressions.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py360
1 files changed, 273 insertions, 87 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 99ebfb3..99722be 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -53,6 +53,7 @@ class _Expression(type):
SQLGLOT_META = "sqlglot.meta"
+TABLE_PARTS = ("this", "db", "catalog")
class Expression(metaclass=_Expression):
@@ -134,7 +135,7 @@ class Expression(metaclass=_Expression):
return self.args.get("expression")
@property
- def expressions(self):
+ def expressions(self) -> t.List[t.Any]:
"""
Retrieves the argument with key "expressions".
"""
@@ -238,6 +239,9 @@ class Expression(metaclass=_Expression):
dtype = DataType.build(dtype)
self._type = dtype # type: ignore
+ def is_type(self, *dtypes) -> bool:
+ return self.type is not None and self.type.is_type(*dtypes)
+
@property
def meta(self) -> t.Dict[str, t.Any]:
if self._meta is None:
@@ -481,7 +485,7 @@ class Expression(metaclass=_Expression):
def flatten(self, unnest=True):
"""
- Returns a generator which yields child nodes who's parents are the same class.
+ Returns a generator which yields child nodes whose parents are the same class.
A AND B AND C -> [A, B, C]
"""
@@ -508,7 +512,7 @@ class Expression(metaclass=_Expression):
"""
from sqlglot.dialects import Dialect
- return Dialect.get_or_raise(dialect)().generate(self, **opts)
+ return Dialect.get_or_raise(dialect).generate(self, **opts)
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
@@ -821,6 +825,12 @@ class Expression(metaclass=_Expression):
def rlike(self, other: ExpOrStr) -> RegexpLike:
return self._binop(RegexpLike, other)
+ def div(self, other: ExpOrStr, typed: bool = False, safe: bool = False) -> Div:
+ div = self._binop(Div, other)
+ div.args["typed"] = typed
+ div.args["safe"] = safe
+ return div
+
def __lt__(self, other: t.Any) -> LT:
return self._binop(LT, other)
@@ -1000,7 +1010,6 @@ class UDTF(DerivedTable, Unionable):
class Cache(Expression):
arg_types = {
- "with": False,
"this": True,
"lazy": False,
"options": False,
@@ -1012,6 +1021,10 @@ class Uncache(Expression):
arg_types = {"this": True, "exists": False}
+class Refresh(Expression):
+ pass
+
+
class DDL(Expression):
@property
def ctes(self):
@@ -1033,6 +1046,43 @@ class DDL(Expression):
return []
+class DML(Expression):
+ def returning(
+ self,
+ expression: ExpOrStr,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> DML:
+ """
+ Set the RETURNING expression. Not supported by all dialects.
+
+ Example:
+ >>> delete("tbl").returning("*", dialect="postgres").sql()
+ 'DELETE FROM tbl RETURNING *'
+
+ Args:
+ expression: the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ Delete: the modified expression.
+ """
+ return _apply_builder(
+ expression=expression,
+ instance=self,
+ arg="returning",
+ prefix="RETURNING",
+ dialect=dialect,
+ copy=copy,
+ into=Returning,
+ **opts,
+ )
+
+
class Create(DDL):
arg_types = {
"with": False,
@@ -1133,8 +1183,10 @@ class WithinGroup(Expression):
arg_types = {"this": True, "expression": False}
+# clickhouse supports scalar ctes
+# https://clickhouse.com/docs/en/sql-reference/statements/select/with
class CTE(DerivedTable):
- arg_types = {"this": True, "alias": True}
+ arg_types = {"this": True, "alias": True, "scalar": False}
class TableAlias(Expression):
@@ -1297,6 +1349,10 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass
+class PeriodForSystemTimeConstraint(ColumnConstraintKind):
+ arg_types = {"this": True, "expression": True}
+
+
class CaseSpecificColumnConstraint(ColumnConstraintKind):
arg_types = {"not_": True}
@@ -1351,6 +1407,10 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
}
+class GeneratedAsRowColumnConstraint(ColumnConstraintKind):
+ arg_types = {"start": True, "hidden": False}
+
+
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
class IndexColumnConstraint(ColumnConstraintKind):
arg_types = {
@@ -1383,6 +1443,11 @@ class OnUpdateColumnConstraint(ColumnConstraintKind):
pass
+# https://docs.snowflake.com/en/sql-reference/sql/create-external-table#optional-parameters
+class TransformColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
arg_types = {"desc": False}
@@ -1413,7 +1478,7 @@ class Constraint(Expression):
arg_types = {"this": True, "expressions": True}
-class Delete(Expression):
+class Delete(DML):
arg_types = {
"with": False,
"this": False,
@@ -1496,41 +1561,6 @@ class Delete(Expression):
**opts,
)
- def returning(
- self,
- expression: ExpOrStr,
- dialect: DialectType = None,
- copy: bool = True,
- **opts,
- ) -> Delete:
- """
- Set the RETURNING expression. Not supported by all dialects.
-
- Example:
- >>> delete("tbl").returning("*", dialect="postgres").sql()
- 'DELETE FROM tbl RETURNING *'
-
- Args:
- expression: the SQL code strings to parse.
- If an `Expression` instance is passed, it will be used as-is.
- dialect: the dialect used to parse the input expressions.
- copy: if `False`, modify this expression instance in-place.
- opts: other options to use to parse the input expressions.
-
- Returns:
- Delete: the modified expression.
- """
- return _apply_builder(
- expression=expression,
- instance=self,
- arg="returning",
- prefix="RETURNING",
- dialect=dialect,
- copy=copy,
- into=Returning,
- **opts,
- )
-
class Drop(Expression):
arg_types = {
@@ -1648,7 +1678,7 @@ class Index(Expression):
}
-class Insert(DDL):
+class Insert(DDL, DML):
arg_types = {
"with": False,
"this": True,
@@ -2259,6 +2289,11 @@ class WithJournalTableProperty(Property):
arg_types = {"this": True}
+class WithSystemVersioningProperty(Property):
+ # this -> history table name, expression -> data consistency check
+ arg_types = {"this": False, "expression": False}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -3663,6 +3698,7 @@ class DataType(Expression):
Type.BIGINT,
Type.INT128,
Type.INT256,
+ Type.BIT,
}
FLOAT_TYPES = {
@@ -3692,7 +3728,7 @@ class DataType(Expression):
@classmethod
def build(
cls,
- dtype: str | DataType | DataType.Type,
+ dtype: DATA_TYPE,
dialect: DialectType = None,
udt: bool = False,
**kwargs,
@@ -3733,7 +3769,7 @@ class DataType(Expression):
return DataType(**{**data_type_exp.args, **kwargs})
- def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
+ def is_type(self, *dtypes: DATA_TYPE) -> bool:
"""
Checks whether this DataType matches one of the provided data types. Nested types or precision
will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
@@ -3761,6 +3797,9 @@ class DataType(Expression):
return False
+DATA_TYPE = t.Union[str, DataType, DataType.Type]
+
+
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(DataType):
arg_types = {"this": True}
@@ -3868,7 +3907,7 @@ class BitwiseXor(Binary):
class Div(Binary):
- pass
+ arg_types = {"this": True, "expression": True, "typed": False, "safe": False}
class Overlaps(Binary):
@@ -3892,13 +3931,25 @@ class Dot(Binary):
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
+ @property
+ def parts(self) -> t.List[Expression]:
+ """Return the parts of a table / column in order catalog, db, table."""
+ this, *parts = self.flatten()
-class DPipe(Binary):
- pass
+ parts.reverse()
+ for arg in ("this", "table", "db", "catalog"):
+ part = this.args.get(arg)
-class SafeDPipe(DPipe):
- pass
+ if isinstance(part, Expression):
+ parts.append(part)
+
+ parts.reverse()
+ return parts
+
+
+class DPipe(Binary):
+ arg_types = {"this": True, "expression": True, "safe": False}
class EQ(Binary, Predicate):
@@ -3913,6 +3964,11 @@ class NullSafeNEQ(Binary, Predicate):
pass
+# Represents e.g. := in DuckDB which is mostly used for setting parameters
+class PropertyEQ(Binary):
+ pass
+
+
class Distance(Binary):
pass
@@ -3981,6 +4037,11 @@ class NEQ(Binary, Predicate):
pass
+# https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH
+class Operator(Binary):
+ arg_types = {"this": True, "operator": True, "expression": True}
+
+
class SimilarTo(Binary, Predicate):
pass
@@ -4048,7 +4109,8 @@ class Between(Predicate):
class Bracket(Condition):
- arg_types = {"this": True, "expressions": True}
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator
+ arg_types = {"this": True, "expressions": True, "offset": False, "safe": False}
@property
def output_name(self) -> str:
@@ -4058,10 +4120,6 @@ class Bracket(Condition):
return super().output_name
-class SafeBracket(Bracket):
- """Represents array lookup where OOB index yields NULL instead of causing a failure."""
-
-
class Distinct(Expression):
arg_types = {"expressions": False, "on": False}
@@ -4077,6 +4135,11 @@ class In(Predicate):
}
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#for-in
+class ForIn(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
class TimeUnit(Expression):
"""Automatically converts unit arg into a var."""
@@ -4248,8 +4311,9 @@ class Array(Func):
# https://docs.snowflake.com/en/sql-reference/functions/to_char
+# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html
class ToChar(Func):
- arg_types = {"this": True, "format": False}
+ arg_types = {"this": True, "format": False, "nlsparam": False}
class GenerateSeries(Func):
@@ -4260,6 +4324,10 @@ class ArrayAgg(AggFunc):
pass
+class ArrayUniqueAgg(AggFunc):
+ pass
+
+
class ArrayAll(Func):
arg_types = {"this": True, "expression": True}
@@ -4358,7 +4426,7 @@ class Cast(Func):
def output_name(self) -> str:
return self.name
- def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
+ def is_type(self, *dtypes: DATA_TYPE) -> bool:
"""
Checks whether this Cast's DataType matches one of the provided data types. Nested types
like arrays or structs will be compared using "structural equivalence" semantics, so e.g.
@@ -4403,14 +4471,10 @@ class Chr(Func):
class Concat(Func):
- arg_types = {"expressions": True}
+ arg_types = {"expressions": True, "safe": False, "coalesce": False}
is_var_len_args = True
-class SafeConcat(Concat):
- pass
-
-
class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
@@ -4643,6 +4707,10 @@ class If(Func):
arg_types = {"this": True, "true": True, "false": False}
+class Nullif(Func):
+ arg_types = {"this": True, "expression": True}
+
+
class Initcap(Func):
arg_types = {"this": True, "expression": False}
@@ -4651,6 +4719,10 @@ class IsNan(Func):
_sql_names = ["IS_NAN", "ISNAN"]
+class IsInf(Func):
+ _sql_names = ["IS_INF", "ISINF"]
+
+
class FormatJson(Expression):
pass
@@ -4970,10 +5042,6 @@ class SafeDivide(Func):
arg_types = {"this": True, "expression": True}
-class SetAgg(AggFunc):
- pass
-
-
class SHA(Func):
_sql_names = ["SHA", "SHA1"]
@@ -5118,6 +5186,15 @@ class Trim(Func):
class TsOrDsAdd(Func, TimeUnit):
+ # return_type is used to correctly cast the arguments of this expression when transpiling it
+ arg_types = {"this": True, "expression": True, "unit": False, "return_type": False}
+
+ @property
+ def return_type(self) -> DataType:
+ return DataType.build(self.args.get("return_type") or DataType.Type.DATE)
+
+
+class TsOrDsDiff(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -5149,6 +5226,7 @@ class UnixToTime(Func):
SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis")
MICROS = Literal.string("micros")
+ NANOS = Literal.string("nanos")
class UnixToTimeStr(Func):
@@ -5202,6 +5280,7 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
+FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
# Helpers
@@ -5693,7 +5772,9 @@ def delete(
if where:
delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
if returning:
- delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
+ delete_expr = t.cast(
+ Delete, delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
+ )
return delete_expr
@@ -5702,6 +5783,7 @@ def insert(
into: ExpOrStr,
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
overwrite: t.Optional[bool] = None,
+ returning: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
@@ -5718,6 +5800,7 @@ def insert(
into: the tbl to insert data to.
columns: optionally the table's column names.
overwrite: whether to INSERT OVERWRITE or not.
+ returning: sql conditional parsed into a RETURNING statement
dialect: the dialect used to parse the input expressions.
copy: whether or not to copy the expression.
**opts: other options to use to parse the input expressions.
@@ -5739,7 +5822,12 @@ def insert(
**opts,
)
- return Insert(this=this, expression=expr, overwrite=overwrite)
+ insert = Insert(this=this, expression=expr, overwrite=overwrite)
+
+ if returning:
+ insert = t.cast(Insert, insert.returning(returning, dialect=dialect, copy=False, **opts))
+
+ return insert
def condition(
@@ -5913,7 +6001,7 @@ def to_identifier(name, quoted=None, copy=True):
return identifier
-def parse_identifier(name: str, dialect: DialectType = None) -> Identifier:
+def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Identifier:
"""
Parses a given string into an identifier.
@@ -5965,7 +6053,7 @@ def to_table(sql_path: None, **kwargs) -> None:
def to_table(
- sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs
+ sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs
) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
@@ -5974,13 +6062,14 @@ def to_table(
Args:
sql_path: a `[catalog].[schema].[table]` string.
dialect: the source dialect according to which the table name will be parsed.
+ copy: Whether or not to copy a table if it is passed in.
kwargs: the kwargs to instantiate the resulting `Table` expression with.
Returns:
A table expression.
"""
if sql_path is None or isinstance(sql_path, Table):
- return sql_path
+ return maybe_copy(sql_path, copy=copy)
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
@@ -6123,7 +6212,7 @@ def column(
)
-def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast:
+def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
@@ -6335,12 +6424,15 @@ def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
}
-def table_name(table: Table | str, dialect: DialectType = None) -> str:
+def table_name(table: Table | str, dialect: DialectType = None, identify: bool = False) -> str:
"""Get the full name of a table as a string.
Args:
table: Table expression node or string.
dialect: The dialect to generate the table name for.
+ identify: Determines when an identifier should be quoted. Possible values are:
+ False (default): Never quote, except in cases where it's mandatory by the dialect.
+ True: Always quote.
Examples:
>>> from sqlglot import exp, parse_one
@@ -6358,37 +6450,68 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str:
return ".".join(
part.sql(dialect=dialect, identify=True)
- if not SAFE_IDENTIFIER_RE.match(part.name)
+ if identify or not SAFE_IDENTIFIER_RE.match(part.name)
else part.name
for part in table.parts
)
-def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E:
+def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: bool = True) -> str:
+ """Returns a case normalized table name without quotes.
+
+ Args:
+ table: the table to normalize
+ dialect: the dialect to use for normalization rules
+ copy: whether or not to copy the expression.
+
+ Examples:
+ >>> normalize_table_name("`A-B`.c", dialect="bigquery")
+ 'A-B.c'
+ """
+ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
+
+ return ".".join(
+ p.name
+ for p in normalize_identifiers(
+ to_table(table, dialect=dialect, copy=copy), dialect=dialect
+ ).parts
+ )
+
+
+def replace_tables(
+ expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True
+) -> E:
"""Replace all tables in expression according to the mapping.
Args:
expression: expression node to be transformed and replaced.
mapping: mapping of table names.
+ dialect: the dialect of the mapping table
copy: whether or not to copy the expression.
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
- 'SELECT * FROM c'
+ 'SELECT * FROM c /* a.b */'
Returns:
The mapped expression.
"""
+ mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()}
+
def _replace_tables(node: Expression) -> Expression:
if isinstance(node, Table):
- new_name = mapping.get(table_name(node))
+ original = normalize_table_name(node, dialect=dialect)
+ new_name = mapping.get(original)
+
if new_name:
- return to_table(
+ table = to_table(
new_name,
- **{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")},
+ **{k: v for k, v in node.args.items() if k not in TABLE_PARTS},
)
+ table.add_comments([original])
+ return table
return node
return expression.transform(_replace_tables, copy=copy)
@@ -6431,7 +6554,10 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
def expand(
- expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
+ expression: Expression,
+ sources: t.Dict[str, Subqueryable],
+ dialect: DialectType = None,
+ copy: bool = True,
) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
@@ -6446,15 +6572,17 @@ def expand(
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
+ dialect: The dialect of the sources dict.
copy: Whether or not to copy the expression during transformation. Defaults to True.
Returns:
The transformed expression.
"""
+ sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()}
def _expand(node: Expression):
if isinstance(node, Table):
- name = table_name(node)
+ name = normalize_table_name(node, dialect=dialect)
source = sources.get(name)
if source:
subquery = source.subquery(node.alias or name)
@@ -6465,7 +6593,7 @@ def expand(
return expression.transform(_expand, copy=copy)
-def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
+def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func:
"""
Returns a Func expression.
@@ -6479,6 +6607,7 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
Args:
name: the name of the function to build.
args: the args used to instantiate the function of interest.
+ copy: whether or not to copy the argument expressions.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
@@ -6494,14 +6623,29 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
from sqlglot.dialects.dialect import Dialect
- converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
- kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
+ dialect = Dialect.get_or_raise(dialect)
- parser = Dialect.get_or_raise(dialect)().parser()
- from_args_list = parser.FUNCTIONS.get(name.upper())
+ converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args]
+ kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()}
- if from_args_list:
- function = from_args_list(converted) if converted else from_args_list.__self__(**kwargs) # type: ignore
+ constructor = dialect.parser_class.FUNCTIONS.get(name.upper())
+ if constructor:
+ if converted:
+ if "dialect" in constructor.__code__.co_varnames:
+ function = constructor(converted, dialect=dialect)
+ else:
+ function = constructor(converted)
+ elif constructor.__name__ == "from_arg_list":
+ function = constructor.__self__(**kwargs) # type: ignore
+ else:
+ constructor = FUNCTION_BY_NAME.get(name.upper())
+ if constructor:
+ function = constructor(**kwargs)
+ else:
+ raise ValueError(
+ f"Unable to convert '{name}' into a Func. Either manually construct "
+ "the Func expression of interest or parse the function call."
+ )
else:
kwargs = kwargs or {"expressions": converted}
function = Anonymous(this=name, **kwargs)
@@ -6512,6 +6656,48 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
return function
+def case(
+ expression: t.Optional[ExpOrStr] = None,
+ **opts,
+) -> Case:
+ """
+ Initialize a CASE statement.
+
+ Example:
+ case().when("a = 1", "foo").else_("bar")
+
+ Args:
+ expression: Optionally, the input expression (not all dialects support this)
+ **opts: Extra keyword arguments for parsing `expression`
+ """
+ if expression is not None:
+ this = maybe_parse(expression, **opts)
+ else:
+ this = None
+ return Case(this=this, ifs=[])
+
+
+def cast_unless(
+ expression: ExpOrStr,
+ to: DATA_TYPE,
+ *types: DATA_TYPE,
+ **opts: t.Any,
+) -> Expression | Cast:
+ """
+ Cast an expression to a data type unless it is a specified type.
+
+ Args:
+ expression: The expression to cast.
+ to: The data type to cast to.
+ **types: The types to exclude from casting.
+ **opts: Extra keyword arguments for parsing `expression`
+ """
+ expr = maybe_parse(expression, **opts)
+ if expr.is_type(*types):
+ return expr
+ return cast(expr, to, **opts)
+
+
def true() -> Boolean:
"""
Returns a true Boolean expression.