diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-08 04:14:30 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-08 04:14:30 +0000 |
commit | 99980f928b5b7be237d108266072e51aa3bb354e (patch) | |
tree | ce6fff00ea2b834bdbe3d84dcac90df1617d4245 /sqlglot/expressions.py | |
parent | Adding upstream version 10.6.0. (diff) | |
download | sqlglot-99980f928b5b7be237d108266072e51aa3bb354e.tar.xz sqlglot-99980f928b5b7be237d108266072e51aa3bb354e.zip |
Adding upstream version 10.6.3.upstream/10.6.3
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 158 |
1 files changed, 141 insertions, 17 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7c1a116..6bb083a 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -32,13 +32,7 @@ from sqlglot.helper import ( from sqlglot.tokens import Token if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import Dialect - - IntoType = t.Union[ - str, - t.Type[Expression], - t.Collection[t.Union[str, t.Type[Expression]]], - ] + from sqlglot.dialects.dialect import DialectType class _Expression(type): @@ -427,7 +421,7 @@ class Expression(metaclass=_Expression): def __repr__(self): return self._to_s() - def sql(self, dialect: Dialect | str | None = None, **opts) -> str: + def sql(self, dialect: DialectType = None, **opts) -> str: """ Returns SQL string representation of this tree. @@ -595,6 +589,14 @@ class Expression(metaclass=_Expression): return load(obj) +if t.TYPE_CHECKING: + IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], + ] + + class Condition(Expression): def and_(self, *expressions, dialect=None, **opts): """ @@ -1285,6 +1287,18 @@ class Property(Expression): arg_types = {"this": True, "value": True} +class AlgorithmProperty(Property): + arg_types = {"this": True} + + +class DefinerProperty(Property): + arg_types = {"this": True} + + +class SqlSecurityProperty(Property): + arg_types = {"definer": True} + + class TableFormatProperty(Property): arg_types = {"this": True} @@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property): class Properties(Expression): - arg_types = {"expressions": True, "before": False} + arg_types = {"expressions": True} NAME_TO_PROPERTY = { + "ALGORITHM": AlgorithmProperty, "AUTO_INCREMENT": AutoIncrementProperty, "CHARACTER SET": CharacterSetProperty, "COLLATE": CollateProperty, "COMMENT": SchemaCommentProperty, + "DEFINER": DefinerProperty, "DISTKEY": DistKeyProperty, "DISTSTYLE": DistStyleProperty, "ENGINE": EngineProperty, @@ -1447,6 +1463,14 @@ class Properties(Expression): PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + class Location(AutoName): + POST_CREATE = auto() + PRE_SCHEMA = auto() + POST_INDEX = auto() + POST_SCHEMA_ROOT = auto() + POST_SCHEMA_WITH = auto() + UNSUPPORTED = auto() + @classmethod def from_dict(cls, properties_dict) -> Properties: expressions = [] @@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = { "order": False, "limit": False, "offset": False, + "lock": False, } @@ -1713,6 +1738,12 @@ class Schema(Expression): arg_types = {"this": False, "expressions": False} +# Used to represent the FOR UPDATE and FOR SHARE locking read types. +# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html +class Lock(Expression): + arg_types = {"update": True} + + class Select(Subqueryable): arg_types = { "with": False, @@ -2243,6 +2274,30 @@ class Select(Subqueryable): properties=properties_expression, ) + def lock(self, update: bool = True, copy: bool = True) -> Select: + """ + Set the locking read mode for this expression. + + Examples: + >>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR UPDATE" + + >>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql") + "SELECT x FROM tbl WHERE x = 'a' FOR SHARE" + + Args: + update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`. + copy: if `False`, modify this expression instance in-place. + + Returns: + The modified expression. + """ + + inst = _maybe_copy(self, copy) + inst.set("lock", Lock(update=update)) + + return inst + @property def named_selects(self) -> t.List[str]: return [e.output_name for e in self.expressions if e.alias_or_name] @@ -2456,24 +2511,28 @@ class DataType(Expression): @classmethod def build( - cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs + cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs ) -> DataType: from sqlglot import parse_one if isinstance(dtype, str): - data_type_exp: t.Optional[Expression] if dtype.upper() in cls.Type.__members__: - data_type_exp = DataType(this=DataType.Type[dtype.upper()]) + data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()]) else: data_type_exp = parse_one(dtype, read=dialect, into=DataType) if data_type_exp is None: raise ValueError(f"Unparsable data type value: {dtype}") elif isinstance(dtype, DataType.Type): data_type_exp = DataType(this=dtype) + elif isinstance(dtype, DataType): + return dtype else: raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") return DataType(**{**data_type_exp.args, **kwargs}) + def is_type(self, dtype: DataType.Type) -> bool: + return self.this == dtype + # https://www.postgresql.org/docs/15/datatype-pseudo.html class PseudoType(Expression): @@ -2840,6 +2899,10 @@ class Array(Func): is_var_len_args = True +class GenerateSeries(Func): + arg_types = {"start": True, "end": True, "step": False} + + class ArrayAgg(AggFunc): pass @@ -2909,6 +2972,9 @@ class Cast(Func): def output_name(self): return self.name + def is_type(self, dtype: DataType.Type) -> bool: + return self.to.is_type(dtype) + class Collate(Binary): pass @@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class DayOfWeek(Func): + _sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"] + + +class DayOfMonth(Func): + _sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"] + + +class DayOfYear(Func): + _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"] + + +class WeekOfYear(Func): + _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"] + + class LastDateOfMonth(Func): pass @@ -3239,7 +3321,7 @@ class ReadCSV(Func): class Reduce(Func): - arg_types = {"this": True, "initial": True, "merge": True, "finish": True} + arg_types = {"this": True, "initial": True, "merge": True, "finish": False} class RegexpLike(Func): @@ -3476,7 +3558,7 @@ def maybe_parse( sql_or_expression: str | Expression, *, into: t.Optional[IntoType] = None, - dialect: t.Optional[str] = None, + dialect: DialectType = None, prefix: t.Optional[str] = None, **opts, ) -> Expression: @@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: return identifier +INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*") + + +def to_interval(interval: str | Literal) -> Interval: + """Builds an interval expression from a string like '1 day' or '5 months'.""" + if isinstance(interval, Literal): + if not interval.is_string: + raise ValueError("Invalid interval string.") + + interval = interval.this + + interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore + + if not interval_parts: + raise ValueError("Invalid interval string.") + + return Interval( + this=Literal.string(interval_parts.group(1)), + unit=Var(this=interval_parts.group(2)), + ) + + @t.overload def to_table(sql_path: str | Table, **kwargs) -> Table: ... @@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): def subquery(expression, alias=None, dialect=None, **opts): """ Build a subquery expression. - Expample: + + Example: >>> subquery('select x from tbl', 'bar').select('x').sql() 'SELECT x FROM (SELECT x FROM tbl) AS bar' @@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts): def column(col, table=None, quoted=None) -> Column: """ Build a Column. + Args: col (str | Expression): column name table (str | Expression): table name @@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column: ) +def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast: + """Cast an expression to a data type. + + Example: + >>> cast('x + 1', 'int').sql() + 'CAST(x + 1 AS INT)' + + Args: + expression: The expression to cast. + to: The datatype to cast to. + + Returns: + A cast node. + """ + expression = maybe_parse(expression, **opts) + return Cast(this=expression, to=DataType.build(to, **opts)) + + def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: """Build a Table. @@ -4137,7 +4261,7 @@ def values( types = list(columns.values()) expressions[0].set( "expressions", - [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)], + [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)], ) return Values( expressions=expressions, @@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True return expression.transform(_expand, copy=copy) -def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func: +def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: """ Returns a Func expression. |