diff options
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r-- | sqlglot/lineage.py | 105 |
1 files changed, 52 insertions, 53 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 011a6b8..abcc10f 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from sqlglot import Schema, exp, maybe_parse from sqlglot.errors import SqlglotError -from sqlglot.optimizer import Scope, build_scope, qualify +from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -29,8 +29,38 @@ class Node: else: yield d - def to_html(self, **opts) -> LineageHTML: - return LineageHTML(self, **opts) + def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: + nodes = {} + edges = [] + + for node in self.walk(): + if isinstance(node.expression, exp.Table): + label = f"FROM {node.expression.this}" + title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" + group = 1 + else: + label = node.expression.sql(pretty=True, dialect=dialect) + source = node.source.transform( + lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") + if n is node.expression + else n, + copy=False, + ).sql(pretty=True, dialect=dialect) + title = f"<pre>{source}</pre>" + group = 0 + + node_id = id(node) + + nodes[node_id] = { + "id": node_id, + "label": label, + "title": title, + "group": group, + } + + for d in node.downstream: + edges.append({"from": node_id, "to": id(d)}) + return GraphHTML(nodes, edges, **opts) def lineage( @@ -64,6 +94,7 @@ def lineage( k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect)) for k, v in sources.items() }, + dialect=dialect, ) qualified = qualify.qualify( @@ -129,17 +160,6 @@ def lineage( return upstream - subquery = select.unalias() - - if isinstance(subquery, exp.Subquery): - upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select) - scope = t.cast(Scope, build_scope(subquery.unnest())) - - for select in subquery.named_selects: - to_node(select, scope=scope, upstream=upstream) - - return upstream - if isinstance(scope.expression, exp.Select): # For better ergonomics in our node labels, replace the full select with # a version that has only the column we care about. @@ -156,16 +176,28 @@ def lineage( expression=select, alias=alias or "", ) + if upstream: upstream.downstream.append(node) + subquery_scopes = { + id(subquery_scope.expression): subquery_scope + for subquery_scope in scope.subquery_scopes + } + + for subquery in find_all_in_scope(select, exp.Subqueryable): + subquery_scope = subquery_scopes[id(subquery)] + + for name in subquery.named_selects: + to_node(name, scope=subquery_scope, upstream=node) + # if the select is a star add all scope sources as downstreams if select.is_star: for source in scope.sources.values(): node.downstream.append(Node(name=select.sql(), source=source, expression=source)) # Find all columns that went into creating this one to list their lineage nodes. - source_columns = set(select.find_all(exp.Column)) + source_columns = set(find_all_in_scope(select, exp.Column)) # If the source is a UDTF find columns used in the UTDF to generate the table if isinstance(source, exp.UDTF): @@ -192,20 +224,15 @@ def lineage( return to_node(column if isinstance(column, str) else column.name, scope) -class LineageHTML: +class GraphHTML: """Node to HTML generator using vis.js. https://visjs.github.io/vis-network/docs/network/ """ def __init__( - self, - node: Node, - dialect: DialectType = None, - imports: bool = True, - **opts: t.Any, + self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None ): - self.node = node self.imports = imports self.options = { @@ -235,39 +262,11 @@ class LineageHTML: "maximum": 300, }, }, - **opts, + **(options or {}), } - self.nodes = {} - self.edges = [] - - for node in node.walk(): - if isinstance(node.expression, exp.Table): - label = f"FROM {node.expression.this}" - title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" - group = 1 - else: - label = node.expression.sql(pretty=True, dialect=dialect) - source = node.source.transform( - lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") - if n is node.expression - else n, - copy=False, - ).sql(pretty=True, dialect=dialect) - title = f"<pre>{source}</pre>" - group = 0 - - node_id = id(node) - - self.nodes[node_id] = { - "id": node_id, - "label": label, - "title": title, - "group": group, - } - - for d in node.downstream: - self.edges.append({"from": node_id, "to": id(d)}) + self.nodes = nodes + self.edges = edges def __str__(self): nodes = json.dumps(list(self.nodes.values())) |