Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.optimizer import Scope, build_scope, optimize
  9from sqlglot.optimizer.expand_laterals import expand_laterals
 10from sqlglot.optimizer.qualify_columns import qualify_columns
 11from sqlglot.optimizer.qualify_tables import qualify_tables
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dialects.dialect import DialectType
 15
 16
 17@dataclass(frozen=True)
 18class Node:
 19    name: str
 20    expression: exp.Expression
 21    source: exp.Expression
 22    downstream: t.List[Node] = field(default_factory=list)
 23    alias: str = ""
 24
 25    def walk(self) -> t.Iterator[Node]:
 26        yield self
 27
 28        for d in self.downstream:
 29            if isinstance(d, Node):
 30                yield from d.walk()
 31            else:
 32                yield d
 33
 34    def to_html(self, **opts) -> LineageHTML:
 35        return LineageHTML(self, **opts)
 36
 37
 38def lineage(
 39    column: str | exp.Column,
 40    sql: str | exp.Expression,
 41    schema: t.Optional[t.Dict | Schema] = None,
 42    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 43    rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
 44    dialect: DialectType = None,
 45) -> Node:
 46    """Build the lineage graph for a column of a SQL query.
 47
 48    Args:
 49        column: The column to build the lineage for.
 50        sql: The SQL string or expression.
 51        schema: The schema of tables.
 52        sources: A mapping of queries which will be used to continue building lineage.
 53        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 54        dialect: The dialect of input SQL.
 55
 56    Returns:
 57        A lineage node.
 58    """
 59
 60    expression = maybe_parse(sql, dialect=dialect)
 61
 62    if sources:
 63        expression = exp.expand(
 64            expression,
 65            {
 66                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 67                for k, v in sources.items()
 68            },
 69        )
 70
 71    optimized = optimize(expression, schema=schema, rules=rules)
 72    scope = build_scope(optimized)
 73
 74    def to_node(
 75        column_name: str,
 76        scope: Scope,
 77        scope_name: t.Optional[str] = None,
 78        upstream: t.Optional[Node] = None,
 79        alias: t.Optional[str] = None,
 80    ) -> Node:
 81        aliases = {
 82            dt.alias: dt.comments[0].split()[1]
 83            for dt in scope.derived_tables
 84            if dt.comments and dt.comments[0].startswith("source: ")
 85        }
 86        if isinstance(scope.expression, exp.Union):
 87            for scope in scope.union_scopes:
 88                node = to_node(
 89                    column_name,
 90                    scope=scope,
 91                    scope_name=scope_name,
 92                    upstream=upstream,
 93                    alias=aliases.get(scope_name),
 94                )
 95            return node
 96
 97        # Find the specific select clause that is the source of the column we want.
 98        # This can either be a specific, named select or a generic `*` clause.
 99        select = next(
100            (select for select in scope.selects if select.alias_or_name == column_name),
101            exp.Star() if scope.expression.is_star else None,
102        )
103
104        if not select:
105            raise ValueError(f"Could not find {column_name} in {scope.expression}")
106
107        if isinstance(scope.expression, exp.Select):
108            # For better ergonomics in our node labels, replace the full select with
109            # a version that has only the column we care about.
110            #   "x", SELECT x, y FROM foo
111            #     => "x", SELECT x FROM foo
112            source = optimize(
113                scope.expression.select(select, append=False), schema=schema, rules=rules
114            )
115            select = source.selects[0]
116        else:
117            source = scope.expression
118
119        # Create the node for this step in the lineage chain, and attach it to the previous one.
120        node = Node(
121            name=f"{scope_name}.{column_name}" if scope_name else column_name,
122            source=source,
123            expression=select,
124            alias=alias or "",
125        )
126        if upstream:
127            upstream.downstream.append(node)
128
129        # Find all columns that went into creating this one to list their lineage nodes.
130        for c in set(select.find_all(exp.Column)):
131            table = c.table
132            source = scope.sources.get(table)
133
134            if isinstance(source, Scope):
135                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
136                to_node(
137                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
138                )
139            else:
140                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
141                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
142                # is not passed into the `sources` map.
143                source = source or exp.Placeholder()
144                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
145
146        return node
147
148    return to_node(column if isinstance(column, str) else column.name, scope)
149
150
151class LineageHTML:
152    """Node to HTML generator using vis.js.
153
154    https://visjs.github.io/vis-network/docs/network/
155    """
156
157    def __init__(
158        self,
159        node: Node,
160        dialect: DialectType = None,
161        imports: bool = True,
162        **opts: t.Any,
163    ):
164        self.node = node
165        self.imports = imports
166
167        self.options = {
168            "height": "500px",
169            "width": "100%",
170            "layout": {
171                "hierarchical": {
172                    "enabled": True,
173                    "nodeSpacing": 200,
174                    "sortMethod": "directed",
175                },
176            },
177            "interaction": {
178                "dragNodes": False,
179                "selectable": False,
180            },
181            "physics": {
182                "enabled": False,
183            },
184            "edges": {
185                "arrows": "to",
186            },
187            "nodes": {
188                "font": "20px monaco",
189                "shape": "box",
190                "widthConstraint": {
191                    "maximum": 300,
192                },
193            },
194            **opts,
195        }
196
197        self.nodes = {}
198        self.edges = []
199
200        for node in node.walk():
201            if isinstance(node.expression, exp.Table):
202                label = f"FROM {node.expression.this}"
203                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
204                group = 1
205            else:
206                label = node.expression.sql(pretty=True, dialect=dialect)
207                source = node.source.transform(
208                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
209                    if n is node.expression
210                    else n,
211                    copy=False,
212                ).sql(pretty=True, dialect=dialect)
213                title = f"<pre>{source}</pre>"
214                group = 0
215
216            node_id = id(node)
217
218            self.nodes[node_id] = {
219                "id": node_id,
220                "label": label,
221                "title": title,
222                "group": group,
223            }
224
225            for d in node.downstream:
226                self.edges.append({"from": node_id, "to": id(d)})
227
228    def __str__(self):
229        nodes = json.dumps(list(self.nodes.values()))
230        edges = json.dumps(self.edges)
231        options = json.dumps(self.options)
232        imports = (
233            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
234  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
235  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
236            if self.imports
237            else ""
238        )
239
240        return f"""<div>
241  <div id="sqlglot-lineage"></div>
242  {imports}
243  <script type="text/javascript">
244    var nodes = new vis.DataSet({nodes})
245    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
246
247    new vis.Network(
248        document.getElementById("sqlglot-lineage"),
249        {{
250            nodes: nodes,
251            edges: new vis.DataSet({edges})
252        }},
253        {options},
254    )
255  </script>
256</div>"""
257
258    def _repr_html_(self) -> str:
259        return self.__str__()
@dataclass(frozen=True)
class Node:
18@dataclass(frozen=True)
19class Node:
20    name: str
21    expression: exp.Expression
22    source: exp.Expression
23    downstream: t.List[Node] = field(default_factory=list)
24    alias: str = ""
25
26    def walk(self) -> t.Iterator[Node]:
27        yield self
28
29        for d in self.downstream:
30            if isinstance(d, Node):
31                yield from d.walk()
32            else:
33                yield d
34
35    def to_html(self, **opts) -> LineageHTML:
36        return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[sqlglot.lineage.Node] = <factory>, alias: str = '')
def walk(self) -> Iterator[sqlglot.lineage.Node]:
26    def walk(self) -> t.Iterator[Node]:
27        yield self
28
29        for d in self.downstream:
30            if isinstance(d, Node):
31                yield from d.walk()
32            else:
33                yield d
def to_html(self, **opts) -> sqlglot.lineage.LineageHTML:
35    def to_html(self, **opts) -> LineageHTML:
36        return LineageHTML(self, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, rules: Sequence[Callable] = (<function qualify_tables at 0x7fc498defd00>, <function qualify_columns at 0x7fc498deeb90>, <function expand_laterals at 0x7fc498deef80>), dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.lineage.Node:
 39def lineage(
 40    column: str | exp.Column,
 41    sql: str | exp.Expression,
 42    schema: t.Optional[t.Dict | Schema] = None,
 43    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 44    rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
 45    dialect: DialectType = None,
 46) -> Node:
 47    """Build the lineage graph for a column of a SQL query.
 48
 49    Args:
 50        column: The column to build the lineage for.
 51        sql: The SQL string or expression.
 52        schema: The schema of tables.
 53        sources: A mapping of queries which will be used to continue building lineage.
 54        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 55        dialect: The dialect of input SQL.
 56
 57    Returns:
 58        A lineage node.
 59    """
 60
 61    expression = maybe_parse(sql, dialect=dialect)
 62
 63    if sources:
 64        expression = exp.expand(
 65            expression,
 66            {
 67                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 68                for k, v in sources.items()
 69            },
 70        )
 71
 72    optimized = optimize(expression, schema=schema, rules=rules)
 73    scope = build_scope(optimized)
 74
 75    def to_node(
 76        column_name: str,
 77        scope: Scope,
 78        scope_name: t.Optional[str] = None,
 79        upstream: t.Optional[Node] = None,
 80        alias: t.Optional[str] = None,
 81    ) -> Node:
 82        aliases = {
 83            dt.alias: dt.comments[0].split()[1]
 84            for dt in scope.derived_tables
 85            if dt.comments and dt.comments[0].startswith("source: ")
 86        }
 87        if isinstance(scope.expression, exp.Union):
 88            for scope in scope.union_scopes:
 89                node = to_node(
 90                    column_name,
 91                    scope=scope,
 92                    scope_name=scope_name,
 93                    upstream=upstream,
 94                    alias=aliases.get(scope_name),
 95                )
 96            return node
 97
 98        # Find the specific select clause that is the source of the column we want.
 99        # This can either be a specific, named select or a generic `*` clause.
100        select = next(
101            (select for select in scope.selects if select.alias_or_name == column_name),
102            exp.Star() if scope.expression.is_star else None,
103        )
104
105        if not select:
106            raise ValueError(f"Could not find {column_name} in {scope.expression}")
107
108        if isinstance(scope.expression, exp.Select):
109            # For better ergonomics in our node labels, replace the full select with
110            # a version that has only the column we care about.
111            #   "x", SELECT x, y FROM foo
112            #     => "x", SELECT x FROM foo
113            source = optimize(
114                scope.expression.select(select, append=False), schema=schema, rules=rules
115            )
116            select = source.selects[0]
117        else:
118            source = scope.expression
119
120        # Create the node for this step in the lineage chain, and attach it to the previous one.
121        node = Node(
122            name=f"{scope_name}.{column_name}" if scope_name else column_name,
123            source=source,
124            expression=select,
125            alias=alias or "",
126        )
127        if upstream:
128            upstream.downstream.append(node)
129
130        # Find all columns that went into creating this one to list their lineage nodes.
131        for c in set(select.find_all(exp.Column)):
132            table = c.table
133            source = scope.sources.get(table)
134
135            if isinstance(source, Scope):
136                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
137                to_node(
138                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
139                )
140            else:
141                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
142                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
143                # is not passed into the `sources` map.
144                source = source or exp.Placeholder()
145                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
146
147        return node
148
149    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • rules: Optimizer rules to apply, by default only qualifying tables and columns.
  • dialect: The dialect of input SQL.
Returns:

A lineage node.

class LineageHTML:
152class LineageHTML:
153    """Node to HTML generator using vis.js.
154
155    https://visjs.github.io/vis-network/docs/network/
156    """
157
158    def __init__(
159        self,
160        node: Node,
161        dialect: DialectType = None,
162        imports: bool = True,
163        **opts: t.Any,
164    ):
165        self.node = node
166        self.imports = imports
167
168        self.options = {
169            "height": "500px",
170            "width": "100%",
171            "layout": {
172                "hierarchical": {
173                    "enabled": True,
174                    "nodeSpacing": 200,
175                    "sortMethod": "directed",
176                },
177            },
178            "interaction": {
179                "dragNodes": False,
180                "selectable": False,
181            },
182            "physics": {
183                "enabled": False,
184            },
185            "edges": {
186                "arrows": "to",
187            },
188            "nodes": {
189                "font": "20px monaco",
190                "shape": "box",
191                "widthConstraint": {
192                    "maximum": 300,
193                },
194            },
195            **opts,
196        }
197
198        self.nodes = {}
199        self.edges = []
200
201        for node in node.walk():
202            if isinstance(node.expression, exp.Table):
203                label = f"FROM {node.expression.this}"
204                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
205                group = 1
206            else:
207                label = node.expression.sql(pretty=True, dialect=dialect)
208                source = node.source.transform(
209                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
210                    if n is node.expression
211                    else n,
212                    copy=False,
213                ).sql(pretty=True, dialect=dialect)
214                title = f"<pre>{source}</pre>"
215                group = 0
216
217            node_id = id(node)
218
219            self.nodes[node_id] = {
220                "id": node_id,
221                "label": label,
222                "title": title,
223                "group": group,
224            }
225
226            for d in node.downstream:
227                self.edges.append({"from": node_id, "to": id(d)})
228
229    def __str__(self):
230        nodes = json.dumps(list(self.nodes.values()))
231        edges = json.dumps(self.edges)
232        options = json.dumps(self.options)
233        imports = (
234            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
235  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
236  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
237            if self.imports
238            else ""
239        )
240
241        return f"""<div>
242  <div id="sqlglot-lineage"></div>
243  {imports}
244  <script type="text/javascript">
245    var nodes = new vis.DataSet({nodes})
246    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
247
248    new vis.Network(
249        document.getElementById("sqlglot-lineage"),
250        {{
251            nodes: nodes,
252            edges: new vis.DataSet({edges})
253        }},
254        {options},
255    )
256  </script>
257</div>"""
258
259    def _repr_html_(self) -> str:
260        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

LineageHTML( node: sqlglot.lineage.Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
158    def __init__(
159        self,
160        node: Node,
161        dialect: DialectType = None,
162        imports: bool = True,
163        **opts: t.Any,
164    ):
165        self.node = node
166        self.imports = imports
167
168        self.options = {
169            "height": "500px",
170            "width": "100%",
171            "layout": {
172                "hierarchical": {
173                    "enabled": True,
174                    "nodeSpacing": 200,
175                    "sortMethod": "directed",
176                },
177            },
178            "interaction": {
179                "dragNodes": False,
180                "selectable": False,
181            },
182            "physics": {
183                "enabled": False,
184            },
185            "edges": {
186                "arrows": "to",
187            },
188            "nodes": {
189                "font": "20px monaco",
190                "shape": "box",
191                "widthConstraint": {
192                    "maximum": 300,
193                },
194            },
195            **opts,
196        }
197
198        self.nodes = {}
199        self.edges = []
200
201        for node in node.walk():
202            if isinstance(node.expression, exp.Table):
203                label = f"FROM {node.expression.this}"
204                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
205                group = 1
206            else:
207                label = node.expression.sql(pretty=True, dialect=dialect)
208                source = node.source.transform(
209                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
210                    if n is node.expression
211                    else n,
212                    copy=False,
213                ).sql(pretty=True, dialect=dialect)
214                title = f"<pre>{source}</pre>"
215                group = 0
216
217            node_id = id(node)
218
219            self.nodes[node_id] = {
220                "id": node_id,
221                "label": label,
222                "title": title,
223                "group": group,
224            }
225
226            for d in node.downstream:
227                self.edges.append({"from": node_id, "to": id(d)})