diff options
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r-- | sqlglot/lineage.py | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 113458f..011a6b8 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -112,17 +112,34 @@ def lineage( column if isinstance(column, int) else next( - i - for i, select in enumerate(scope.expression.selects) - if select.alias_or_name == column + ( + i + for i, select in enumerate(scope.expression.selects) + if select.alias_or_name == column or select.is_star + ), + -1, # mypy will not allow a None here, but a negative index should never be returned ) ) + if index == -1: + raise ValueError(f"Could not find {column} in {scope.expression}") + for s in scope.union_scopes: to_node(index, scope=s, upstream=upstream) return upstream + subquery = select.unalias() + + if isinstance(subquery, exp.Subquery): + upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select) + scope = t.cast(Scope, build_scope(subquery.unnest())) + + for select in subquery.named_selects: + to_node(select, scope=scope, upstream=upstream) + + return upstream + if isinstance(scope.expression, exp.Select): # For better ergonomics in our node labels, replace the full select with # a version that has only the column we care about. @@ -142,8 +159,19 @@ def lineage( if upstream: upstream.downstream.append(node) + # if the select is a star add all scope sources as downstreams + if select.is_star: + for source in scope.sources.values(): + node.downstream.append(Node(name=select.sql(), source=source, expression=source)) + # Find all columns that went into creating this one to list their lineage nodes. - for c in set(select.find_all(exp.Column)): + source_columns = set(select.find_all(exp.Column)) + + # If the source is a UDTF find columns used in the UTDF to generate the table + if isinstance(source, exp.UDTF): + source_columns |= set(source.find_all(exp.Column)) + + for c in source_columns: table = c.table source = scope.sources.get(table) |