summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py222
1 files changed, 214 insertions, 8 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 8cdacce..f2ffd12 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -508,7 +508,69 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects]
-class UDTF(DerivedTable):
+class Unionable:
+ def union(self, expression, distinct=True, dialect=None, **opts):
+ """
+ Builds a UNION expression.
+
+ Example:
+ >>> import sqlglot
+ >>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql()
+ 'SELECT * FROM foo UNION SELECT * FROM bla'
+
+ Args:
+ expression (str or Expression): the SQL code string.
+ If an `Expression` instance is passed, it will be used as-is.
+ distinct (bool): set the DISTINCT flag if and only if this is true.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+ Returns:
+ Union: the Union expression.
+ """
+ return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
+
+ def intersect(self, expression, distinct=True, dialect=None, **opts):
+ """
+ Builds an INTERSECT expression.
+
+ Example:
+ >>> import sqlglot
+ >>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql()
+ 'SELECT * FROM foo INTERSECT SELECT * FROM bla'
+
+ Args:
+ expression (str or Expression): the SQL code string.
+ If an `Expression` instance is passed, it will be used as-is.
+ distinct (bool): set the DISTINCT flag if and only if this is true.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+ Returns:
+ Intersect: the Intersect expression
+ """
+ return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
+
+ def except_(self, expression, distinct=True, dialect=None, **opts):
+ """
+ Builds an EXCEPT expression.
+
+ Example:
+ >>> import sqlglot
+ >>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql()
+ 'SELECT * FROM foo EXCEPT SELECT * FROM bla'
+
+ Args:
+ expression (str or Expression): the SQL code string.
+ If an `Expression` instance is passed, it will be used as-is.
+ distinct (bool): set the DISTINCT flag if and only if this is true.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+ Returns:
+ Except: the Except expression
+ """
+ return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
+
+
+class UDTF(DerivedTable, Unionable):
pass
@@ -518,6 +580,10 @@ class Annotation(Expression):
"expression": True,
}
+ @property
+ def alias(self):
+ return self.expression.alias_or_name
+
class Cache(Expression):
arg_types = {
@@ -700,6 +766,10 @@ class Hint(Expression):
arg_types = {"expressions": True}
+class JoinHint(Expression):
+ arg_types = {"this": True, "expressions": True}
+
+
class Identifier(Expression):
arg_types = {"this": True, "quoted": False}
@@ -971,7 +1041,7 @@ class Tuple(Expression):
arg_types = {"expressions": False}
-class Subqueryable:
+class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True):
"""
Convert this expression to an aliased expression that can be used as a Subquery.
@@ -1654,7 +1724,7 @@ class Select(Subqueryable, Expression):
return self.expressions
-class Subquery(DerivedTable):
+class Subquery(DerivedTable, Unionable):
arg_types = {
"this": True,
"alias": False,
@@ -1731,7 +1801,7 @@ class Parameter(Expression):
class Placeholder(Expression):
- arg_types = {}
+ arg_types = {"this": False}
class Null(Condition):
@@ -1791,6 +1861,8 @@ class DataType(Expression):
IMAGE = auto()
VARIANT = auto()
OBJECT = auto()
+ NULL = auto()
+ UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod
def build(cls, dtype, **kwargs):
@@ -2007,7 +2079,7 @@ class Distinct(Expression):
class In(Predicate):
- arg_types = {"this": True, "expressions": False, "query": False, "unnest": False}
+ arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
class TimeUnit(Expression):
@@ -2377,6 +2449,11 @@ class Map(Func):
arg_types = {"keys": True, "values": True}
+class VarMap(Func):
+ arg_types = {"keys": True, "values": True}
+ is_var_len_args = True
+
+
class Max(AggFunc):
pass
@@ -2449,7 +2526,7 @@ class Substring(Func):
class StrPosition(Func):
- arg_types = {"this": True, "substr": True, "position": False}
+ arg_types = {"substr": True, "this": True, "position": False}
class StrToDate(Func):
@@ -2785,6 +2862,81 @@ def _wrap_operator(expression):
return expression
+def union(left, right, distinct=True, dialect=None, **opts):
+ """
+ Initializes a syntax tree from one UNION expression.
+
+ Example:
+ >>> union("SELECT * FROM foo", "SELECT * FROM bla").sql()
+ 'SELECT * FROM foo UNION SELECT * FROM bla'
+
+ Args:
+ left (str or Expression): the SQL code string corresponding to the left-hand side.
+ If an `Expression` instance is passed, it will be used as-is.
+ right (str or Expression): the SQL code string corresponding to the right-hand side.
+ If an `Expression` instance is passed, it will be used as-is.
+ distinct (bool): set the DISTINCT flag if and only if this is true.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+ Returns:
+ Union: the syntax tree for the UNION expression.
+ """
+ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
+ right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
+
+ return Union(this=left, expression=right, distinct=distinct)
+
+
+def intersect(left, right, distinct=True, dialect=None, **opts):
+ """
+ Initializes a syntax tree from one INTERSECT expression.
+
+ Example:
+ >>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql()
+ 'SELECT * FROM foo INTERSECT SELECT * FROM bla'
+
+ Args:
+ left (str or Expression): the SQL code string corresponding to the left-hand side.
+ If an `Expression` instance is passed, it will be used as-is.
+ right (str or Expression): the SQL code string corresponding to the right-hand side.
+ If an `Expression` instance is passed, it will be used as-is.
+ distinct (bool): set the DISTINCT flag if and only if this is true.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+ Returns:
+ Intersect: the syntax tree for the INTERSECT expression.
+ """
+ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
+ right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
+
+ return Intersect(this=left, expression=right, distinct=distinct)
+
+
+def except_(left, right, distinct=True, dialect=None, **opts):
+ """
+ Initializes a syntax tree from one EXCEPT expression.
+
+ Example:
+ >>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql()
+ 'SELECT * FROM foo EXCEPT SELECT * FROM bla'
+
+ Args:
+ left (str or Expression): the SQL code string corresponding to the left-hand side.
+ If an `Expression` instance is passed, it will be used as-is.
+ right (str or Expression): the SQL code string corresponding to the right-hand side.
+ If an `Expression` instance is passed, it will be used as-is.
+ distinct (bool): set the DISTINCT flag if and only if this is true.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+ Returns:
+ Except: the syntax tree for the EXCEPT statement.
+ """
+ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
+ right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
+
+ return Except(this=left, expression=right, distinct=distinct)
+
+
def select(*expressions, dialect=None, **opts):
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@@ -2991,7 +3143,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
If an Expression instance is passed, this is used as-is.
alias (str or Identifier): the alias name to use. If the name has
special characters it is quoted.
- table (boolean): create a table alias, default false
+ table (bool): create a table alias, default false
dialect (str): the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
@@ -3002,7 +3154,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
alias = to_identifier(alias, quoted=quoted)
alias = TableAlias(this=alias) if table else alias
- if "alias" in exp.arg_types:
+ if "alias" in exp.arg_types and not isinstance(exp, Window):
exp = exp.copy()
exp.set("alias", alias)
return exp
@@ -3138,6 +3290,60 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
+def table_name(table):
+ """Get the full name of a table as a string.
+
+ Args:
+ table (exp.Table | str): Table expression node or string.
+
+ Examples:
+ >>> from sqlglot import exp, parse_one
+ >>> table_name(parse_one("select * from a.b.c").find(exp.Table))
+ 'a.b.c'
+
+ Returns:
+ str: the table name
+ """
+
+ table = maybe_parse(table, into=Table)
+
+ return ".".join(
+ part
+ for part in (
+ table.text("catalog"),
+ table.text("db"),
+ table.name,
+ )
+ if part
+ )
+
+
+def replace_tables(expression, mapping):
+ """Replace all tables in expression according to the mapping.
+
+ Args:
+ expression (sqlglot.Expression): Expression node to be transformed and replaced
+ mapping (Dict[str, str]): Mapping of table names
+
+ Examples:
+ >>> from sqlglot import exp, parse_one
+ >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
+ 'SELECT * FROM "c"'
+
+ Returns:
+ The mapped expression
+ """
+
+ def _replace_tables(node):
+ if isinstance(node, Table):
+ new_name = mapping.get(table_name(node))
+ if new_name:
+ return table_(*reversed(new_name.split(".")), quoted=True)
+ return node
+
+ return expression.transform(_replace_tables)
+
+
TRUE = Boolean(this=True)
FALSE = Boolean(this=False)
NULL = Null()