summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py90
1 files changed, 74 insertions, 16 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index da4a4ed..c7d4664 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1500,6 +1500,7 @@ class Index(Expression):
arg_types = {
"this": False,
"table": False,
+ "using": False,
"where": False,
"columns": False,
"unique": False,
@@ -1623,7 +1624,7 @@ class Lambda(Expression):
class Limit(Expression):
- arg_types = {"this": False, "expression": True}
+ arg_types = {"this": False, "expression": True, "offset": False}
class Literal(Condition):
@@ -1869,6 +1870,10 @@ class EngineProperty(Property):
arg_types = {"this": True}
+class ToTableProperty(Property):
+ arg_types = {"this": True}
+
+
class ExecuteAsProperty(Property):
arg_types = {"this": True}
@@ -3072,12 +3077,35 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
-
inst = _maybe_copy(self, copy)
inst.set("locks", [Lock(update=update)])
return inst
+ def hint(self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True) -> Select:
+ """
+ Set hints for this expression.
+
+ Examples:
+ >>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark")
+ 'SELECT /*+ BROADCAST(y) */ x FROM tbl'
+
+ Args:
+ hints: The SQL code strings to parse as the hints.
+ If an `Expression` instance is passed, it will be used as-is.
+ dialect: The dialect used to parse the hints.
+ copy: If `False`, modify this expression instance in-place.
+
+ Returns:
+ The modified expression.
+ """
+ inst = _maybe_copy(self, copy)
+ inst.set(
+ "hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
+ )
+
+ return inst
+
@property
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@@ -3244,6 +3272,7 @@ class DataType(Expression):
DATE = auto()
DATETIME = auto()
DATETIME64 = auto()
+ ENUM = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
@@ -3284,6 +3313,7 @@ class DataType(Expression):
OBJECT = auto()
ROWVERSION = auto()
SERIAL = auto()
+ SET = auto()
SMALLINT = auto()
SMALLMONEY = auto()
SMALLSERIAL = auto()
@@ -3334,6 +3364,7 @@ class DataType(Expression):
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
TEMPORAL_TYPES = {
+ Type.TIME,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
@@ -3342,6 +3373,8 @@ class DataType(Expression):
Type.DATETIME64,
}
+ META_TYPES = {"UNKNOWN", "NULL"}
+
@classmethod
def build(
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
@@ -3349,8 +3382,9 @@ class DataType(Expression):
from sqlglot import parse_one
if isinstance(dtype, str):
- if dtype.upper() in cls.Type.__members__:
- data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
+ upper = dtype.upper()
+ if upper in DataType.META_TYPES:
+ data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
@@ -3483,6 +3517,10 @@ class Dot(Binary):
def name(self) -> str:
return self.expression.name
+ @property
+ def output_name(self) -> str:
+ return self.name
+
@classmethod
def build(self, expressions: t.Sequence[Expression]) -> Dot:
"""Build a Dot object with a sequence of expressions."""
@@ -3502,6 +3540,10 @@ class DPipe(Binary):
pass
+class SafeDPipe(DPipe):
+ pass
+
+
class EQ(Binary, Predicate):
pass
@@ -3615,6 +3657,10 @@ class Not(Unary):
class Paren(Unary):
arg_types = {"this": True, "with": False}
+ @property
+ def output_name(self) -> str:
+ return self.this.name
+
class Neg(Unary):
pass
@@ -3904,6 +3950,7 @@ class Ceil(Func):
class Coalesce(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
+ _sql_names = ["COALESCE", "IFNULL", "NVL"]
class Concat(Func):
@@ -3911,12 +3958,17 @@ class Concat(Func):
is_var_len_args = True
+class SafeConcat(Concat):
+ pass
+
+
class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
class Count(AggFunc):
- arg_types = {"this": False}
+ arg_types = {"this": False, "expressions": False}
+ is_var_len_args = True
class CountIf(AggFunc):
@@ -4049,6 +4101,11 @@ class DateToDi(Func):
pass
+class Date(Func):
+ arg_types = {"expressions": True}
+ is_var_len_args = True
+
+
class Day(Func):
pass
@@ -4102,11 +4159,6 @@ class If(Func):
arg_types = {"this": True, "true": True, "false": False}
-class IfNull(Func):
- arg_types = {"this": True, "expression": False}
- _sql_names = ["IFNULL", "NVL"]
-
-
class Initcap(Func):
arg_types = {"this": True, "expression": False}
@@ -5608,22 +5660,27 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
-def column_table_names(expression: Expression) -> t.List[str]:
+def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
"""
Return all table names referenced through columns in an expression.
Example:
>>> import sqlglot
- >>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))
- ['c', 'a']
+ >>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")))
+ ['a', 'c']
Args:
expression: expression to find table names.
+ exclude: a table name to exclude
Returns:
A list of unique names.
"""
- return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
+ return {
+ table
+ for table in (column.table for column in expression.find_all(Column))
+ if table and table != exclude
+ }
def table_name(table: Table | str) -> str:
@@ -5649,12 +5706,13 @@ def table_name(table: Table | str) -> str:
return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part)
-def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
+def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E:
"""Replace all tables in expression according to the mapping.
Args:
expression: expression node to be transformed and replaced.
mapping: mapping of table names.
+ copy: whether or not to copy the expression.
Examples:
>>> from sqlglot import exp, parse_one
@@ -5675,7 +5733,7 @@ def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
)
return node
- return expression.transform(_replace_tables)
+ return expression.transform(_replace_tables, copy=copy)
def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: