summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py158
1 files changed, 141 insertions, 17 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 7c1a116..6bb083a 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -32,13 +32,7 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
- from sqlglot.dialects.dialect import Dialect
-
- IntoType = t.Union[
- str,
- t.Type[Expression],
- t.Collection[t.Union[str, t.Type[Expression]]],
- ]
+ from sqlglot.dialects.dialect import DialectType
class _Expression(type):
@@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self._to_s()
- def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
+ def sql(self, dialect: DialectType = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
return load(obj)
+if t.TYPE_CHECKING:
+ IntoType = t.Union[
+ str,
+ t.Type[Expression],
+ t.Collection[t.Union[str, t.Type[Expression]]],
+ ]
+
+
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
"""
@@ -1285,6 +1287,18 @@ class Property(Expression):
arg_types = {"this": True, "value": True}
+class AlgorithmProperty(Property):
+ arg_types = {"this": True}
+
+
+class DefinerProperty(Property):
+ arg_types = {"this": True}
+
+
+class SqlSecurityProperty(Property):
+ arg_types = {"definer": True}
+
+
class TableFormatProperty(Property):
arg_types = {"this": True}
@@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
class Properties(Expression):
- arg_types = {"expressions": True, "before": False}
+ arg_types = {"expressions": True}
NAME_TO_PROPERTY = {
+ "ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
+ "DEFINER": DefinerProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
@@ -1447,6 +1463,14 @@ class Properties(Expression):
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
+ class Location(AutoName):
+ POST_CREATE = auto()
+ PRE_SCHEMA = auto()
+ POST_INDEX = auto()
+ POST_SCHEMA_ROOT = auto()
+ POST_SCHEMA_WITH = auto()
+ UNSUPPORTED = auto()
+
@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
@@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
"order": False,
"limit": False,
"offset": False,
+ "lock": False,
}
@@ -1713,6 +1738,12 @@ class Schema(Expression):
arg_types = {"this": False, "expressions": False}
+# Used to represent the FOR UPDATE and FOR SHARE locking read types.
+# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
+class Lock(Expression):
+ arg_types = {"update": True}
+
+
class Select(Subqueryable):
arg_types = {
"with": False,
@@ -2243,6 +2274,30 @@ class Select(Subqueryable):
properties=properties_expression,
)
+ def lock(self, update: bool = True, copy: bool = True) -> Select:
+ """
+ Set the locking read mode for this expression.
+
+ Examples:
+ >>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
+ "SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
+
+ >>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
+ "SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
+
+ Args:
+ update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
+ copy: if `False`, modify this expression instance in-place.
+
+ Returns:
+ The modified expression.
+ """
+
+ inst = _maybe_copy(self, copy)
+ inst.set("lock", Lock(update=update))
+
+ return inst
+
@property
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@@ -2456,24 +2511,28 @@ class DataType(Expression):
@classmethod
def build(
- cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
+ cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
) -> DataType:
from sqlglot import parse_one
if isinstance(dtype, str):
- data_type_exp: t.Optional[Expression]
if dtype.upper() in cls.Type.__members__:
- data_type_exp = DataType(this=DataType.Type[dtype.upper()])
+ data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
+ elif isinstance(dtype, DataType):
+ return dtype
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
return DataType(**{**data_type_exp.args, **kwargs})
+ def is_type(self, dtype: DataType.Type) -> bool:
+ return self.this == dtype
+
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression):
@@ -2840,6 +2899,10 @@ class Array(Func):
is_var_len_args = True
+class GenerateSeries(Func):
+ arg_types = {"start": True, "end": True, "step": False}
+
+
class ArrayAgg(AggFunc):
pass
@@ -2909,6 +2972,9 @@ class Cast(Func):
def output_name(self):
return self.name
+ def is_type(self, dtype: DataType.Type) -> bool:
+ return self.to.is_type(dtype)
+
class Collate(Binary):
pass
@@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
+class DayOfWeek(Func):
+ _sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
+
+
+class DayOfMonth(Func):
+ _sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
+
+
+class DayOfYear(Func):
+ _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
+
+
+class WeekOfYear(Func):
+ _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
+
+
class LastDateOfMonth(Func):
pass
@@ -3239,7 +3321,7 @@ class ReadCSV(Func):
class Reduce(Func):
- arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
+ arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
class RegexpLike(Func):
@@ -3476,7 +3558,7 @@ def maybe_parse(
sql_or_expression: str | Expression,
*,
into: t.Optional[IntoType] = None,
- dialect: t.Optional[str] = None,
+ dialect: DialectType = None,
prefix: t.Optional[str] = None,
**opts,
) -> Expression:
@@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier
+INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
+
+
+def to_interval(interval: str | Literal) -> Interval:
+ """Builds an interval expression from a string like '1 day' or '5 months'."""
+ if isinstance(interval, Literal):
+ if not interval.is_string:
+ raise ValueError("Invalid interval string.")
+
+ interval = interval.this
+
+ interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
+
+ if not interval_parts:
+ raise ValueError("Invalid interval string.")
+
+ return Interval(
+ this=Literal.string(interval_parts.group(1)),
+ unit=Var(this=interval_parts.group(2)),
+ )
+
+
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
def subquery(expression, alias=None, dialect=None, **opts):
"""
Build a subquery expression.
- Expample:
+
+ Example:
>>> subquery('select x from tbl', 'bar').select('x').sql()
'SELECT x FROM (SELECT x FROM tbl) AS bar'
@@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
+
Args:
col (str | Expression): column name
table (str | Expression): table name
@@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
)
+def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
+ """Cast an expression to a data type.
+
+ Example:
+ >>> cast('x + 1', 'int').sql()
+ 'CAST(x + 1 AS INT)'
+
+ Args:
+ expression: The expression to cast.
+ to: The datatype to cast to.
+
+ Returns:
+ A cast node.
+ """
+ expression = maybe_parse(expression, **opts)
+ return Cast(this=expression, to=DataType.build(to, **opts))
+
+
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
@@ -4137,7 +4261,7 @@ def values(
types = list(columns.values())
expressions[0].set(
"expressions",
- [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
+ [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
)
return Values(
expressions=expressions,
@@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
return expression.transform(_expand, copy=copy)
-def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
+def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
"""
Returns a Func expression.