summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py81
1 files changed, 65 insertions, 16 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 39f4452..f7717c8 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -11,6 +11,7 @@ from sqlglot.helper import (
camel_to_snake_case,
ensure_list,
list_get,
+ split_num_words,
subclasses,
)
@@ -108,6 +109,8 @@ class Expression(metaclass=_Expression):
@property
def alias_or_name(self):
+ if isinstance(self, Null):
+ return "NULL"
return self.alias or self.name
def __deepcopy__(self, memo):
@@ -659,6 +662,10 @@ class HexString(Condition):
pass
+class ByteString(Condition):
+ pass
+
+
class Column(Condition):
arg_types = {"this": True, "table": False}
@@ -725,7 +732,7 @@ class Constraint(Expression):
class Delete(Expression):
- arg_types = {"with": False, "this": True, "where": False}
+ arg_types = {"with": False, "this": True, "using": False, "where": False}
class Drop(Expression):
@@ -1192,6 +1199,7 @@ QUERY_MODIFIERS = {
class Table(Expression):
arg_types = {
"this": True,
+ "alias": False,
"db": False,
"catalog": False,
"laterals": False,
@@ -1323,6 +1331,7 @@ class Select(Subqueryable):
*expressions (str or Expression): the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Group`.
+ If nothing is passed in then a group by is not applied to the expression
append (bool): if `True`, add to any existing expressions.
Otherwise, this flattens all the `Group` expression into a single expression.
dialect (str): the dialect used to parse the input expression.
@@ -1332,6 +1341,8 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
+ if not expressions:
+ return self if not copy else self.copy()
return _apply_child_list_builder(
*expressions,
instance=self,
@@ -2239,6 +2250,11 @@ class ArrayAny(Func):
arg_types = {"this": True, "expression": True}
+class ArrayConcat(Func):
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
+
+
class ArrayContains(Func):
arg_types = {"this": True, "expression": True}
@@ -2570,7 +2586,7 @@ class SortArray(Func):
class Split(Func):
- arg_types = {"this": True, "expression": True}
+ arg_types = {"this": True, "expression": True, "limit": False}
# Start may be omitted in the case of postgres
@@ -3209,29 +3225,49 @@ def to_identifier(alias, quoted=None):
return identifier
-def to_table(sql_path, **kwargs):
+def to_table(sql_path: str, **kwargs) -> Table:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
- Example:
- >>> to_table('catalog.db.table_name').sql()
- 'catalog.db.table_name'
+
+ If a table is passed in then that table is returned.
Args:
- sql_path(str): `[catalog].[schema].[table]` string
+ sql_path(str|Table): `[catalog].[schema].[table]` string
Returns:
Table: A table expression
"""
- table_parts = sql_path.split(".")
- catalog, db, table_name = [
- to_identifier(x) if x is not None else x for x in [None] * (3 - len(table_parts)) + table_parts
- ]
+ if sql_path is None or isinstance(sql_path, Table):
+ return sql_path
+ if not isinstance(sql_path, str):
+ raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
+
+ catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
+def to_column(sql_path: str, **kwargs) -> Column:
+ """
+ Create a column from a `[table].[column]` sql path. Schema is optional.
+
+ If a column is passed in then that column is returned.
+
+ Args:
+ sql_path: `[table].[column]` string
+ Returns:
+ Table: A column expression
+ """
+ if sql_path is None or isinstance(sql_path, Column):
+ return sql_path
+ if not isinstance(sql_path, str):
+ raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
+ table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
+ return Column(this=column_name, table=table_name, **kwargs)
+
+
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
"""
Create an Alias expression.
- Expample:
+ Example:
>>> alias_('foo', 'bar').sql()
'foo AS bar'
@@ -3249,7 +3285,16 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
"""
exp = maybe_parse(expression, dialect=dialect, **opts)
alias = to_identifier(alias, quoted=quoted)
- alias = TableAlias(this=alias) if table else alias
+
+ if table:
+ expression.set("alias", TableAlias(this=alias))
+ return expression
+
+ # We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
+ # the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node
+ # for the complete Window expression.
+ #
+ # [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls
if "alias" in exp.arg_types and not isinstance(exp, Window):
exp = exp.copy()
@@ -3295,7 +3340,7 @@ def column(col, table=None, quoted=None):
)
-def table_(table, db=None, catalog=None, quoted=None):
+def table_(table, db=None, catalog=None, quoted=None, alias=None):
"""Build a Table.
Args:
@@ -3310,6 +3355,7 @@ def table_(table, db=None, catalog=None, quoted=None):
this=to_identifier(table, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
+ alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
@@ -3453,7 +3499,7 @@ def replace_tables(expression, mapping):
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
- 'SELECT * FROM "c"'
+ 'SELECT * FROM c'
Returns:
The mapped expression
@@ -3463,7 +3509,10 @@ def replace_tables(expression, mapping):
if isinstance(node, Table):
new_name = mapping.get(table_name(node))
if new_name:
- return table_(*reversed(new_name.split(".")), quoted=True)
+ return to_table(
+ new_name,
+ **{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")},
+ )
return node
return expression.transform(_replace_tables)