summaryrefslogtreecommitdiffstats
path: root/sqlglot/lineage.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
commit42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch)
tree5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot/lineage.py
parentReleasing debian version 21.1.2-1. (diff)
downloadsqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz
sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r--sqlglot/lineage.py237
1 files changed, 129 insertions, 108 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index f10fbb9..eb428dc 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -1,16 +1,19 @@
from __future__ import annotations
import json
+import logging
import typing as t
from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.errors import SqlglotError
-from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
+from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
+logger = logging.getLogger("sqlglot")
+
@dataclass(frozen=True)
class Node:
@@ -18,7 +21,8 @@ class Node:
expression: exp.Expression
source: exp.Expression
downstream: t.List[Node] = field(default_factory=list)
- alias: str = ""
+ source_name: str = ""
+ reference_node_name: str = ""
def walk(self) -> t.Iterator[Node]:
yield self
@@ -67,7 +71,7 @@ def lineage(
column: str | exp.Column,
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
- sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
+ sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
dialect: DialectType = None,
**kwargs,
) -> Node:
@@ -86,14 +90,12 @@ def lineage(
"""
expression = maybe_parse(sql, dialect=dialect)
+ column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
if sources:
expression = exp.expand(
expression,
- {
- k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
- for k, v in sources.items()
- },
+ {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
dialect=dialect,
)
@@ -109,122 +111,141 @@ def lineage(
if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")
- def to_node(
- column: str | int,
- scope: Scope,
- scope_name: t.Optional[str] = None,
- upstream: t.Optional[Node] = None,
- alias: t.Optional[str] = None,
- ) -> Node:
- aliases = {
- dt.alias: dt.comments[0].split()[1]
- for dt in scope.derived_tables
- if dt.comments and dt.comments[0].startswith("source: ")
- }
+ if not any(select.alias_or_name == column for select in scope.expression.selects):
+ raise SqlglotError(f"Cannot find column '{column}' in query.")
- # Find the specific select clause that is the source of the column we want.
- # This can either be a specific, named select or a generic `*` clause.
- select = (
- scope.expression.selects[column]
- if isinstance(column, int)
- else next(
- (select for select in scope.expression.selects if select.alias_or_name == column),
- exp.Star() if scope.expression.is_star else scope.expression,
- )
- )
+ return to_node(column, scope, dialect)
- if isinstance(scope.expression, exp.Union):
- upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
-
- index = (
- column
- if isinstance(column, int)
- else next(
- (
- i
- for i, select in enumerate(scope.expression.selects)
- if select.alias_or_name == column or select.is_star
- ),
- -1, # mypy will not allow a None here, but a negative index should never be returned
- )
- )
-
- if index == -1:
- raise ValueError(f"Could not find {column} in {scope.expression}")
- for s in scope.union_scopes:
- to_node(index, scope=s, upstream=upstream, alias=alias)
-
- return upstream
-
- if isinstance(scope.expression, exp.Select):
- # For better ergonomics in our node labels, replace the full select with
- # a version that has only the column we care about.
- # "x", SELECT x, y FROM foo
- # => "x", SELECT x FROM foo
- source = t.cast(exp.Expression, scope.expression.select(select, append=False))
- else:
- source = scope.expression
-
- # Create the node for this step in the lineage chain, and attach it to the previous one.
- node = Node(
- name=f"{scope_name}.{column}" if scope_name else str(column),
- source=source,
- expression=select,
- alias=alias or "",
+def to_node(
+ column: str | int,
+ scope: Scope,
+ dialect: DialectType,
+ scope_name: t.Optional[str] = None,
+ upstream: t.Optional[Node] = None,
+ source_name: t.Optional[str] = None,
+ reference_node_name: t.Optional[str] = None,
+) -> Node:
+ source_names = {
+ dt.alias: dt.comments[0].split()[1]
+ for dt in scope.derived_tables
+ if dt.comments and dt.comments[0].startswith("source: ")
+ }
+
+ # Find the specific select clause that is the source of the column we want.
+ # This can either be a specific, named select or a generic `*` clause.
+ select = (
+ scope.expression.selects[column]
+ if isinstance(column, int)
+ else next(
+ (select for select in scope.expression.selects if select.alias_or_name == column),
+ exp.Star() if scope.expression.is_star else scope.expression,
)
+ )
- if upstream:
- upstream.downstream.append(node)
+ if isinstance(scope.expression, exp.Union):
+ upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
- subquery_scopes = {
- id(subquery_scope.expression): subquery_scope
- for subquery_scope in scope.subquery_scopes
- }
+ index = (
+ column
+ if isinstance(column, int)
+ else next(
+ (
+ i
+ for i, select in enumerate(scope.expression.selects)
+ if select.alias_or_name == column or select.is_star
+ ),
+ -1, # mypy will not allow a None here, but a negative index should never be returned
+ )
+ )
- for subquery in find_all_in_scope(select, exp.Subqueryable):
- subquery_scope = subquery_scopes[id(subquery)]
+ if index == -1:
+ raise ValueError(f"Could not find {column} in {scope.expression}")
+
+ for s in scope.union_scopes:
+ to_node(
+ index,
+ scope=s,
+ dialect=dialect,
+ upstream=upstream,
+ source_name=source_name,
+ reference_node_name=reference_node_name,
+ )
- for name in subquery.named_selects:
- to_node(name, scope=subquery_scope, upstream=node)
+ return upstream
+
+ if isinstance(scope.expression, exp.Select):
+ # For better ergonomics in our node labels, replace the full select with
+ # a version that has only the column we care about.
+ # "x", SELECT x, y FROM foo
+ # => "x", SELECT x FROM foo
+ source = t.cast(exp.Expression, scope.expression.select(select, append=False))
+ else:
+ source = scope.expression
+
+ # Create the node for this step in the lineage chain, and attach it to the previous one.
+ node = Node(
+ name=f"{scope_name}.{column}" if scope_name else str(column),
+ source=source,
+ expression=select,
+ source_name=source_name or "",
+ reference_node_name=reference_node_name or "",
+ )
- # if the select is a star add all scope sources as downstreams
- if select.is_star:
- for source in scope.sources.values():
- if isinstance(source, Scope):
- source = source.expression
- node.downstream.append(Node(name=select.sql(), source=source, expression=source))
+ if upstream:
+ upstream.downstream.append(node)
- # Find all columns that went into creating this one to list their lineage nodes.
- source_columns = set(find_all_in_scope(select, exp.Column))
+ subquery_scopes = {
+ id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
+ }
- # If the source is a UDTF find columns used in the UTDF to generate the table
- if isinstance(source, exp.UDTF):
- source_columns |= set(source.find_all(exp.Column))
+ for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
+ subquery_scope = subquery_scopes.get(id(subquery))
+ if not subquery_scope:
+ logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
+ continue
- for c in source_columns:
- table = c.table
- source = scope.sources.get(table)
+ for name in subquery.named_selects:
+ to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
+ # if the select is a star add all scope sources as downstreams
+ if select.is_star:
+ for source in scope.sources.values():
if isinstance(source, Scope):
- # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
- to_node(
- c.name,
- scope=source,
- scope_name=table,
- upstream=node,
- alias=aliases.get(table) or alias,
- )
- else:
- # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
- # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
- # is not passed into the `sources` map.
- source = source or exp.Placeholder()
- node.downstream.append(Node(name=c.sql(), source=source, expression=source))
-
- return node
+ source = source.expression
+ node.downstream.append(Node(name=select.sql(), source=source, expression=source))
+
+ # Find all columns that went into creating this one to list their lineage nodes.
+ source_columns = set(find_all_in_scope(select, exp.Column))
+
+ # If the source is a UDTF find columns used in the UTDF to generate the table
+ if isinstance(source, exp.UDTF):
+ source_columns |= set(source.find_all(exp.Column))
+
+ for c in source_columns:
+ table = c.table
+ source = scope.sources.get(table)
+
+ if isinstance(source, Scope):
+ selected_node, _ = scope.selected_sources.get(table, (None, None))
+ # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
+ to_node(
+ c.name,
+ scope=source,
+ dialect=dialect,
+ scope_name=table,
+ upstream=node,
+ source_name=source_names.get(table) or source_name,
+ reference_node_name=selected_node.name if selected_node else None,
+ )
+ else:
+ # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
+ # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
+ # is not passed into the `sources` map.
+ source = source or exp.Placeholder()
+ node.downstream.append(Node(name=c.sql(), source=source, expression=source))
- return to_node(column if isinstance(column, str) else column.name, scope)
+ return node
class GraphHTML: