diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-10 11:29:05 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-10 11:29:05 +0000 |
commit | f818ab3b896d52e874634b7c4db3533078c1887f (patch) | |
tree | 8d0f7e4b7f165f33f49da74cb34eb31a0a2d147b /sqlglot/expressions.py | |
parent | Releasing debian version 6.2.8-1. (diff) | |
download | sqlglot-f818ab3b896d52e874634b7c4db3533078c1887f.tar.xz sqlglot-f818ab3b896d52e874634b7c4db3533078c1887f.zip |
Merging upstream version 6.3.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 222 |
1 files changed, 214 insertions, 8 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8cdacce..f2ffd12 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -508,7 +508,69 @@ class DerivedTable(Expression): return [select.alias_or_name for select in self.selects] -class UDTF(DerivedTable): +class Unionable: + def union(self, expression, distinct=True, dialect=None, **opts): + """ + Builds a UNION expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Union: the Union expression. + """ + return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + def intersect(self, expression, distinct=True, dialect=None, **opts): + """ + Builds an INTERSECT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Intersect: the Intersect expression + """ + return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + def except_(self, expression, distinct=True, dialect=None, **opts): + """ + Builds an EXCEPT expression. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + expression (str or Expression): the SQL code string. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Except: the Except expression + """ + return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) + + +class UDTF(DerivedTable, Unionable): pass @@ -518,6 +580,10 @@ class Annotation(Expression): "expression": True, } + @property + def alias(self): + return self.expression.alias_or_name + class Cache(Expression): arg_types = { @@ -700,6 +766,10 @@ class Hint(Expression): arg_types = {"expressions": True} +class JoinHint(Expression): + arg_types = {"this": True, "expressions": True} + + class Identifier(Expression): arg_types = {"this": True, "quoted": False} @@ -971,7 +1041,7 @@ class Tuple(Expression): arg_types = {"expressions": False} -class Subqueryable: +class Subqueryable(Unionable): def subquery(self, alias=None, copy=True): """ Convert this expression to an aliased expression that can be used as a Subquery. @@ -1654,7 +1724,7 @@ class Select(Subqueryable, Expression): return self.expressions -class Subquery(DerivedTable): +class Subquery(DerivedTable, Unionable): arg_types = { "this": True, "alias": False, @@ -1731,7 +1801,7 @@ class Parameter(Expression): class Placeholder(Expression): - arg_types = {} + arg_types = {"this": False} class Null(Condition): @@ -1791,6 +1861,8 @@ class DataType(Expression): IMAGE = auto() VARIANT = auto() OBJECT = auto() + NULL = auto() + UNKNOWN = auto() # Sentinel value, useful for type annotation @classmethod def build(cls, dtype, **kwargs): @@ -2007,7 +2079,7 @@ class Distinct(Expression): class In(Predicate): - arg_types = {"this": True, "expressions": False, "query": False, "unnest": False} + arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False} class TimeUnit(Expression): @@ -2377,6 +2449,11 @@ class Map(Func): arg_types = {"keys": True, "values": True} +class VarMap(Func): + arg_types = {"keys": True, "values": True} + is_var_len_args = True + + class Max(AggFunc): pass @@ -2449,7 +2526,7 @@ class Substring(Func): class StrPosition(Func): - arg_types = {"this": True, "substr": True, "position": False} + arg_types = {"substr": True, "this": True, "position": False} class StrToDate(Func): @@ -2785,6 +2862,81 @@ def _wrap_operator(expression): return expression +def union(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one UNION expression. + + Example: + >>> union("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo UNION SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Union: the syntax tree for the UNION expression. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Union(this=left, expression=right, distinct=distinct) + + +def intersect(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one INTERSECT expression. + + Example: + >>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo INTERSECT SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Intersect: the syntax tree for the INTERSECT expression. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Intersect(this=left, expression=right, distinct=distinct) + + +def except_(left, right, distinct=True, dialect=None, **opts): + """ + Initializes a syntax tree from one EXCEPT expression. + + Example: + >>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql() + 'SELECT * FROM foo EXCEPT SELECT * FROM bla' + + Args: + left (str or Expression): the SQL code string corresponding to the left-hand side. + If an `Expression` instance is passed, it will be used as-is. + right (str or Expression): the SQL code string corresponding to the right-hand side. + If an `Expression` instance is passed, it will be used as-is. + distinct (bool): set the DISTINCT flag if and only if this is true. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + Returns: + Except: the syntax tree for the EXCEPT statement. + """ + left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) + right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) + + return Except(this=left, expression=right, distinct=distinct) + + def select(*expressions, dialect=None, **opts): """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -2991,7 +3143,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): If an Expression instance is passed, this is used as-is. alias (str or Identifier): the alias name to use. If the name has special characters it is quoted. - table (boolean): create a table alias, default false + table (bool): create a table alias, default false dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3002,7 +3154,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): alias = to_identifier(alias, quoted=quoted) alias = TableAlias(this=alias) if table else alias - if "alias" in exp.arg_types: + if "alias" in exp.arg_types and not isinstance(exp, Window): exp = exp.copy() exp.set("alias", alias) return exp @@ -3138,6 +3290,60 @@ def column_table_names(expression): return list(dict.fromkeys(column.table for column in expression.find_all(Column))) +def table_name(table): + """Get the full name of a table as a string. + + Args: + table (exp.Table | str): Table expression node or string. + + Examples: + >>> from sqlglot import exp, parse_one + >>> table_name(parse_one("select * from a.b.c").find(exp.Table)) + 'a.b.c' + + Returns: + str: the table name + """ + + table = maybe_parse(table, into=Table) + + return ".".join( + part + for part in ( + table.text("catalog"), + table.text("db"), + table.name, + ) + if part + ) + + +def replace_tables(expression, mapping): + """Replace all tables in expression according to the mapping. + + Args: + expression (sqlglot.Expression): Expression node to be transformed and replaced + mapping (Dict[str, str]): Mapping of table names + + Examples: + >>> from sqlglot import exp, parse_one + >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() + 'SELECT * FROM "c"' + + Returns: + The mapped expression + """ + + def _replace_tables(node): + if isinstance(node, Table): + new_name = mapping.get(table_name(node)) + if new_name: + return table_(*reversed(new_name.split(".")), quoted=True) + return node + + return expression.transform(_replace_tables) + + TRUE = Boolean(this=True) FALSE = Boolean(this=False) NULL = Null() |