Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.errors import SqlglotError
  9from sqlglot.optimizer import Scope, build_scope, qualify
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot.dialects.dialect import DialectType
 13
 14
 15@dataclass(frozen=True)
 16class Node:
 17    name: str
 18    expression: exp.Expression
 19    source: exp.Expression
 20    downstream: t.List[Node] = field(default_factory=list)
 21    alias: str = ""
 22
 23    def walk(self) -> t.Iterator[Node]:
 24        yield self
 25
 26        for d in self.downstream:
 27            if isinstance(d, Node):
 28                yield from d.walk()
 29            else:
 30                yield d
 31
 32    def to_html(self, **opts) -> LineageHTML:
 33        return LineageHTML(self, **opts)
 34
 35
 36def lineage(
 37    column: str | exp.Column,
 38    sql: str | exp.Expression,
 39    schema: t.Optional[t.Dict | Schema] = None,
 40    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 41    dialect: DialectType = None,
 42    **kwargs,
 43) -> Node:
 44    """Build the lineage graph for a column of a SQL query.
 45
 46    Args:
 47        column: The column to build the lineage for.
 48        sql: The SQL string or expression.
 49        schema: The schema of tables.
 50        sources: A mapping of queries which will be used to continue building lineage.
 51        dialect: The dialect of input SQL.
 52        **kwargs: Qualification optimizer kwargs.
 53
 54    Returns:
 55        A lineage node.
 56    """
 57
 58    expression = maybe_parse(sql, dialect=dialect)
 59
 60    if sources:
 61        expression = exp.expand(
 62            expression,
 63            {
 64                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 65                for k, v in sources.items()
 66            },
 67        )
 68
 69    qualified = qualify.qualify(
 70        expression,
 71        dialect=dialect,
 72        schema=schema,
 73        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
 74    )
 75
 76    scope = build_scope(qualified)
 77
 78    if not scope:
 79        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 80
 81    def to_node(
 82        column: str | int,
 83        scope: Scope,
 84        scope_name: t.Optional[str] = None,
 85        upstream: t.Optional[Node] = None,
 86        alias: t.Optional[str] = None,
 87    ) -> Node:
 88        aliases = {
 89            dt.alias: dt.comments[0].split()[1]
 90            for dt in scope.derived_tables
 91            if dt.comments and dt.comments[0].startswith("source: ")
 92        }
 93
 94        # Find the specific select clause that is the source of the column we want.
 95        # This can either be a specific, named select or a generic `*` clause.
 96        select = (
 97            scope.expression.selects[column]
 98            if isinstance(column, int)
 99            else next(
100                (select for select in scope.expression.selects if select.alias_or_name == column),
101                exp.Star() if scope.expression.is_star else None,
102            )
103        )
104
105        if not select:
106            raise ValueError(f"Could not find {column} in {scope.expression}")
107
108        if isinstance(scope.expression, exp.Union):
109            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
110
111            index = (
112                column
113                if isinstance(column, int)
114                else next(
115                    i
116                    for i, select in enumerate(scope.expression.selects)
117                    if select.alias_or_name == column
118                )
119            )
120
121            for s in scope.union_scopes:
122                to_node(index, scope=s, upstream=upstream)
123
124            return upstream
125
126        subquery = select.unalias()
127
128        if isinstance(subquery, exp.Subquery):
129            upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
130            scope = t.cast(Scope, build_scope(subquery.unnest()))
131
132            for select in subquery.named_selects:
133                to_node(select, scope=scope, upstream=upstream)
134
135            return upstream
136
137        if isinstance(scope.expression, exp.Select):
138            # For better ergonomics in our node labels, replace the full select with
139            # a version that has only the column we care about.
140            #   "x", SELECT x, y FROM foo
141            #     => "x", SELECT x FROM foo
142            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
143        else:
144            source = scope.expression
145
146        # Create the node for this step in the lineage chain, and attach it to the previous one.
147        node = Node(
148            name=f"{scope_name}.{column}" if scope_name else str(column),
149            source=source,
150            expression=select,
151            alias=alias or "",
152        )
153        if upstream:
154            upstream.downstream.append(node)
155
156        # Find all columns that went into creating this one to list their lineage nodes.
157        source_columns = set(select.find_all(exp.Column))
158
159        # If the source is a UDTF find columns used in the UTDF to generate the table
160        if isinstance(source, exp.UDTF):
161            source_columns |= set(source.find_all(exp.Column))
162
163        for c in source_columns:
164            table = c.table
165            source = scope.sources.get(table)
166
167            if isinstance(source, Scope):
168                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
169                to_node(
170                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
171                )
172            else:
173                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
174                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
175                # is not passed into the `sources` map.
176                source = source or exp.Placeholder()
177                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
178
179        return node
180
181    return to_node(column if isinstance(column, str) else column.name, scope)
182
183
184class LineageHTML:
185    """Node to HTML generator using vis.js.
186
187    https://visjs.github.io/vis-network/docs/network/
188    """
189
190    def __init__(
191        self,
192        node: Node,
193        dialect: DialectType = None,
194        imports: bool = True,
195        **opts: t.Any,
196    ):
197        self.node = node
198        self.imports = imports
199
200        self.options = {
201            "height": "500px",
202            "width": "100%",
203            "layout": {
204                "hierarchical": {
205                    "enabled": True,
206                    "nodeSpacing": 200,
207                    "sortMethod": "directed",
208                },
209            },
210            "interaction": {
211                "dragNodes": False,
212                "selectable": False,
213            },
214            "physics": {
215                "enabled": False,
216            },
217            "edges": {
218                "arrows": "to",
219            },
220            "nodes": {
221                "font": "20px monaco",
222                "shape": "box",
223                "widthConstraint": {
224                    "maximum": 300,
225                },
226            },
227            **opts,
228        }
229
230        self.nodes = {}
231        self.edges = []
232
233        for node in node.walk():
234            if isinstance(node.expression, exp.Table):
235                label = f"FROM {node.expression.this}"
236                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
237                group = 1
238            else:
239                label = node.expression.sql(pretty=True, dialect=dialect)
240                source = node.source.transform(
241                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
242                    if n is node.expression
243                    else n,
244                    copy=False,
245                ).sql(pretty=True, dialect=dialect)
246                title = f"<pre>{source}</pre>"
247                group = 0
248
249            node_id = id(node)
250
251            self.nodes[node_id] = {
252                "id": node_id,
253                "label": label,
254                "title": title,
255                "group": group,
256            }
257
258            for d in node.downstream:
259                self.edges.append({"from": node_id, "to": id(d)})
260
261    def __str__(self):
262        nodes = json.dumps(list(self.nodes.values()))
263        edges = json.dumps(self.edges)
264        options = json.dumps(self.options)
265        imports = (
266            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
267  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
268  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
269            if self.imports
270            else ""
271        )
272
273        return f"""<div>
274  <div id="sqlglot-lineage"></div>
275  {imports}
276  <script type="text/javascript">
277    var nodes = new vis.DataSet({nodes})
278    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
279
280    new vis.Network(
281        document.getElementById("sqlglot-lineage"),
282        {{
283            nodes: nodes,
284            edges: new vis.DataSet({edges})
285        }},
286        {options},
287    )
288  </script>
289</div>"""
290
291    def _repr_html_(self) -> str:
292        return self.__str__()
@dataclass(frozen=True)
class Node:
16@dataclass(frozen=True)
17class Node:
18    name: str
19    expression: exp.Expression
20    source: exp.Expression
21    downstream: t.List[Node] = field(default_factory=list)
22    alias: str = ""
23
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
32
33    def to_html(self, **opts) -> LineageHTML:
34        return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, alias: str = '')
name: str
downstream: List[Node]
alias: str = ''
def walk(self) -> Iterator[Node]:
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
def to_html(self, **opts) -> LineageHTML:
33    def to_html(self, **opts) -> LineageHTML:
34        return LineageHTML(self, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
 37def lineage(
 38    column: str | exp.Column,
 39    sql: str | exp.Expression,
 40    schema: t.Optional[t.Dict | Schema] = None,
 41    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 42    dialect: DialectType = None,
 43    **kwargs,
 44) -> Node:
 45    """Build the lineage graph for a column of a SQL query.
 46
 47    Args:
 48        column: The column to build the lineage for.
 49        sql: The SQL string or expression.
 50        schema: The schema of tables.
 51        sources: A mapping of queries which will be used to continue building lineage.
 52        dialect: The dialect of input SQL.
 53        **kwargs: Qualification optimizer kwargs.
 54
 55    Returns:
 56        A lineage node.
 57    """
 58
 59    expression = maybe_parse(sql, dialect=dialect)
 60
 61    if sources:
 62        expression = exp.expand(
 63            expression,
 64            {
 65                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 66                for k, v in sources.items()
 67            },
 68        )
 69
 70    qualified = qualify.qualify(
 71        expression,
 72        dialect=dialect,
 73        schema=schema,
 74        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
 75    )
 76
 77    scope = build_scope(qualified)
 78
 79    if not scope:
 80        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 81
 82    def to_node(
 83        column: str | int,
 84        scope: Scope,
 85        scope_name: t.Optional[str] = None,
 86        upstream: t.Optional[Node] = None,
 87        alias: t.Optional[str] = None,
 88    ) -> Node:
 89        aliases = {
 90            dt.alias: dt.comments[0].split()[1]
 91            for dt in scope.derived_tables
 92            if dt.comments and dt.comments[0].startswith("source: ")
 93        }
 94
 95        # Find the specific select clause that is the source of the column we want.
 96        # This can either be a specific, named select or a generic `*` clause.
 97        select = (
 98            scope.expression.selects[column]
 99            if isinstance(column, int)
100            else next(
101                (select for select in scope.expression.selects if select.alias_or_name == column),
102                exp.Star() if scope.expression.is_star else None,
103            )
104        )
105
106        if not select:
107            raise ValueError(f"Could not find {column} in {scope.expression}")
108
109        if isinstance(scope.expression, exp.Union):
110            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
111
112            index = (
113                column
114                if isinstance(column, int)
115                else next(
116                    i
117                    for i, select in enumerate(scope.expression.selects)
118                    if select.alias_or_name == column
119                )
120            )
121
122            for s in scope.union_scopes:
123                to_node(index, scope=s, upstream=upstream)
124
125            return upstream
126
127        subquery = select.unalias()
128
129        if isinstance(subquery, exp.Subquery):
130            upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
131            scope = t.cast(Scope, build_scope(subquery.unnest()))
132
133            for select in subquery.named_selects:
134                to_node(select, scope=scope, upstream=upstream)
135
136            return upstream
137
138        if isinstance(scope.expression, exp.Select):
139            # For better ergonomics in our node labels, replace the full select with
140            # a version that has only the column we care about.
141            #   "x", SELECT x, y FROM foo
142            #     => "x", SELECT x FROM foo
143            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
144        else:
145            source = scope.expression
146
147        # Create the node for this step in the lineage chain, and attach it to the previous one.
148        node = Node(
149            name=f"{scope_name}.{column}" if scope_name else str(column),
150            source=source,
151            expression=select,
152            alias=alias or "",
153        )
154        if upstream:
155            upstream.downstream.append(node)
156
157        # Find all columns that went into creating this one to list their lineage nodes.
158        source_columns = set(select.find_all(exp.Column))
159
160        # If the source is a UDTF find columns used in the UTDF to generate the table
161        if isinstance(source, exp.UDTF):
162            source_columns |= set(source.find_all(exp.Column))
163
164        for c in source_columns:
165            table = c.table
166            source = scope.sources.get(table)
167
168            if isinstance(source, Scope):
169                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
170                to_node(
171                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
172                )
173            else:
174                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
175                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
176                # is not passed into the `sources` map.
177                source = source or exp.Placeholder()
178                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
179
180        return node
181
182    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

class LineageHTML:
185class LineageHTML:
186    """Node to HTML generator using vis.js.
187
188    https://visjs.github.io/vis-network/docs/network/
189    """
190
191    def __init__(
192        self,
193        node: Node,
194        dialect: DialectType = None,
195        imports: bool = True,
196        **opts: t.Any,
197    ):
198        self.node = node
199        self.imports = imports
200
201        self.options = {
202            "height": "500px",
203            "width": "100%",
204            "layout": {
205                "hierarchical": {
206                    "enabled": True,
207                    "nodeSpacing": 200,
208                    "sortMethod": "directed",
209                },
210            },
211            "interaction": {
212                "dragNodes": False,
213                "selectable": False,
214            },
215            "physics": {
216                "enabled": False,
217            },
218            "edges": {
219                "arrows": "to",
220            },
221            "nodes": {
222                "font": "20px monaco",
223                "shape": "box",
224                "widthConstraint": {
225                    "maximum": 300,
226                },
227            },
228            **opts,
229        }
230
231        self.nodes = {}
232        self.edges = []
233
234        for node in node.walk():
235            if isinstance(node.expression, exp.Table):
236                label = f"FROM {node.expression.this}"
237                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
238                group = 1
239            else:
240                label = node.expression.sql(pretty=True, dialect=dialect)
241                source = node.source.transform(
242                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
243                    if n is node.expression
244                    else n,
245                    copy=False,
246                ).sql(pretty=True, dialect=dialect)
247                title = f"<pre>{source}</pre>"
248                group = 0
249
250            node_id = id(node)
251
252            self.nodes[node_id] = {
253                "id": node_id,
254                "label": label,
255                "title": title,
256                "group": group,
257            }
258
259            for d in node.downstream:
260                self.edges.append({"from": node_id, "to": id(d)})
261
262    def __str__(self):
263        nodes = json.dumps(list(self.nodes.values()))
264        edges = json.dumps(self.edges)
265        options = json.dumps(self.options)
266        imports = (
267            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
268  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
269  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
270            if self.imports
271            else ""
272        )
273
274        return f"""<div>
275  <div id="sqlglot-lineage"></div>
276  {imports}
277  <script type="text/javascript">
278    var nodes = new vis.DataSet({nodes})
279    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
280
281    new vis.Network(
282        document.getElementById("sqlglot-lineage"),
283        {{
284            nodes: nodes,
285            edges: new vis.DataSet({edges})
286        }},
287        {options},
288    )
289  </script>
290</div>"""
291
292    def _repr_html_(self) -> str:
293        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

LineageHTML( node: Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
191    def __init__(
192        self,
193        node: Node,
194        dialect: DialectType = None,
195        imports: bool = True,
196        **opts: t.Any,
197    ):
198        self.node = node
199        self.imports = imports
200
201        self.options = {
202            "height": "500px",
203            "width": "100%",
204            "layout": {
205                "hierarchical": {
206                    "enabled": True,
207                    "nodeSpacing": 200,
208                    "sortMethod": "directed",
209                },
210            },
211            "interaction": {
212                "dragNodes": False,
213                "selectable": False,
214            },
215            "physics": {
216                "enabled": False,
217            },
218            "edges": {
219                "arrows": "to",
220            },
221            "nodes": {
222                "font": "20px monaco",
223                "shape": "box",
224                "widthConstraint": {
225                    "maximum": 300,
226                },
227            },
228            **opts,
229        }
230
231        self.nodes = {}
232        self.edges = []
233
234        for node in node.walk():
235            if isinstance(node.expression, exp.Table):
236                label = f"FROM {node.expression.this}"
237                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
238                group = 1
239            else:
240                label = node.expression.sql(pretty=True, dialect=dialect)
241                source = node.source.transform(
242                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
243                    if n is node.expression
244                    else n,
245                    copy=False,
246                ).sql(pretty=True, dialect=dialect)
247                title = f"<pre>{source}</pre>"
248                group = 0
249
250            node_id = id(node)
251
252            self.nodes[node_id] = {
253                "id": node_id,
254                "label": label,
255                "title": title,
256                "group": group,
257            }
258
259            for d in node.downstream:
260                self.edges.append({"from": node_id, "to": id(d)})
node
imports
options
nodes
edges