summaryrefslogtreecommitdiffstats
path: root/sqlglot/lineage.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r--sqlglot/lineage.py53
1 files changed, 40 insertions, 13 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 2e563ae..0eac870 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -20,6 +20,7 @@ class Node:
expression: exp.Expression
source: exp.Expression
downstream: t.List[Node] = field(default_factory=list)
+ alias: str = ""
def walk(self) -> t.Iterator[Node]:
yield self
@@ -69,14 +70,19 @@ def lineage(
optimized = optimize(expression, schema=schema, rules=rules)
scope = build_scope(optimized)
- tables: t.Dict[str, Node] = {}
def to_node(
column_name: str,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
+ alias: t.Optional[str] = None,
) -> Node:
+ aliases = {
+ dt.alias: dt.comments[0].split()[1]
+ for dt in scope.derived_tables
+ if dt.comments and dt.comments[0].startswith("source: ")
+ }
if isinstance(scope.expression, exp.Union):
for scope in scope.union_scopes:
node = to_node(
@@ -84,37 +90,58 @@ def lineage(
scope=scope,
scope_name=scope_name,
upstream=upstream,
+ alias=aliases.get(scope_name),
)
return node
- select = next(select for select in scope.selects if select.alias_or_name == column_name)
- source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
- select = source.selects[0]
+ # Find the specific select clause that is the source of the column we want.
+ # This can either be a specific, named select or a generic `*` clause.
+ select = next(
+ (select for select in scope.selects if select.alias_or_name == column_name),
+ exp.Star() if scope.expression.is_star else None,
+ )
+ if not select:
+ raise ValueError(f"Could not find {column_name} in {scope.expression}")
+
+ if isinstance(scope.expression, exp.Select):
+ # For better ergonomics in our node labels, replace the full select with
+ # a version that has only the column we care about.
+ # "x", SELECT x, y FROM foo
+ # => "x", SELECT x FROM foo
+ source = optimize(
+ scope.expression.select(select, append=False), schema=schema, rules=rules
+ )
+ select = source.selects[0]
+ else:
+ source = scope.expression
+
+ # Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
name=f"{scope_name}.{column_name}" if scope_name else column_name,
source=source,
expression=select,
+ alias=alias or "",
)
-
if upstream:
upstream.downstream.append(node)
+ # Find all columns that went into creating this one to list their lineage nodes.
for c in set(select.find_all(exp.Column)):
table = c.table
- source = scope.sources[table]
+ source = scope.sources.get(table)
if isinstance(source, Scope):
+ # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
- c.name,
- scope=source,
- scope_name=table,
- upstream=node,
+ c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
)
else:
- if table not in tables:
- tables[table] = Node(name=c.sql(), source=source, expression=source)
- node.downstream.append(tables[table])
+ # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
+ # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
+ # is not passed into the `sources` map.
+ source = source or exp.Placeholder()
+ node.downstream.append(Node(name=c.sql(), source=source, expression=source))
return node