diff options
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 90 |
1 files changed, 74 insertions, 16 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index da4a4ed..c7d4664 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1500,6 +1500,7 @@ class Index(Expression): arg_types = { "this": False, "table": False, + "using": False, "where": False, "columns": False, "unique": False, @@ -1623,7 +1624,7 @@ class Lambda(Expression): class Limit(Expression): - arg_types = {"this": False, "expression": True} + arg_types = {"this": False, "expression": True, "offset": False} class Literal(Condition): @@ -1869,6 +1870,10 @@ class EngineProperty(Property): arg_types = {"this": True} +class ToTableProperty(Property): + arg_types = {"this": True} + + class ExecuteAsProperty(Property): arg_types = {"this": True} @@ -3072,12 +3077,35 @@ class Select(Subqueryable): Returns: The modified expression. """ - inst = _maybe_copy(self, copy) inst.set("locks", [Lock(update=update)]) return inst + def hint(self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True) -> Select: + """ + Set hints for this expression. + + Examples: + >>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark") + 'SELECT /*+ BROADCAST(y) */ x FROM tbl' + + Args: + hints: The SQL code strings to parse as the hints. + If an `Expression` instance is passed, it will be used as-is. + dialect: The dialect used to parse the hints. + copy: If `False`, modify this expression instance in-place. + + Returns: + The modified expression. + """ + inst = _maybe_copy(self, copy) + inst.set( + "hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints]) + ) + + return inst + @property def named_selects(self) -> t.List[str]: return [e.output_name for e in self.expressions if e.alias_or_name] @@ -3244,6 +3272,7 @@ class DataType(Expression): DATE = auto() DATETIME = auto() DATETIME64 = auto() + ENUM = auto() INT4RANGE = auto() INT4MULTIRANGE = auto() INT8RANGE = auto() @@ -3284,6 +3313,7 @@ class DataType(Expression): OBJECT = auto() ROWVERSION = auto() SERIAL = auto() + SET = auto() SMALLINT = auto() SMALLMONEY = auto() SMALLSERIAL = auto() @@ -3334,6 +3364,7 @@ class DataType(Expression): NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES} TEMPORAL_TYPES = { + Type.TIME, Type.TIMESTAMP, Type.TIMESTAMPTZ, Type.TIMESTAMPLTZ, @@ -3342,6 +3373,8 @@ class DataType(Expression): Type.DATETIME64, } + META_TYPES = {"UNKNOWN", "NULL"} + @classmethod def build( cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs @@ -3349,8 +3382,9 @@ class DataType(Expression): from sqlglot import parse_one if isinstance(dtype, str): - if dtype.upper() in cls.Type.__members__: - data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()]) + upper = dtype.upper() + if upper in DataType.META_TYPES: + data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper]) else: data_type_exp = parse_one(dtype, read=dialect, into=DataType) @@ -3483,6 +3517,10 @@ class Dot(Binary): def name(self) -> str: return self.expression.name + @property + def output_name(self) -> str: + return self.name + @classmethod def build(self, expressions: t.Sequence[Expression]) -> Dot: """Build a Dot object with a sequence of expressions.""" @@ -3502,6 +3540,10 @@ class DPipe(Binary): pass +class SafeDPipe(DPipe): + pass + + class EQ(Binary, Predicate): pass @@ -3615,6 +3657,10 @@ class Not(Unary): class Paren(Unary): arg_types = {"this": True, "with": False} + @property + def output_name(self) -> str: + return self.this.name + class Neg(Unary): pass @@ -3904,6 +3950,7 @@ class Ceil(Func): class Coalesce(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True + _sql_names = ["COALESCE", "IFNULL", "NVL"] class Concat(Func): @@ -3911,12 +3958,17 @@ class Concat(Func): is_var_len_args = True +class SafeConcat(Concat): + pass + + class ConcatWs(Concat): _sql_names = ["CONCAT_WS"] class Count(AggFunc): - arg_types = {"this": False} + arg_types = {"this": False, "expressions": False} + is_var_len_args = True class CountIf(AggFunc): @@ -4049,6 +4101,11 @@ class DateToDi(Func): pass +class Date(Func): + arg_types = {"expressions": True} + is_var_len_args = True + + class Day(Func): pass @@ -4102,11 +4159,6 @@ class If(Func): arg_types = {"this": True, "true": True, "false": False} -class IfNull(Func): - arg_types = {"this": True, "expression": False} - _sql_names = ["IFNULL", "NVL"] - - class Initcap(Func): arg_types = {"this": True, "expression": False} @@ -5608,22 +5660,27 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) - expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) -def column_table_names(expression: Expression) -> t.List[str]: +def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]: """ Return all table names referenced through columns in an expression. Example: >>> import sqlglot - >>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")) - ['c', 'a'] + >>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))) + ['a', 'c'] Args: expression: expression to find table names. + exclude: a table name to exclude Returns: A list of unique names. """ - return list(dict.fromkeys(column.table for column in expression.find_all(Column))) + return { + table + for table in (column.table for column in expression.find_all(Column)) + if table and table != exclude + } def table_name(table: Table | str) -> str: @@ -5649,12 +5706,13 @@ def table_name(table: Table | str) -> str: return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part) -def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E: +def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E: """Replace all tables in expression according to the mapping. Args: expression: expression node to be transformed and replaced. mapping: mapping of table names. + copy: whether or not to copy the expression. Examples: >>> from sqlglot import exp, parse_one @@ -5675,7 +5733,7 @@ def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E: ) return node - return expression.transform(_replace_tables) + return expression.transform(_replace_tables, copy=copy) def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: |