Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.errors import SqlglotError
  9from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot.dialects.dialect import DialectType
 13
 14
 15@dataclass(frozen=True)
 16class Node:
 17    name: str
 18    expression: exp.Expression
 19    source: exp.Expression
 20    downstream: t.List[Node] = field(default_factory=list)
 21    alias: str = ""
 22
 23    def walk(self) -> t.Iterator[Node]:
 24        yield self
 25
 26        for d in self.downstream:
 27            if isinstance(d, Node):
 28                yield from d.walk()
 29            else:
 30                yield d
 31
 32    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
 33        nodes = {}
 34        edges = []
 35
 36        for node in self.walk():
 37            if isinstance(node.expression, exp.Table):
 38                label = f"FROM {node.expression.this}"
 39                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 40                group = 1
 41            else:
 42                label = node.expression.sql(pretty=True, dialect=dialect)
 43                source = node.source.transform(
 44                    lambda n: (
 45                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 46                    ),
 47                    copy=False,
 48                ).sql(pretty=True, dialect=dialect)
 49                title = f"<pre>{source}</pre>"
 50                group = 0
 51
 52            node_id = id(node)
 53
 54            nodes[node_id] = {
 55                "id": node_id,
 56                "label": label,
 57                "title": title,
 58                "group": group,
 59            }
 60
 61            for d in node.downstream:
 62                edges.append({"from": node_id, "to": id(d)})
 63        return GraphHTML(nodes, edges, **opts)
 64
 65
 66def lineage(
 67    column: str | exp.Column,
 68    sql: str | exp.Expression,
 69    schema: t.Optional[t.Dict | Schema] = None,
 70    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 71    dialect: DialectType = None,
 72    **kwargs,
 73) -> Node:
 74    """Build the lineage graph for a column of a SQL query.
 75
 76    Args:
 77        column: The column to build the lineage for.
 78        sql: The SQL string or expression.
 79        schema: The schema of tables.
 80        sources: A mapping of queries which will be used to continue building lineage.
 81        dialect: The dialect of input SQL.
 82        **kwargs: Qualification optimizer kwargs.
 83
 84    Returns:
 85        A lineage node.
 86    """
 87
 88    expression = maybe_parse(sql, dialect=dialect)
 89
 90    if sources:
 91        expression = exp.expand(
 92            expression,
 93            {
 94                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 95                for k, v in sources.items()
 96            },
 97            dialect=dialect,
 98        )
 99
100    qualified = qualify.qualify(
101        expression,
102        dialect=dialect,
103        schema=schema,
104        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
105    )
106
107    scope = build_scope(qualified)
108
109    if not scope:
110        raise SqlglotError("Cannot build lineage, sql must be SELECT")
111
112    def to_node(
113        column: str | int,
114        scope: Scope,
115        scope_name: t.Optional[str] = None,
116        upstream: t.Optional[Node] = None,
117        alias: t.Optional[str] = None,
118    ) -> Node:
119        aliases = {
120            dt.alias: dt.comments[0].split()[1]
121            for dt in scope.derived_tables
122            if dt.comments and dt.comments[0].startswith("source: ")
123        }
124
125        # Find the specific select clause that is the source of the column we want.
126        # This can either be a specific, named select or a generic `*` clause.
127        select = (
128            scope.expression.selects[column]
129            if isinstance(column, int)
130            else next(
131                (select for select in scope.expression.selects if select.alias_or_name == column),
132                exp.Star() if scope.expression.is_star else scope.expression,
133            )
134        )
135
136        if isinstance(scope.expression, exp.Union):
137            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
138
139            index = (
140                column
141                if isinstance(column, int)
142                else next(
143                    (
144                        i
145                        for i, select in enumerate(scope.expression.selects)
146                        if select.alias_or_name == column or select.is_star
147                    ),
148                    -1,  # mypy will not allow a None here, but a negative index should never be returned
149                )
150            )
151
152            if index == -1:
153                raise ValueError(f"Could not find {column} in {scope.expression}")
154
155            for s in scope.union_scopes:
156                to_node(index, scope=s, upstream=upstream, alias=alias)
157
158            return upstream
159
160        if isinstance(scope.expression, exp.Select):
161            # For better ergonomics in our node labels, replace the full select with
162            # a version that has only the column we care about.
163            #   "x", SELECT x, y FROM foo
164            #     => "x", SELECT x FROM foo
165            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
166        else:
167            source = scope.expression
168
169        # Create the node for this step in the lineage chain, and attach it to the previous one.
170        node = Node(
171            name=f"{scope_name}.{column}" if scope_name else str(column),
172            source=source,
173            expression=select,
174            alias=alias or "",
175        )
176
177        if upstream:
178            upstream.downstream.append(node)
179
180        subquery_scopes = {
181            id(subquery_scope.expression): subquery_scope
182            for subquery_scope in scope.subquery_scopes
183        }
184
185        for subquery in find_all_in_scope(select, exp.Subqueryable):
186            subquery_scope = subquery_scopes[id(subquery)]
187
188            for name in subquery.named_selects:
189                to_node(name, scope=subquery_scope, upstream=node)
190
191        # if the select is a star add all scope sources as downstreams
192        if select.is_star:
193            for source in scope.sources.values():
194                if isinstance(source, Scope):
195                    source = source.expression
196                node.downstream.append(Node(name=select.sql(), source=source, expression=source))
197
198        # Find all columns that went into creating this one to list their lineage nodes.
199        source_columns = set(find_all_in_scope(select, exp.Column))
200
201        # If the source is a UDTF find columns used in the UTDF to generate the table
202        if isinstance(source, exp.UDTF):
203            source_columns |= set(source.find_all(exp.Column))
204
205        for c in source_columns:
206            table = c.table
207            source = scope.sources.get(table)
208
209            if isinstance(source, Scope):
210                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
211                to_node(
212                    c.name,
213                    scope=source,
214                    scope_name=table,
215                    upstream=node,
216                    alias=aliases.get(table) or alias,
217                )
218            else:
219                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
220                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
221                # is not passed into the `sources` map.
222                source = source or exp.Placeholder()
223                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
224
225        return node
226
227    return to_node(column if isinstance(column, str) else column.name, scope)
228
229
230class GraphHTML:
231    """Node to HTML generator using vis.js.
232
233    https://visjs.github.io/vis-network/docs/network/
234    """
235
236    def __init__(
237        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
238    ):
239        self.imports = imports
240
241        self.options = {
242            "height": "500px",
243            "width": "100%",
244            "layout": {
245                "hierarchical": {
246                    "enabled": True,
247                    "nodeSpacing": 200,
248                    "sortMethod": "directed",
249                },
250            },
251            "interaction": {
252                "dragNodes": False,
253                "selectable": False,
254            },
255            "physics": {
256                "enabled": False,
257            },
258            "edges": {
259                "arrows": "to",
260            },
261            "nodes": {
262                "font": "20px monaco",
263                "shape": "box",
264                "widthConstraint": {
265                    "maximum": 300,
266                },
267            },
268            **(options or {}),
269        }
270
271        self.nodes = nodes
272        self.edges = edges
273
274    def __str__(self):
275        nodes = json.dumps(list(self.nodes.values()))
276        edges = json.dumps(self.edges)
277        options = json.dumps(self.options)
278        imports = (
279            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
280  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
281  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
282            if self.imports
283            else ""
284        )
285
286        return f"""<div>
287  <div id="sqlglot-lineage"></div>
288  {imports}
289  <script type="text/javascript">
290    var nodes = new vis.DataSet({nodes})
291    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
292
293    new vis.Network(
294        document.getElementById("sqlglot-lineage"),
295        {{
296            nodes: nodes,
297            edges: new vis.DataSet({edges})
298        }},
299        {options},
300    )
301  </script>
302</div>"""
303
304    def _repr_html_(self) -> str:
305        return self.__str__()
@dataclass(frozen=True)
class Node:
16@dataclass(frozen=True)
17class Node:
18    name: str
19    expression: exp.Expression
20    source: exp.Expression
21    downstream: t.List[Node] = field(default_factory=list)
22    alias: str = ""
23
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
32
33    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
34        nodes = {}
35        edges = []
36
37        for node in self.walk():
38            if isinstance(node.expression, exp.Table):
39                label = f"FROM {node.expression.this}"
40                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
41                group = 1
42            else:
43                label = node.expression.sql(pretty=True, dialect=dialect)
44                source = node.source.transform(
45                    lambda n: (
46                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
47                    ),
48                    copy=False,
49                ).sql(pretty=True, dialect=dialect)
50                title = f"<pre>{source}</pre>"
51                group = 0
52
53            node_id = id(node)
54
55            nodes[node_id] = {
56                "id": node_id,
57                "label": label,
58                "title": title,
59                "group": group,
60            }
61
62            for d in node.downstream:
63                edges.append({"from": node_id, "to": id(d)})
64        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, alias: str = '')
name: str
downstream: List[Node]
alias: str = ''
def walk(self) -> Iterator[Node]:
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
def to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
33    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
34        nodes = {}
35        edges = []
36
37        for node in self.walk():
38            if isinstance(node.expression, exp.Table):
39                label = f"FROM {node.expression.this}"
40                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
41                group = 1
42            else:
43                label = node.expression.sql(pretty=True, dialect=dialect)
44                source = node.source.transform(
45                    lambda n: (
46                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
47                    ),
48                    copy=False,
49                ).sql(pretty=True, dialect=dialect)
50                title = f"<pre>{source}</pre>"
51                group = 0
52
53            node_id = id(node)
54
55            nodes[node_id] = {
56                "id": node_id,
57                "label": label,
58                "title": title,
59                "group": group,
60            }
61
62            for d in node.downstream:
63                edges.append({"from": node_id, "to": id(d)})
64        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
 67def lineage(
 68    column: str | exp.Column,
 69    sql: str | exp.Expression,
 70    schema: t.Optional[t.Dict | Schema] = None,
 71    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 72    dialect: DialectType = None,
 73    **kwargs,
 74) -> Node:
 75    """Build the lineage graph for a column of a SQL query.
 76
 77    Args:
 78        column: The column to build the lineage for.
 79        sql: The SQL string or expression.
 80        schema: The schema of tables.
 81        sources: A mapping of queries which will be used to continue building lineage.
 82        dialect: The dialect of input SQL.
 83        **kwargs: Qualification optimizer kwargs.
 84
 85    Returns:
 86        A lineage node.
 87    """
 88
 89    expression = maybe_parse(sql, dialect=dialect)
 90
 91    if sources:
 92        expression = exp.expand(
 93            expression,
 94            {
 95                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 96                for k, v in sources.items()
 97            },
 98            dialect=dialect,
 99        )
100
101    qualified = qualify.qualify(
102        expression,
103        dialect=dialect,
104        schema=schema,
105        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
106    )
107
108    scope = build_scope(qualified)
109
110    if not scope:
111        raise SqlglotError("Cannot build lineage, sql must be SELECT")
112
113    def to_node(
114        column: str | int,
115        scope: Scope,
116        scope_name: t.Optional[str] = None,
117        upstream: t.Optional[Node] = None,
118        alias: t.Optional[str] = None,
119    ) -> Node:
120        aliases = {
121            dt.alias: dt.comments[0].split()[1]
122            for dt in scope.derived_tables
123            if dt.comments and dt.comments[0].startswith("source: ")
124        }
125
126        # Find the specific select clause that is the source of the column we want.
127        # This can either be a specific, named select or a generic `*` clause.
128        select = (
129            scope.expression.selects[column]
130            if isinstance(column, int)
131            else next(
132                (select for select in scope.expression.selects if select.alias_or_name == column),
133                exp.Star() if scope.expression.is_star else scope.expression,
134            )
135        )
136
137        if isinstance(scope.expression, exp.Union):
138            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
139
140            index = (
141                column
142                if isinstance(column, int)
143                else next(
144                    (
145                        i
146                        for i, select in enumerate(scope.expression.selects)
147                        if select.alias_or_name == column or select.is_star
148                    ),
149                    -1,  # mypy will not allow a None here, but a negative index should never be returned
150                )
151            )
152
153            if index == -1:
154                raise ValueError(f"Could not find {column} in {scope.expression}")
155
156            for s in scope.union_scopes:
157                to_node(index, scope=s, upstream=upstream, alias=alias)
158
159            return upstream
160
161        if isinstance(scope.expression, exp.Select):
162            # For better ergonomics in our node labels, replace the full select with
163            # a version that has only the column we care about.
164            #   "x", SELECT x, y FROM foo
165            #     => "x", SELECT x FROM foo
166            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
167        else:
168            source = scope.expression
169
170        # Create the node for this step in the lineage chain, and attach it to the previous one.
171        node = Node(
172            name=f"{scope_name}.{column}" if scope_name else str(column),
173            source=source,
174            expression=select,
175            alias=alias or "",
176        )
177
178        if upstream:
179            upstream.downstream.append(node)
180
181        subquery_scopes = {
182            id(subquery_scope.expression): subquery_scope
183            for subquery_scope in scope.subquery_scopes
184        }
185
186        for subquery in find_all_in_scope(select, exp.Subqueryable):
187            subquery_scope = subquery_scopes[id(subquery)]
188
189            for name in subquery.named_selects:
190                to_node(name, scope=subquery_scope, upstream=node)
191
192        # if the select is a star add all scope sources as downstreams
193        if select.is_star:
194            for source in scope.sources.values():
195                if isinstance(source, Scope):
196                    source = source.expression
197                node.downstream.append(Node(name=select.sql(), source=source, expression=source))
198
199        # Find all columns that went into creating this one to list their lineage nodes.
200        source_columns = set(find_all_in_scope(select, exp.Column))
201
202        # If the source is a UDTF find columns used in the UTDF to generate the table
203        if isinstance(source, exp.UDTF):
204            source_columns |= set(source.find_all(exp.Column))
205
206        for c in source_columns:
207            table = c.table
208            source = scope.sources.get(table)
209
210            if isinstance(source, Scope):
211                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
212                to_node(
213                    c.name,
214                    scope=source,
215                    scope_name=table,
216                    upstream=node,
217                    alias=aliases.get(table) or alias,
218                )
219            else:
220                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
221                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
222                # is not passed into the `sources` map.
223                source = source or exp.Placeholder()
224                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
225
226        return node
227
228    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

class GraphHTML:
231class GraphHTML:
232    """Node to HTML generator using vis.js.
233
234    https://visjs.github.io/vis-network/docs/network/
235    """
236
237    def __init__(
238        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
239    ):
240        self.imports = imports
241
242        self.options = {
243            "height": "500px",
244            "width": "100%",
245            "layout": {
246                "hierarchical": {
247                    "enabled": True,
248                    "nodeSpacing": 200,
249                    "sortMethod": "directed",
250                },
251            },
252            "interaction": {
253                "dragNodes": False,
254                "selectable": False,
255            },
256            "physics": {
257                "enabled": False,
258            },
259            "edges": {
260                "arrows": "to",
261            },
262            "nodes": {
263                "font": "20px monaco",
264                "shape": "box",
265                "widthConstraint": {
266                    "maximum": 300,
267                },
268            },
269            **(options or {}),
270        }
271
272        self.nodes = nodes
273        self.edges = edges
274
275    def __str__(self):
276        nodes = json.dumps(list(self.nodes.values()))
277        edges = json.dumps(self.edges)
278        options = json.dumps(self.options)
279        imports = (
280            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
281  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
282  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
283            if self.imports
284            else ""
285        )
286
287        return f"""<div>
288  <div id="sqlglot-lineage"></div>
289  {imports}
290  <script type="text/javascript">
291    var nodes = new vis.DataSet({nodes})
292    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
293
294    new vis.Network(
295        document.getElementById("sqlglot-lineage"),
296        {{
297            nodes: nodes,
298            edges: new vis.DataSet({edges})
299        }},
300        {options},
301    )
302  </script>
303</div>"""
304
305    def _repr_html_(self) -> str:
306        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
237    def __init__(
238        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
239    ):
240        self.imports = imports
241
242        self.options = {
243            "height": "500px",
244            "width": "100%",
245            "layout": {
246                "hierarchical": {
247                    "enabled": True,
248                    "nodeSpacing": 200,
249                    "sortMethod": "directed",
250                },
251            },
252            "interaction": {
253                "dragNodes": False,
254                "selectable": False,
255            },
256            "physics": {
257                "enabled": False,
258            },
259            "edges": {
260                "arrows": "to",
261            },
262            "nodes": {
263                "font": "20px monaco",
264                "shape": "box",
265                "widthConstraint": {
266                    "maximum": 300,
267                },
268            },
269            **(options or {}),
270        }
271
272        self.nodes = nodes
273        self.edges = edges
imports
options
nodes
edges