Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  7
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15logger = logging.getLogger("sqlglot")
 16
 17
 18@dataclass(frozen=True)
 19class Node:
 20    name: str
 21    expression: exp.Expression
 22    source: exp.Expression
 23    downstream: t.List[Node] = field(default_factory=list)
 24    source_name: str = ""
 25    reference_node_name: str = ""
 26
 27    def walk(self) -> t.Iterator[Node]:
 28        yield self
 29
 30        for d in self.downstream:
 31            if isinstance(d, Node):
 32                yield from d.walk()
 33            else:
 34                yield d
 35
 36    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
 37        nodes = {}
 38        edges = []
 39
 40        for node in self.walk():
 41            if isinstance(node.expression, exp.Table):
 42                label = f"FROM {node.expression.this}"
 43                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 44                group = 1
 45            else:
 46                label = node.expression.sql(pretty=True, dialect=dialect)
 47                source = node.source.transform(
 48                    lambda n: (
 49                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 50                    ),
 51                    copy=False,
 52                ).sql(pretty=True, dialect=dialect)
 53                title = f"<pre>{source}</pre>"
 54                group = 0
 55
 56            node_id = id(node)
 57
 58            nodes[node_id] = {
 59                "id": node_id,
 60                "label": label,
 61                "title": title,
 62                "group": group,
 63            }
 64
 65            for d in node.downstream:
 66                edges.append({"from": node_id, "to": id(d)})
 67        return GraphHTML(nodes, edges, **opts)
 68
 69
 70def lineage(
 71    column: str | exp.Column,
 72    sql: str | exp.Expression,
 73    schema: t.Optional[t.Dict | Schema] = None,
 74    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 75    dialect: DialectType = None,
 76    **kwargs,
 77) -> Node:
 78    """Build the lineage graph for a column of a SQL query.
 79
 80    Args:
 81        column: The column to build the lineage for.
 82        sql: The SQL string or expression.
 83        schema: The schema of tables.
 84        sources: A mapping of queries which will be used to continue building lineage.
 85        dialect: The dialect of input SQL.
 86        **kwargs: Qualification optimizer kwargs.
 87
 88    Returns:
 89        A lineage node.
 90    """
 91
 92    expression = maybe_parse(sql, dialect=dialect)
 93    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 94
 95    if sources:
 96        expression = exp.expand(
 97            expression,
 98            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
 99            dialect=dialect,
100        )
101
102    qualified = qualify.qualify(
103        expression,
104        dialect=dialect,
105        schema=schema,
106        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
107    )
108
109    scope = build_scope(qualified)
110
111    if not scope:
112        raise SqlglotError("Cannot build lineage, sql must be SELECT")
113
114    if not any(select.alias_or_name == column for select in scope.expression.selects):
115        raise SqlglotError(f"Cannot find column '{column}' in query.")
116
117    return to_node(column, scope, dialect)
118
119
120def to_node(
121    column: str | int,
122    scope: Scope,
123    dialect: DialectType,
124    scope_name: t.Optional[str] = None,
125    upstream: t.Optional[Node] = None,
126    source_name: t.Optional[str] = None,
127    reference_node_name: t.Optional[str] = None,
128) -> Node:
129    source_names = {
130        dt.alias: dt.comments[0].split()[1]
131        for dt in scope.derived_tables
132        if dt.comments and dt.comments[0].startswith("source: ")
133    }
134
135    # Find the specific select clause that is the source of the column we want.
136    # This can either be a specific, named select or a generic `*` clause.
137    select = (
138        scope.expression.selects[column]
139        if isinstance(column, int)
140        else next(
141            (select for select in scope.expression.selects if select.alias_or_name == column),
142            exp.Star() if scope.expression.is_star else scope.expression,
143        )
144    )
145
146    if isinstance(scope.expression, exp.Union):
147        upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
148
149        index = (
150            column
151            if isinstance(column, int)
152            else next(
153                (
154                    i
155                    for i, select in enumerate(scope.expression.selects)
156                    if select.alias_or_name == column or select.is_star
157                ),
158                -1,  # mypy will not allow a None here, but a negative index should never be returned
159            )
160        )
161
162        if index == -1:
163            raise ValueError(f"Could not find {column} in {scope.expression}")
164
165        for s in scope.union_scopes:
166            to_node(
167                index,
168                scope=s,
169                dialect=dialect,
170                upstream=upstream,
171                source_name=source_name,
172                reference_node_name=reference_node_name,
173            )
174
175        return upstream
176
177    if isinstance(scope.expression, exp.Select):
178        # For better ergonomics in our node labels, replace the full select with
179        # a version that has only the column we care about.
180        #   "x", SELECT x, y FROM foo
181        #     => "x", SELECT x FROM foo
182        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
183    else:
184        source = scope.expression
185
186    # Create the node for this step in the lineage chain, and attach it to the previous one.
187    node = Node(
188        name=f"{scope_name}.{column}" if scope_name else str(column),
189        source=source,
190        expression=select,
191        source_name=source_name or "",
192        reference_node_name=reference_node_name or "",
193    )
194
195    if upstream:
196        upstream.downstream.append(node)
197
198    subquery_scopes = {
199        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
200    }
201
202    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
203        subquery_scope = subquery_scopes.get(id(subquery))
204        if not subquery_scope:
205            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
206            continue
207
208        for name in subquery.named_selects:
209            to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
210
211    # if the select is a star add all scope sources as downstreams
212    if select.is_star:
213        for source in scope.sources.values():
214            if isinstance(source, Scope):
215                source = source.expression
216            node.downstream.append(Node(name=select.sql(), source=source, expression=source))
217
218    # Find all columns that went into creating this one to list their lineage nodes.
219    source_columns = set(find_all_in_scope(select, exp.Column))
220
221    # If the source is a UDTF find columns used in the UTDF to generate the table
222    if isinstance(source, exp.UDTF):
223        source_columns |= set(source.find_all(exp.Column))
224
225    for c in source_columns:
226        table = c.table
227        source = scope.sources.get(table)
228
229        if isinstance(source, Scope):
230            selected_node, _ = scope.selected_sources.get(table, (None, None))
231            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
232            to_node(
233                c.name,
234                scope=source,
235                dialect=dialect,
236                scope_name=table,
237                upstream=node,
238                source_name=source_names.get(table) or source_name,
239                reference_node_name=selected_node.name if selected_node else None,
240            )
241        else:
242            # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
243            # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
244            # is not passed into the `sources` map.
245            source = source or exp.Placeholder()
246            node.downstream.append(Node(name=c.sql(), source=source, expression=source))
247
248    return node
249
250
251class GraphHTML:
252    """Node to HTML generator using vis.js.
253
254    https://visjs.github.io/vis-network/docs/network/
255    """
256
257    def __init__(
258        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
259    ):
260        self.imports = imports
261
262        self.options = {
263            "height": "500px",
264            "width": "100%",
265            "layout": {
266                "hierarchical": {
267                    "enabled": True,
268                    "nodeSpacing": 200,
269                    "sortMethod": "directed",
270                },
271            },
272            "interaction": {
273                "dragNodes": False,
274                "selectable": False,
275            },
276            "physics": {
277                "enabled": False,
278            },
279            "edges": {
280                "arrows": "to",
281            },
282            "nodes": {
283                "font": "20px monaco",
284                "shape": "box",
285                "widthConstraint": {
286                    "maximum": 300,
287                },
288            },
289            **(options or {}),
290        }
291
292        self.nodes = nodes
293        self.edges = edges
294
295    def __str__(self):
296        nodes = json.dumps(list(self.nodes.values()))
297        edges = json.dumps(self.edges)
298        options = json.dumps(self.options)
299        imports = (
300            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
301  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
302  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
303            if self.imports
304            else ""
305        )
306
307        return f"""<div>
308  <div id="sqlglot-lineage"></div>
309  {imports}
310  <script type="text/javascript">
311    var nodes = new vis.DataSet({nodes})
312    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
313
314    new vis.Network(
315        document.getElementById("sqlglot-lineage"),
316        {{
317            nodes: nodes,
318            edges: new vis.DataSet({edges})
319        }},
320        {options},
321    )
322  </script>
323</div>"""
324
325    def _repr_html_(self) -> str:
326        return self.__str__()
logger = <Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class Node:
19@dataclass(frozen=True)
20class Node:
21    name: str
22    expression: exp.Expression
23    source: exp.Expression
24    downstream: t.List[Node] = field(default_factory=list)
25    source_name: str = ""
26    reference_node_name: str = ""
27
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            if isinstance(d, Node):
33                yield from d.walk()
34            else:
35                yield d
36
37    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
38        nodes = {}
39        edges = []
40
41        for node in self.walk():
42            if isinstance(node.expression, exp.Table):
43                label = f"FROM {node.expression.this}"
44                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
45                group = 1
46            else:
47                label = node.expression.sql(pretty=True, dialect=dialect)
48                source = node.source.transform(
49                    lambda n: (
50                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
51                    ),
52                    copy=False,
53                ).sql(pretty=True, dialect=dialect)
54                title = f"<pre>{source}</pre>"
55                group = 0
56
57            node_id = id(node)
58
59            nodes[node_id] = {
60                "id": node_id,
61                "label": label,
62                "title": title,
63                "group": group,
64            }
65
66            for d in node.downstream:
67                edges.append({"from": node_id, "to": id(d)})
68        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
name: str
downstream: List[Node]
source_name: str = ''
reference_node_name: str = ''
def walk(self) -> Iterator[Node]:
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            if isinstance(d, Node):
33                yield from d.walk()
34            else:
35                yield d
def to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
37    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
38        nodes = {}
39        edges = []
40
41        for node in self.walk():
42            if isinstance(node.expression, exp.Table):
43                label = f"FROM {node.expression.this}"
44                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
45                group = 1
46            else:
47                label = node.expression.sql(pretty=True, dialect=dialect)
48                source = node.source.transform(
49                    lambda n: (
50                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
51                    ),
52                    copy=False,
53                ).sql(pretty=True, dialect=dialect)
54                title = f"<pre>{source}</pre>"
55                group = 0
56
57            node_id = id(node)
58
59            nodes[node_id] = {
60                "id": node_id,
61                "label": label,
62                "title": title,
63                "group": group,
64            }
65
66            for d in node.downstream:
67                edges.append({"from": node_id, "to": id(d)})
68        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
 71def lineage(
 72    column: str | exp.Column,
 73    sql: str | exp.Expression,
 74    schema: t.Optional[t.Dict | Schema] = None,
 75    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 76    dialect: DialectType = None,
 77    **kwargs,
 78) -> Node:
 79    """Build the lineage graph for a column of a SQL query.
 80
 81    Args:
 82        column: The column to build the lineage for.
 83        sql: The SQL string or expression.
 84        schema: The schema of tables.
 85        sources: A mapping of queries which will be used to continue building lineage.
 86        dialect: The dialect of input SQL.
 87        **kwargs: Qualification optimizer kwargs.
 88
 89    Returns:
 90        A lineage node.
 91    """
 92
 93    expression = maybe_parse(sql, dialect=dialect)
 94    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 95
 96    if sources:
 97        expression = exp.expand(
 98            expression,
 99            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
100            dialect=dialect,
101        )
102
103    qualified = qualify.qualify(
104        expression,
105        dialect=dialect,
106        schema=schema,
107        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
108    )
109
110    scope = build_scope(qualified)
111
112    if not scope:
113        raise SqlglotError("Cannot build lineage, sql must be SELECT")
114
115    if not any(select.alias_or_name == column for select in scope.expression.selects):
116        raise SqlglotError(f"Cannot find column '{column}' in query.")
117
118    return to_node(column, scope, dialect)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

def to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None) -> Node:
121def to_node(
122    column: str | int,
123    scope: Scope,
124    dialect: DialectType,
125    scope_name: t.Optional[str] = None,
126    upstream: t.Optional[Node] = None,
127    source_name: t.Optional[str] = None,
128    reference_node_name: t.Optional[str] = None,
129) -> Node:
130    source_names = {
131        dt.alias: dt.comments[0].split()[1]
132        for dt in scope.derived_tables
133        if dt.comments and dt.comments[0].startswith("source: ")
134    }
135
136    # Find the specific select clause that is the source of the column we want.
137    # This can either be a specific, named select or a generic `*` clause.
138    select = (
139        scope.expression.selects[column]
140        if isinstance(column, int)
141        else next(
142            (select for select in scope.expression.selects if select.alias_or_name == column),
143            exp.Star() if scope.expression.is_star else scope.expression,
144        )
145    )
146
147    if isinstance(scope.expression, exp.Union):
148        upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
149
150        index = (
151            column
152            if isinstance(column, int)
153            else next(
154                (
155                    i
156                    for i, select in enumerate(scope.expression.selects)
157                    if select.alias_or_name == column or select.is_star
158                ),
159                -1,  # mypy will not allow a None here, but a negative index should never be returned
160            )
161        )
162
163        if index == -1:
164            raise ValueError(f"Could not find {column} in {scope.expression}")
165
166        for s in scope.union_scopes:
167            to_node(
168                index,
169                scope=s,
170                dialect=dialect,
171                upstream=upstream,
172                source_name=source_name,
173                reference_node_name=reference_node_name,
174            )
175
176        return upstream
177
178    if isinstance(scope.expression, exp.Select):
179        # For better ergonomics in our node labels, replace the full select with
180        # a version that has only the column we care about.
181        #   "x", SELECT x, y FROM foo
182        #     => "x", SELECT x FROM foo
183        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
184    else:
185        source = scope.expression
186
187    # Create the node for this step in the lineage chain, and attach it to the previous one.
188    node = Node(
189        name=f"{scope_name}.{column}" if scope_name else str(column),
190        source=source,
191        expression=select,
192        source_name=source_name or "",
193        reference_node_name=reference_node_name or "",
194    )
195
196    if upstream:
197        upstream.downstream.append(node)
198
199    subquery_scopes = {
200        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
201    }
202
203    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
204        subquery_scope = subquery_scopes.get(id(subquery))
205        if not subquery_scope:
206            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
207            continue
208
209        for name in subquery.named_selects:
210            to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
211
212    # if the select is a star add all scope sources as downstreams
213    if select.is_star:
214        for source in scope.sources.values():
215            if isinstance(source, Scope):
216                source = source.expression
217            node.downstream.append(Node(name=select.sql(), source=source, expression=source))
218
219    # Find all columns that went into creating this one to list their lineage nodes.
220    source_columns = set(find_all_in_scope(select, exp.Column))
221
222    # If the source is a UDTF find columns used in the UTDF to generate the table
223    if isinstance(source, exp.UDTF):
224        source_columns |= set(source.find_all(exp.Column))
225
226    for c in source_columns:
227        table = c.table
228        source = scope.sources.get(table)
229
230        if isinstance(source, Scope):
231            selected_node, _ = scope.selected_sources.get(table, (None, None))
232            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
233            to_node(
234                c.name,
235                scope=source,
236                dialect=dialect,
237                scope_name=table,
238                upstream=node,
239                source_name=source_names.get(table) or source_name,
240                reference_node_name=selected_node.name if selected_node else None,
241            )
242        else:
243            # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
244            # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
245            # is not passed into the `sources` map.
246            source = source or exp.Placeholder()
247            node.downstream.append(Node(name=c.sql(), source=source, expression=source))
248
249    return node
class GraphHTML:
252class GraphHTML:
253    """Node to HTML generator using vis.js.
254
255    https://visjs.github.io/vis-network/docs/network/
256    """
257
258    def __init__(
259        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
260    ):
261        self.imports = imports
262
263        self.options = {
264            "height": "500px",
265            "width": "100%",
266            "layout": {
267                "hierarchical": {
268                    "enabled": True,
269                    "nodeSpacing": 200,
270                    "sortMethod": "directed",
271                },
272            },
273            "interaction": {
274                "dragNodes": False,
275                "selectable": False,
276            },
277            "physics": {
278                "enabled": False,
279            },
280            "edges": {
281                "arrows": "to",
282            },
283            "nodes": {
284                "font": "20px monaco",
285                "shape": "box",
286                "widthConstraint": {
287                    "maximum": 300,
288                },
289            },
290            **(options or {}),
291        }
292
293        self.nodes = nodes
294        self.edges = edges
295
296    def __str__(self):
297        nodes = json.dumps(list(self.nodes.values()))
298        edges = json.dumps(self.edges)
299        options = json.dumps(self.options)
300        imports = (
301            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
302  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
303  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
304            if self.imports
305            else ""
306        )
307
308        return f"""<div>
309  <div id="sqlglot-lineage"></div>
310  {imports}
311  <script type="text/javascript">
312    var nodes = new vis.DataSet({nodes})
313    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
314
315    new vis.Network(
316        document.getElementById("sqlglot-lineage"),
317        {{
318            nodes: nodes,
319            edges: new vis.DataSet({edges})
320        }},
321        {options},
322    )
323  </script>
324</div>"""
325
326    def _repr_html_(self) -> str:
327        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
258    def __init__(
259        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
260    ):
261        self.imports = imports
262
263        self.options = {
264            "height": "500px",
265            "width": "100%",
266            "layout": {
267                "hierarchical": {
268                    "enabled": True,
269                    "nodeSpacing": 200,
270                    "sortMethod": "directed",
271                },
272            },
273            "interaction": {
274                "dragNodes": False,
275                "selectable": False,
276            },
277            "physics": {
278                "enabled": False,
279            },
280            "edges": {
281                "arrows": "to",
282            },
283            "nodes": {
284                "font": "20px monaco",
285                "shape": "box",
286                "widthConstraint": {
287                    "maximum": 300,
288                },
289            },
290            **(options or {}),
291        }
292
293        self.nodes = nodes
294        self.edges = edges
imports
options
nodes
edges