diff options
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r-- | sqlglot/lineage.py | 44 |
1 files changed, 28 insertions, 16 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 9f5ae9a..113458f 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -79,7 +79,7 @@ def lineage( raise SqlglotError("Cannot build lineage, sql must be SELECT") def to_node( - column_name: str, + column: str | int, scope: Scope, scope_name: t.Optional[str] = None, upstream: t.Optional[Node] = None, @@ -90,26 +90,38 @@ def lineage( for dt in scope.derived_tables if dt.comments and dt.comments[0].startswith("source: ") } - if isinstance(scope.expression, exp.Union): - for scope in scope.union_scopes: - node = to_node( - column_name, - scope=scope, - scope_name=scope_name, - upstream=upstream, - alias=aliases.get(scope_name), - ) - return node # Find the specific select clause that is the source of the column we want. # This can either be a specific, named select or a generic `*` clause. - select = next( - (select for select in scope.expression.selects if select.alias_or_name == column_name), - exp.Star() if scope.expression.is_star else None, + select = ( + scope.expression.selects[column] + if isinstance(column, int) + else next( + (select for select in scope.expression.selects if select.alias_or_name == column), + exp.Star() if scope.expression.is_star else None, + ) ) if not select: - raise ValueError(f"Could not find {column_name} in {scope.expression}") + raise ValueError(f"Could not find {column} in {scope.expression}") + + if isinstance(scope.expression, exp.Union): + upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) + + index = ( + column + if isinstance(column, int) + else next( + i + for i, select in enumerate(scope.expression.selects) + if select.alias_or_name == column + ) + ) + + for s in scope.union_scopes: + to_node(index, scope=s, upstream=upstream) + + return upstream if isinstance(scope.expression, exp.Select): # For better ergonomics in our node labels, replace the full select with @@ -122,7 +134,7 @@ def lineage( # Create the node for this step in the lineage chain, and attach it to the previous one. node = Node( - name=f"{scope_name}.{column_name}" if scope_name else column_name, + name=f"{scope_name}.{column}" if scope_name else str(column), source=source, expression=select, alias=alias or "", |