diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
commit | 42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch) | |
tree | 5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot/lineage.py | |
parent | Releasing debian version 21.1.2-1. (diff) | |
download | sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip |
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r-- | sqlglot/lineage.py | 237 |
1 files changed, 129 insertions, 108 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index f10fbb9..eb428dc 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -1,16 +1,19 @@ from __future__ import annotations import json +import logging import typing as t from dataclasses import dataclass, field from sqlglot import Schema, exp, maybe_parse from sqlglot.errors import SqlglotError -from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify +from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType +logger = logging.getLogger("sqlglot") + @dataclass(frozen=True) class Node: @@ -18,7 +21,8 @@ class Node: expression: exp.Expression source: exp.Expression downstream: t.List[Node] = field(default_factory=list) - alias: str = "" + source_name: str = "" + reference_node_name: str = "" def walk(self) -> t.Iterator[Node]: yield self @@ -67,7 +71,7 @@ def lineage( column: str | exp.Column, sql: str | exp.Expression, schema: t.Optional[t.Dict | Schema] = None, - sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, + sources: t.Optional[t.Dict[str, str | exp.Query]] = None, dialect: DialectType = None, **kwargs, ) -> Node: @@ -86,14 +90,12 @@ def lineage( """ expression = maybe_parse(sql, dialect=dialect) + column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name if sources: expression = exp.expand( expression, - { - k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect)) - for k, v in sources.items() - }, + {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, dialect=dialect, ) @@ -109,122 +111,141 @@ def lineage( if not scope: raise SqlglotError("Cannot build lineage, sql must be SELECT") - def to_node( - column: str | int, - scope: Scope, - scope_name: t.Optional[str] = None, - upstream: t.Optional[Node] = None, - alias: t.Optional[str] = None, - ) -> Node: - aliases = { - dt.alias: dt.comments[0].split()[1] - for dt in scope.derived_tables - if dt.comments and dt.comments[0].startswith("source: ") - } + if not any(select.alias_or_name == column for select in scope.expression.selects): + raise SqlglotError(f"Cannot find column '{column}' in query.") - # Find the specific select clause that is the source of the column we want. - # This can either be a specific, named select or a generic `*` clause. - select = ( - scope.expression.selects[column] - if isinstance(column, int) - else next( - (select for select in scope.expression.selects if select.alias_or_name == column), - exp.Star() if scope.expression.is_star else scope.expression, - ) - ) + return to_node(column, scope, dialect) - if isinstance(scope.expression, exp.Union): - upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) - - index = ( - column - if isinstance(column, int) - else next( - ( - i - for i, select in enumerate(scope.expression.selects) - if select.alias_or_name == column or select.is_star - ), - -1, # mypy will not allow a None here, but a negative index should never be returned - ) - ) - - if index == -1: - raise ValueError(f"Could not find {column} in {scope.expression}") - for s in scope.union_scopes: - to_node(index, scope=s, upstream=upstream, alias=alias) - - return upstream - - if isinstance(scope.expression, exp.Select): - # For better ergonomics in our node labels, replace the full select with - # a version that has only the column we care about. - # "x", SELECT x, y FROM foo - # => "x", SELECT x FROM foo - source = t.cast(exp.Expression, scope.expression.select(select, append=False)) - else: - source = scope.expression - - # Create the node for this step in the lineage chain, and attach it to the previous one. - node = Node( - name=f"{scope_name}.{column}" if scope_name else str(column), - source=source, - expression=select, - alias=alias or "", +def to_node( + column: str | int, + scope: Scope, + dialect: DialectType, + scope_name: t.Optional[str] = None, + upstream: t.Optional[Node] = None, + source_name: t.Optional[str] = None, + reference_node_name: t.Optional[str] = None, +) -> Node: + source_names = { + dt.alias: dt.comments[0].split()[1] + for dt in scope.derived_tables + if dt.comments and dt.comments[0].startswith("source: ") + } + + # Find the specific select clause that is the source of the column we want. + # This can either be a specific, named select or a generic `*` clause. + select = ( + scope.expression.selects[column] + if isinstance(column, int) + else next( + (select for select in scope.expression.selects if select.alias_or_name == column), + exp.Star() if scope.expression.is_star else scope.expression, ) + ) - if upstream: - upstream.downstream.append(node) + if isinstance(scope.expression, exp.Union): + upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) - subquery_scopes = { - id(subquery_scope.expression): subquery_scope - for subquery_scope in scope.subquery_scopes - } + index = ( + column + if isinstance(column, int) + else next( + ( + i + for i, select in enumerate(scope.expression.selects) + if select.alias_or_name == column or select.is_star + ), + -1, # mypy will not allow a None here, but a negative index should never be returned + ) + ) - for subquery in find_all_in_scope(select, exp.Subqueryable): - subquery_scope = subquery_scopes[id(subquery)] + if index == -1: + raise ValueError(f"Could not find {column} in {scope.expression}") + + for s in scope.union_scopes: + to_node( + index, + scope=s, + dialect=dialect, + upstream=upstream, + source_name=source_name, + reference_node_name=reference_node_name, + ) - for name in subquery.named_selects: - to_node(name, scope=subquery_scope, upstream=node) + return upstream + + if isinstance(scope.expression, exp.Select): + # For better ergonomics in our node labels, replace the full select with + # a version that has only the column we care about. + # "x", SELECT x, y FROM foo + # => "x", SELECT x FROM foo + source = t.cast(exp.Expression, scope.expression.select(select, append=False)) + else: + source = scope.expression + + # Create the node for this step in the lineage chain, and attach it to the previous one. + node = Node( + name=f"{scope_name}.{column}" if scope_name else str(column), + source=source, + expression=select, + source_name=source_name or "", + reference_node_name=reference_node_name or "", + ) - # if the select is a star add all scope sources as downstreams - if select.is_star: - for source in scope.sources.values(): - if isinstance(source, Scope): - source = source.expression - node.downstream.append(Node(name=select.sql(), source=source, expression=source)) + if upstream: + upstream.downstream.append(node) - # Find all columns that went into creating this one to list their lineage nodes. - source_columns = set(find_all_in_scope(select, exp.Column)) + subquery_scopes = { + id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes + } - # If the source is a UDTF find columns used in the UTDF to generate the table - if isinstance(source, exp.UDTF): - source_columns |= set(source.find_all(exp.Column)) + for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): + subquery_scope = subquery_scopes.get(id(subquery)) + if not subquery_scope: + logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") + continue - for c in source_columns: - table = c.table - source = scope.sources.get(table) + for name in subquery.named_selects: + to_node(name, scope=subquery_scope, dialect=dialect, upstream=node) + # if the select is a star add all scope sources as downstreams + if select.is_star: + for source in scope.sources.values(): if isinstance(source, Scope): - # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. - to_node( - c.name, - scope=source, - scope_name=table, - upstream=node, - alias=aliases.get(table) or alias, - ) - else: - # The source is not a scope - we've reached the end of the line. At this point, if a source is not found - # it means this column's lineage is unknown. This can happen if the definition of a source used in a query - # is not passed into the `sources` map. - source = source or exp.Placeholder() - node.downstream.append(Node(name=c.sql(), source=source, expression=source)) - - return node + source = source.expression + node.downstream.append(Node(name=select.sql(), source=source, expression=source)) + + # Find all columns that went into creating this one to list their lineage nodes. + source_columns = set(find_all_in_scope(select, exp.Column)) + + # If the source is a UDTF find columns used in the UTDF to generate the table + if isinstance(source, exp.UDTF): + source_columns |= set(source.find_all(exp.Column)) + + for c in source_columns: + table = c.table + source = scope.sources.get(table) + + if isinstance(source, Scope): + selected_node, _ = scope.selected_sources.get(table, (None, None)) + # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. + to_node( + c.name, + scope=source, + dialect=dialect, + scope_name=table, + upstream=node, + source_name=source_names.get(table) or source_name, + reference_node_name=selected_node.name if selected_node else None, + ) + else: + # The source is not a scope - we've reached the end of the line. At this point, if a source is not found + # it means this column's lineage is unknown. This can happen if the definition of a source used in a query + # is not passed into the `sources` map. + source = source or exp.Placeholder() + node.downstream.append(Node(name=c.sql(), source=source, expression=source)) - return to_node(column if isinstance(column, str) else column.name, scope) + return node class GraphHTML: |