From 67c28dbe67209effad83d93b850caba5ee1e20e3 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 3 May 2023 11:12:28 +0200 Subject: Merging upstream version 11.7.1. Signed-off-by: Daniel Baumann --- sqlglot/lineage.py | 53 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 13 deletions(-) (limited to 'sqlglot/lineage.py') diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 2e563ae..0eac870 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -20,6 +20,7 @@ class Node: expression: exp.Expression source: exp.Expression downstream: t.List[Node] = field(default_factory=list) + alias: str = "" def walk(self) -> t.Iterator[Node]: yield self @@ -69,14 +70,19 @@ def lineage( optimized = optimize(expression, schema=schema, rules=rules) scope = build_scope(optimized) - tables: t.Dict[str, Node] = {} def to_node( column_name: str, scope: Scope, scope_name: t.Optional[str] = None, upstream: t.Optional[Node] = None, + alias: t.Optional[str] = None, ) -> Node: + aliases = { + dt.alias: dt.comments[0].split()[1] + for dt in scope.derived_tables + if dt.comments and dt.comments[0].startswith("source: ") + } if isinstance(scope.expression, exp.Union): for scope in scope.union_scopes: node = to_node( @@ -84,37 +90,58 @@ def lineage( scope=scope, scope_name=scope_name, upstream=upstream, + alias=aliases.get(scope_name), ) return node - select = next(select for select in scope.selects if select.alias_or_name == column_name) - source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules) - select = source.selects[0] + # Find the specific select clause that is the source of the column we want. + # This can either be a specific, named select or a generic `*` clause. + select = next( + (select for select in scope.selects if select.alias_or_name == column_name), + exp.Star() if scope.expression.is_star else None, + ) + if not select: + raise ValueError(f"Could not find {column_name} in {scope.expression}") + + if isinstance(scope.expression, exp.Select): + # For better ergonomics in our node labels, replace the full select with + # a version that has only the column we care about. + # "x", SELECT x, y FROM foo + # => "x", SELECT x FROM foo + source = optimize( + scope.expression.select(select, append=False), schema=schema, rules=rules + ) + select = source.selects[0] + else: + source = scope.expression + + # Create the node for this step in the lineage chain, and attach it to the previous one. node = Node( name=f"{scope_name}.{column_name}" if scope_name else column_name, source=source, expression=select, + alias=alias or "", ) - if upstream: upstream.downstream.append(node) + # Find all columns that went into creating this one to list their lineage nodes. for c in set(select.find_all(exp.Column)): table = c.table - source = scope.sources[table] + source = scope.sources.get(table) if isinstance(source, Scope): + # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. to_node( - c.name, - scope=source, - scope_name=table, - upstream=node, + c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table) ) else: - if table not in tables: - tables[table] = Node(name=c.sql(), source=source, expression=source) - node.downstream.append(tables[table]) + # The source is not a scope - we've reached the end of the line. At this point, if a source is not found + # it means this column's lineage is unknown. This can happen if the definition of a source used in a query + # is not passed into the `sources` map. + source = source or exp.Placeholder() + node.downstream.append(Node(name=c.sql(), source=source, expression=source)) return node -- cgit v1.2.3