summaryrefslogtreecommitdiffstats
path: root/sqlglot/lineage.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/lineage.py36
1 files changed, 32 insertions, 4 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 113458f..011a6b8 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -112,17 +112,34 @@ def lineage(
column
if isinstance(column, int)
else next(
- i
- for i, select in enumerate(scope.expression.selects)
- if select.alias_or_name == column
+ (
+ i
+ for i, select in enumerate(scope.expression.selects)
+ if select.alias_or_name == column or select.is_star
+ ),
+ -1, # mypy will not allow a None here, but a negative index should never be returned
)
)
+ if index == -1:
+ raise ValueError(f"Could not find {column} in {scope.expression}")
+
for s in scope.union_scopes:
to_node(index, scope=s, upstream=upstream)
return upstream
+ subquery = select.unalias()
+
+ if isinstance(subquery, exp.Subquery):
+ upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
+ scope = t.cast(Scope, build_scope(subquery.unnest()))
+
+ for select in subquery.named_selects:
+ to_node(select, scope=scope, upstream=upstream)
+
+ return upstream
+
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
@@ -142,8 +159,19 @@ def lineage(
if upstream:
upstream.downstream.append(node)
+ # if the select is a star add all scope sources as downstreams
+ if select.is_star:
+ for source in scope.sources.values():
+ node.downstream.append(Node(name=select.sql(), source=source, expression=source))
+
# Find all columns that went into creating this one to list their lineage nodes.
- for c in set(select.find_all(exp.Column)):
+ source_columns = set(select.find_all(exp.Column))
+
+ # If the source is a UDTF find columns used in the UTDF to generate the table
+ if isinstance(source, exp.UDTF):
+ source_columns |= set(source.find_all(exp.Column))
+
+ for c in source_columns:
table = c.table
source = scope.sources.get(table)