summaryrefslogtreecommitdiffstats
path: root/sqlglot/lineage.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r--sqlglot/lineage.py105
1 files changed, 52 insertions, 53 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 011a6b8..abcc10f 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.errors import SqlglotError
-from sqlglot.optimizer import Scope, build_scope, qualify
+from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@@ -29,8 +29,38 @@ class Node:
else:
yield d
- def to_html(self, **opts) -> LineageHTML:
- return LineageHTML(self, **opts)
+ def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
+ nodes = {}
+ edges = []
+
+ for node in self.walk():
+ if isinstance(node.expression, exp.Table):
+ label = f"FROM {node.expression.this}"
+ title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
+ group = 1
+ else:
+ label = node.expression.sql(pretty=True, dialect=dialect)
+ source = node.source.transform(
+ lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
+ if n is node.expression
+ else n,
+ copy=False,
+ ).sql(pretty=True, dialect=dialect)
+ title = f"<pre>{source}</pre>"
+ group = 0
+
+ node_id = id(node)
+
+ nodes[node_id] = {
+ "id": node_id,
+ "label": label,
+ "title": title,
+ "group": group,
+ }
+
+ for d in node.downstream:
+ edges.append({"from": node_id, "to": id(d)})
+ return GraphHTML(nodes, edges, **opts)
def lineage(
@@ -64,6 +94,7 @@ def lineage(
k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
for k, v in sources.items()
},
+ dialect=dialect,
)
qualified = qualify.qualify(
@@ -129,17 +160,6 @@ def lineage(
return upstream
- subquery = select.unalias()
-
- if isinstance(subquery, exp.Subquery):
- upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
- scope = t.cast(Scope, build_scope(subquery.unnest()))
-
- for select in subquery.named_selects:
- to_node(select, scope=scope, upstream=upstream)
-
- return upstream
-
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
@@ -156,16 +176,28 @@ def lineage(
expression=select,
alias=alias or "",
)
+
if upstream:
upstream.downstream.append(node)
+ subquery_scopes = {
+ id(subquery_scope.expression): subquery_scope
+ for subquery_scope in scope.subquery_scopes
+ }
+
+ for subquery in find_all_in_scope(select, exp.Subqueryable):
+ subquery_scope = subquery_scopes[id(subquery)]
+
+ for name in subquery.named_selects:
+ to_node(name, scope=subquery_scope, upstream=node)
+
# if the select is a star add all scope sources as downstreams
if select.is_star:
for source in scope.sources.values():
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
# Find all columns that went into creating this one to list their lineage nodes.
- source_columns = set(select.find_all(exp.Column))
+ source_columns = set(find_all_in_scope(select, exp.Column))
# If the source is a UDTF find columns used in the UTDF to generate the table
if isinstance(source, exp.UDTF):
@@ -192,20 +224,15 @@ def lineage(
return to_node(column if isinstance(column, str) else column.name, scope)
-class LineageHTML:
+class GraphHTML:
"""Node to HTML generator using vis.js.
https://visjs.github.io/vis-network/docs/network/
"""
def __init__(
- self,
- node: Node,
- dialect: DialectType = None,
- imports: bool = True,
- **opts: t.Any,
+ self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
):
- self.node = node
self.imports = imports
self.options = {
@@ -235,39 +262,11 @@ class LineageHTML:
"maximum": 300,
},
},
- **opts,
+ **(options or {}),
}
- self.nodes = {}
- self.edges = []
-
- for node in node.walk():
- if isinstance(node.expression, exp.Table):
- label = f"FROM {node.expression.this}"
- title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
- group = 1
- else:
- label = node.expression.sql(pretty=True, dialect=dialect)
- source = node.source.transform(
- lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
- if n is node.expression
- else n,
- copy=False,
- ).sql(pretty=True, dialect=dialect)
- title = f"<pre>{source}</pre>"
- group = 0
-
- node_id = id(node)
-
- self.nodes[node_id] = {
- "id": node_id,
- "label": label,
- "title": title,
- "group": group,
- }
-
- for d in node.downstream:
- self.edges.append({"from": node_id, "to": id(d)})
+ self.nodes = nodes
+ self.edges = edges
def __str__(self):
nodes = json.dumps(list(self.nodes.values()))