Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.optimizer import Scope, build_scope, optimize
  9from sqlglot.optimizer.expand_laterals import expand_laterals
 10from sqlglot.optimizer.qualify_columns import qualify_columns
 11from sqlglot.optimizer.qualify_tables import qualify_tables
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dialects.dialect import DialectType
 15
 16
 17@dataclass(frozen=True)
 18class Node:
 19    name: str
 20    expression: exp.Expression
 21    source: exp.Expression
 22    downstream: t.List[Node] = field(default_factory=list)
 23
 24    def walk(self) -> t.Iterator[Node]:
 25        yield self
 26
 27        for d in self.downstream:
 28            if isinstance(d, Node):
 29                yield from d.walk()
 30            else:
 31                yield d
 32
 33    def to_html(self, **opts) -> LineageHTML:
 34        return LineageHTML(self, **opts)
 35
 36
 37def lineage(
 38    column: str | exp.Column,
 39    sql: str | exp.Expression,
 40    schema: t.Optional[t.Dict | Schema] = None,
 41    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 42    rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
 43    dialect: DialectType = None,
 44) -> Node:
 45    """Build the lineage graph for a column of a SQL query.
 46
 47    Args:
 48        column: The column to build the lineage for.
 49        sql: The SQL string or expression.
 50        schema: The schema of tables.
 51        sources: A mapping of queries which will be used to continue building lineage.
 52        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 53        dialect: The dialect of input SQL.
 54
 55    Returns:
 56        A lineage node.
 57    """
 58
 59    expression = maybe_parse(sql, dialect=dialect)
 60
 61    if sources:
 62        expression = exp.expand(
 63            expression,
 64            {
 65                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 66                for k, v in sources.items()
 67            },
 68        )
 69
 70    optimized = optimize(expression, schema=schema, rules=rules)
 71    scope = build_scope(optimized)
 72    tables: t.Dict[str, Node] = {}
 73
 74    def to_node(
 75        column_name: str,
 76        scope: Scope,
 77        scope_name: t.Optional[str] = None,
 78        upstream: t.Optional[Node] = None,
 79    ) -> Node:
 80        if isinstance(scope.expression, exp.Union):
 81            for scope in scope.union_scopes:
 82                node = to_node(
 83                    column_name,
 84                    scope=scope,
 85                    scope_name=scope_name,
 86                    upstream=upstream,
 87                )
 88            return node
 89
 90        select = next(select for select in scope.selects if select.alias_or_name == column_name)
 91        source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
 92        select = source.selects[0]
 93
 94        node = Node(
 95            name=f"{scope_name}.{column_name}" if scope_name else column_name,
 96            source=source,
 97            expression=select,
 98        )
 99
100        if upstream:
101            upstream.downstream.append(node)
102
103        for c in set(select.find_all(exp.Column)):
104            table = c.table
105            source = scope.sources[table]
106
107            if isinstance(source, Scope):
108                to_node(
109                    c.name,
110                    scope=source,
111                    scope_name=table,
112                    upstream=node,
113                )
114            else:
115                if table not in tables:
116                    tables[table] = Node(name=c.sql(), source=source, expression=source)
117                node.downstream.append(tables[table])
118
119        return node
120
121    return to_node(column if isinstance(column, str) else column.name, scope)
122
123
124class LineageHTML:
125    """Node to HTML generator using vis.js.
126
127    https://visjs.github.io/vis-network/docs/network/
128    """
129
130    def __init__(
131        self,
132        node: Node,
133        dialect: DialectType = None,
134        imports: bool = True,
135        **opts: t.Any,
136    ):
137        self.node = node
138        self.imports = imports
139
140        self.options = {
141            "height": "500px",
142            "width": "100%",
143            "layout": {
144                "hierarchical": {
145                    "enabled": True,
146                    "nodeSpacing": 200,
147                    "sortMethod": "directed",
148                },
149            },
150            "interaction": {
151                "dragNodes": False,
152                "selectable": False,
153            },
154            "physics": {
155                "enabled": False,
156            },
157            "edges": {
158                "arrows": "to",
159            },
160            "nodes": {
161                "font": "20px monaco",
162                "shape": "box",
163                "widthConstraint": {
164                    "maximum": 300,
165                },
166            },
167            **opts,
168        }
169
170        self.nodes = {}
171        self.edges = []
172
173        for node in node.walk():
174            if isinstance(node.expression, exp.Table):
175                label = f"FROM {node.expression.this}"
176                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
177                group = 1
178            else:
179                label = node.expression.sql(pretty=True, dialect=dialect)
180                source = node.source.transform(
181                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
182                    if n is node.expression
183                    else n,
184                    copy=False,
185                ).sql(pretty=True, dialect=dialect)
186                title = f"<pre>{source}</pre>"
187                group = 0
188
189            node_id = id(node)
190
191            self.nodes[node_id] = {
192                "id": node_id,
193                "label": label,
194                "title": title,
195                "group": group,
196            }
197
198            for d in node.downstream:
199                self.edges.append({"from": node_id, "to": id(d)})
200
201    def __str__(self):
202        nodes = json.dumps(list(self.nodes.values()))
203        edges = json.dumps(self.edges)
204        options = json.dumps(self.options)
205        imports = (
206            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
207  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
208  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
209            if self.imports
210            else ""
211        )
212
213        return f"""<div>
214  <div id="sqlglot-lineage"></div>
215  {imports}
216  <script type="text/javascript">
217    var nodes = new vis.DataSet({nodes})
218    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
219
220    new vis.Network(
221        document.getElementById("sqlglot-lineage"),
222        {{
223            nodes: nodes,
224            edges: new vis.DataSet({edges})
225        }},
226        {options},
227    )
228  </script>
229</div>"""
230
231    def _repr_html_(self) -> str:
232        return self.__str__()
@dataclass(frozen=True)
class Node:
18@dataclass(frozen=True)
19class Node:
20    name: str
21    expression: exp.Expression
22    source: exp.Expression
23    downstream: t.List[Node] = field(default_factory=list)
24
25    def walk(self) -> t.Iterator[Node]:
26        yield self
27
28        for d in self.downstream:
29            if isinstance(d, Node):
30                yield from d.walk()
31            else:
32                yield d
33
34    def to_html(self, **opts) -> LineageHTML:
35        return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[sqlglot.lineage.Node] = <factory>)
def walk(self) -> Iterator[sqlglot.lineage.Node]:
25    def walk(self) -> t.Iterator[Node]:
26        yield self
27
28        for d in self.downstream:
29            if isinstance(d, Node):
30                yield from d.walk()
31            else:
32                yield d
def to_html(self, **opts) -> sqlglot.lineage.LineageHTML:
34    def to_html(self, **opts) -> LineageHTML:
35        return LineageHTML(self, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, rules: Sequence[Callable] = (<function qualify_tables at 0x7f0637f3fa30>, <function qualify_columns at 0x7f0637f3e950>, <function expand_laterals at 0x7f0637f3d000>), dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.lineage.Node:
 38def lineage(
 39    column: str | exp.Column,
 40    sql: str | exp.Expression,
 41    schema: t.Optional[t.Dict | Schema] = None,
 42    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 43    rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
 44    dialect: DialectType = None,
 45) -> Node:
 46    """Build the lineage graph for a column of a SQL query.
 47
 48    Args:
 49        column: The column to build the lineage for.
 50        sql: The SQL string or expression.
 51        schema: The schema of tables.
 52        sources: A mapping of queries which will be used to continue building lineage.
 53        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 54        dialect: The dialect of input SQL.
 55
 56    Returns:
 57        A lineage node.
 58    """
 59
 60    expression = maybe_parse(sql, dialect=dialect)
 61
 62    if sources:
 63        expression = exp.expand(
 64            expression,
 65            {
 66                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 67                for k, v in sources.items()
 68            },
 69        )
 70
 71    optimized = optimize(expression, schema=schema, rules=rules)
 72    scope = build_scope(optimized)
 73    tables: t.Dict[str, Node] = {}
 74
 75    def to_node(
 76        column_name: str,
 77        scope: Scope,
 78        scope_name: t.Optional[str] = None,
 79        upstream: t.Optional[Node] = None,
 80    ) -> Node:
 81        if isinstance(scope.expression, exp.Union):
 82            for scope in scope.union_scopes:
 83                node = to_node(
 84                    column_name,
 85                    scope=scope,
 86                    scope_name=scope_name,
 87                    upstream=upstream,
 88                )
 89            return node
 90
 91        select = next(select for select in scope.selects if select.alias_or_name == column_name)
 92        source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
 93        select = source.selects[0]
 94
 95        node = Node(
 96            name=f"{scope_name}.{column_name}" if scope_name else column_name,
 97            source=source,
 98            expression=select,
 99        )
100
101        if upstream:
102            upstream.downstream.append(node)
103
104        for c in set(select.find_all(exp.Column)):
105            table = c.table
106            source = scope.sources[table]
107
108            if isinstance(source, Scope):
109                to_node(
110                    c.name,
111                    scope=source,
112                    scope_name=table,
113                    upstream=node,
114                )
115            else:
116                if table not in tables:
117                    tables[table] = Node(name=c.sql(), source=source, expression=source)
118                node.downstream.append(tables[table])
119
120        return node
121
122    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • rules: Optimizer rules to apply, by default only qualifying tables and columns.
  • dialect: The dialect of input SQL.
Returns:

A lineage node.

class LineageHTML:
125class LineageHTML:
126    """Node to HTML generator using vis.js.
127
128    https://visjs.github.io/vis-network/docs/network/
129    """
130
131    def __init__(
132        self,
133        node: Node,
134        dialect: DialectType = None,
135        imports: bool = True,
136        **opts: t.Any,
137    ):
138        self.node = node
139        self.imports = imports
140
141        self.options = {
142            "height": "500px",
143            "width": "100%",
144            "layout": {
145                "hierarchical": {
146                    "enabled": True,
147                    "nodeSpacing": 200,
148                    "sortMethod": "directed",
149                },
150            },
151            "interaction": {
152                "dragNodes": False,
153                "selectable": False,
154            },
155            "physics": {
156                "enabled": False,
157            },
158            "edges": {
159                "arrows": "to",
160            },
161            "nodes": {
162                "font": "20px monaco",
163                "shape": "box",
164                "widthConstraint": {
165                    "maximum": 300,
166                },
167            },
168            **opts,
169        }
170
171        self.nodes = {}
172        self.edges = []
173
174        for node in node.walk():
175            if isinstance(node.expression, exp.Table):
176                label = f"FROM {node.expression.this}"
177                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
178                group = 1
179            else:
180                label = node.expression.sql(pretty=True, dialect=dialect)
181                source = node.source.transform(
182                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
183                    if n is node.expression
184                    else n,
185                    copy=False,
186                ).sql(pretty=True, dialect=dialect)
187                title = f"<pre>{source}</pre>"
188                group = 0
189
190            node_id = id(node)
191
192            self.nodes[node_id] = {
193                "id": node_id,
194                "label": label,
195                "title": title,
196                "group": group,
197            }
198
199            for d in node.downstream:
200                self.edges.append({"from": node_id, "to": id(d)})
201
202    def __str__(self):
203        nodes = json.dumps(list(self.nodes.values()))
204        edges = json.dumps(self.edges)
205        options = json.dumps(self.options)
206        imports = (
207            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
208  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
209  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
210            if self.imports
211            else ""
212        )
213
214        return f"""<div>
215  <div id="sqlglot-lineage"></div>
216  {imports}
217  <script type="text/javascript">
218    var nodes = new vis.DataSet({nodes})
219    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
220
221    new vis.Network(
222        document.getElementById("sqlglot-lineage"),
223        {{
224            nodes: nodes,
225            edges: new vis.DataSet({edges})
226        }},
227        {options},
228    )
229  </script>
230</div>"""
231
232    def _repr_html_(self) -> str:
233        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

LineageHTML( node: sqlglot.lineage.Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
131    def __init__(
132        self,
133        node: Node,
134        dialect: DialectType = None,
135        imports: bool = True,
136        **opts: t.Any,
137    ):
138        self.node = node
139        self.imports = imports
140
141        self.options = {
142            "height": "500px",
143            "width": "100%",
144            "layout": {
145                "hierarchical": {
146                    "enabled": True,
147                    "nodeSpacing": 200,
148                    "sortMethod": "directed",
149                },
150            },
151            "interaction": {
152                "dragNodes": False,
153                "selectable": False,
154            },
155            "physics": {
156                "enabled": False,
157            },
158            "edges": {
159                "arrows": "to",
160            },
161            "nodes": {
162                "font": "20px monaco",
163                "shape": "box",
164                "widthConstraint": {
165                    "maximum": 300,
166                },
167            },
168            **opts,
169        }
170
171        self.nodes = {}
172        self.edges = []
173
174        for node in node.walk():
175            if isinstance(node.expression, exp.Table):
176                label = f"FROM {node.expression.this}"
177                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
178                group = 1
179            else:
180                label = node.expression.sql(pretty=True, dialect=dialect)
181                source = node.source.transform(
182                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
183                    if n is node.expression
184                    else n,
185                    copy=False,
186                ).sql(pretty=True, dialect=dialect)
187                title = f"<pre>{source}</pre>"
188                group = 0
189
190            node_id = id(node)
191
192            self.nodes[node_id] = {
193                "id": node_id,
194                "label": label,
195                "title": title,
196                "group": group,
197            }
198
199            for d in node.downstream:
200                self.edges.append({"from": node_id, "to": id(d)})