summaryrefslogtreecommitdiffstats
path: root/sqlglot/lineage.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r--sqlglot/lineage.py44
1 files changed, 28 insertions, 16 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 9f5ae9a..113458f 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -79,7 +79,7 @@ def lineage(
raise SqlglotError("Cannot build lineage, sql must be SELECT")
def to_node(
- column_name: str,
+ column: str | int,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
@@ -90,26 +90,38 @@ def lineage(
for dt in scope.derived_tables
if dt.comments and dt.comments[0].startswith("source: ")
}
- if isinstance(scope.expression, exp.Union):
- for scope in scope.union_scopes:
- node = to_node(
- column_name,
- scope=scope,
- scope_name=scope_name,
- upstream=upstream,
- alias=aliases.get(scope_name),
- )
- return node
# Find the specific select clause that is the source of the column we want.
# This can either be a specific, named select or a generic `*` clause.
- select = next(
- (select for select in scope.expression.selects if select.alias_or_name == column_name),
- exp.Star() if scope.expression.is_star else None,
+ select = (
+ scope.expression.selects[column]
+ if isinstance(column, int)
+ else next(
+ (select for select in scope.expression.selects if select.alias_or_name == column),
+ exp.Star() if scope.expression.is_star else None,
+ )
)
if not select:
- raise ValueError(f"Could not find {column_name} in {scope.expression}")
+ raise ValueError(f"Could not find {column} in {scope.expression}")
+
+ if isinstance(scope.expression, exp.Union):
+ upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
+
+ index = (
+ column
+ if isinstance(column, int)
+ else next(
+ i
+ for i, select in enumerate(scope.expression.selects)
+ if select.alias_or_name == column
+ )
+ )
+
+ for s in scope.union_scopes:
+ to_node(index, scope=s, upstream=upstream)
+
+ return upstream
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
@@ -122,7 +134,7 @@ def lineage(
# Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
- name=f"{scope_name}.{column_name}" if scope_name else column_name,
+ name=f"{scope_name}.{column}" if scope_name else str(column),
source=source,
expression=select,
alias=alias or "",