Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.errors import SqlglotError
  9from sqlglot.optimizer import Scope, build_scope, optimize
 10from sqlglot.optimizer.lower_identities import lower_identities
 11from sqlglot.optimizer.qualify_columns import qualify_columns
 12from sqlglot.optimizer.qualify_tables import qualify_tables
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot.dialects.dialect import DialectType
 16
 17
 18@dataclass(frozen=True)
 19class Node:
 20    name: str
 21    expression: exp.Expression
 22    source: exp.Expression
 23    downstream: t.List[Node] = field(default_factory=list)
 24    alias: str = ""
 25
 26    def walk(self) -> t.Iterator[Node]:
 27        yield self
 28
 29        for d in self.downstream:
 30            if isinstance(d, Node):
 31                yield from d.walk()
 32            else:
 33                yield d
 34
 35    def to_html(self, **opts) -> LineageHTML:
 36        return LineageHTML(self, **opts)
 37
 38
 39def lineage(
 40    column: str | exp.Column,
 41    sql: str | exp.Expression,
 42    schema: t.Optional[t.Dict | Schema] = None,
 43    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 44    rules: t.Sequence[t.Callable] = (lower_identities, qualify_tables, qualify_columns),
 45    dialect: DialectType = None,
 46) -> Node:
 47    """Build the lineage graph for a column of a SQL query.
 48
 49    Args:
 50        column: The column to build the lineage for.
 51        sql: The SQL string or expression.
 52        schema: The schema of tables.
 53        sources: A mapping of queries which will be used to continue building lineage.
 54        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 55        dialect: The dialect of input SQL.
 56
 57    Returns:
 58        A lineage node.
 59    """
 60
 61    expression = maybe_parse(sql, dialect=dialect)
 62
 63    if sources:
 64        expression = exp.expand(
 65            expression,
 66            {
 67                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 68                for k, v in sources.items()
 69            },
 70        )
 71
 72    optimized = optimize(expression, schema=schema, rules=rules)
 73    scope = build_scope(optimized)
 74
 75    if not scope:
 76        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 77
 78    def to_node(
 79        column_name: str,
 80        scope: Scope,
 81        scope_name: t.Optional[str] = None,
 82        upstream: t.Optional[Node] = None,
 83        alias: t.Optional[str] = None,
 84    ) -> Node:
 85        aliases = {
 86            dt.alias: dt.comments[0].split()[1]
 87            for dt in scope.derived_tables
 88            if dt.comments and dt.comments[0].startswith("source: ")
 89        }
 90        if isinstance(scope.expression, exp.Union):
 91            for scope in scope.union_scopes:
 92                node = to_node(
 93                    column_name,
 94                    scope=scope,
 95                    scope_name=scope_name,
 96                    upstream=upstream,
 97                    alias=aliases.get(scope_name),
 98                )
 99            return node
100
101        # Find the specific select clause that is the source of the column we want.
102        # This can either be a specific, named select or a generic `*` clause.
103        select = next(
104            (select for select in scope.selects if select.alias_or_name == column_name),
105            exp.Star() if scope.expression.is_star else None,
106        )
107
108        if not select:
109            raise ValueError(f"Could not find {column_name} in {scope.expression}")
110
111        if isinstance(scope.expression, exp.Select):
112            # For better ergonomics in our node labels, replace the full select with
113            # a version that has only the column we care about.
114            #   "x", SELECT x, y FROM foo
115            #     => "x", SELECT x FROM foo
116            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
117        else:
118            source = scope.expression
119
120        # Create the node for this step in the lineage chain, and attach it to the previous one.
121        node = Node(
122            name=f"{scope_name}.{column_name}" if scope_name else column_name,
123            source=source,
124            expression=select,
125            alias=alias or "",
126        )
127        if upstream:
128            upstream.downstream.append(node)
129
130        # Find all columns that went into creating this one to list their lineage nodes.
131        for c in set(select.find_all(exp.Column)):
132            table = c.table
133            source = scope.sources.get(table)
134
135            if isinstance(source, Scope):
136                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
137                to_node(
138                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
139                )
140            else:
141                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
142                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
143                # is not passed into the `sources` map.
144                source = source or exp.Placeholder()
145                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
146
147        return node
148
149    return to_node(column if isinstance(column, str) else column.name, scope)
150
151
152class LineageHTML:
153    """Node to HTML generator using vis.js.
154
155    https://visjs.github.io/vis-network/docs/network/
156    """
157
158    def __init__(
159        self,
160        node: Node,
161        dialect: DialectType = None,
162        imports: bool = True,
163        **opts: t.Any,
164    ):
165        self.node = node
166        self.imports = imports
167
168        self.options = {
169            "height": "500px",
170            "width": "100%",
171            "layout": {
172                "hierarchical": {
173                    "enabled": True,
174                    "nodeSpacing": 200,
175                    "sortMethod": "directed",
176                },
177            },
178            "interaction": {
179                "dragNodes": False,
180                "selectable": False,
181            },
182            "physics": {
183                "enabled": False,
184            },
185            "edges": {
186                "arrows": "to",
187            },
188            "nodes": {
189                "font": "20px monaco",
190                "shape": "box",
191                "widthConstraint": {
192                    "maximum": 300,
193                },
194            },
195            **opts,
196        }
197
198        self.nodes = {}
199        self.edges = []
200
201        for node in node.walk():
202            if isinstance(node.expression, exp.Table):
203                label = f"FROM {node.expression.this}"
204                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
205                group = 1
206            else:
207                label = node.expression.sql(pretty=True, dialect=dialect)
208                source = node.source.transform(
209                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
210                    if n is node.expression
211                    else n,
212                    copy=False,
213                ).sql(pretty=True, dialect=dialect)
214                title = f"<pre>{source}</pre>"
215                group = 0
216
217            node_id = id(node)
218
219            self.nodes[node_id] = {
220                "id": node_id,
221                "label": label,
222                "title": title,
223                "group": group,
224            }
225
226            for d in node.downstream:
227                self.edges.append({"from": node_id, "to": id(d)})
228
229    def __str__(self):
230        nodes = json.dumps(list(self.nodes.values()))
231        edges = json.dumps(self.edges)
232        options = json.dumps(self.options)
233        imports = (
234            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
235  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
236  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
237            if self.imports
238            else ""
239        )
240
241        return f"""<div>
242  <div id="sqlglot-lineage"></div>
243  {imports}
244  <script type="text/javascript">
245    var nodes = new vis.DataSet({nodes})
246    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
247
248    new vis.Network(
249        document.getElementById("sqlglot-lineage"),
250        {{
251            nodes: nodes,
252            edges: new vis.DataSet({edges})
253        }},
254        {options},
255    )
256  </script>
257</div>"""
258
259    def _repr_html_(self) -> str:
260        return self.__str__()
@dataclass(frozen=True)
class Node:
19@dataclass(frozen=True)
20class Node:
21    name: str
22    expression: exp.Expression
23    source: exp.Expression
24    downstream: t.List[Node] = field(default_factory=list)
25    alias: str = ""
26
27    def walk(self) -> t.Iterator[Node]:
28        yield self
29
30        for d in self.downstream:
31            if isinstance(d, Node):
32                yield from d.walk()
33            else:
34                yield d
35
36    def to_html(self, **opts) -> LineageHTML:
37        return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[sqlglot.lineage.Node] = <factory>, alias: str = '')
def walk(self) -> Iterator[sqlglot.lineage.Node]:
27    def walk(self) -> t.Iterator[Node]:
28        yield self
29
30        for d in self.downstream:
31            if isinstance(d, Node):
32                yield from d.walk()
33            else:
34                yield d
def to_html(self, **opts) -> sqlglot.lineage.LineageHTML:
36    def to_html(self, **opts) -> LineageHTML:
37        return LineageHTML(self, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, rules: Sequence[Callable] = (<function lower_identities at 0x7f0af2fc76d0>, <function qualify_tables at 0x7f0af2fa1d80>, <function qualify_columns at 0x7f0af2fa0c10>), dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.lineage.Node:
 40def lineage(
 41    column: str | exp.Column,
 42    sql: str | exp.Expression,
 43    schema: t.Optional[t.Dict | Schema] = None,
 44    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 45    rules: t.Sequence[t.Callable] = (lower_identities, qualify_tables, qualify_columns),
 46    dialect: DialectType = None,
 47) -> Node:
 48    """Build the lineage graph for a column of a SQL query.
 49
 50    Args:
 51        column: The column to build the lineage for.
 52        sql: The SQL string or expression.
 53        schema: The schema of tables.
 54        sources: A mapping of queries which will be used to continue building lineage.
 55        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 56        dialect: The dialect of input SQL.
 57
 58    Returns:
 59        A lineage node.
 60    """
 61
 62    expression = maybe_parse(sql, dialect=dialect)
 63
 64    if sources:
 65        expression = exp.expand(
 66            expression,
 67            {
 68                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 69                for k, v in sources.items()
 70            },
 71        )
 72
 73    optimized = optimize(expression, schema=schema, rules=rules)
 74    scope = build_scope(optimized)
 75
 76    if not scope:
 77        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 78
 79    def to_node(
 80        column_name: str,
 81        scope: Scope,
 82        scope_name: t.Optional[str] = None,
 83        upstream: t.Optional[Node] = None,
 84        alias: t.Optional[str] = None,
 85    ) -> Node:
 86        aliases = {
 87            dt.alias: dt.comments[0].split()[1]
 88            for dt in scope.derived_tables
 89            if dt.comments and dt.comments[0].startswith("source: ")
 90        }
 91        if isinstance(scope.expression, exp.Union):
 92            for scope in scope.union_scopes:
 93                node = to_node(
 94                    column_name,
 95                    scope=scope,
 96                    scope_name=scope_name,
 97                    upstream=upstream,
 98                    alias=aliases.get(scope_name),
 99                )
100            return node
101
102        # Find the specific select clause that is the source of the column we want.
103        # This can either be a specific, named select or a generic `*` clause.
104        select = next(
105            (select for select in scope.selects if select.alias_or_name == column_name),
106            exp.Star() if scope.expression.is_star else None,
107        )
108
109        if not select:
110            raise ValueError(f"Could not find {column_name} in {scope.expression}")
111
112        if isinstance(scope.expression, exp.Select):
113            # For better ergonomics in our node labels, replace the full select with
114            # a version that has only the column we care about.
115            #   "x", SELECT x, y FROM foo
116            #     => "x", SELECT x FROM foo
117            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
118        else:
119            source = scope.expression
120
121        # Create the node for this step in the lineage chain, and attach it to the previous one.
122        node = Node(
123            name=f"{scope_name}.{column_name}" if scope_name else column_name,
124            source=source,
125            expression=select,
126            alias=alias or "",
127        )
128        if upstream:
129            upstream.downstream.append(node)
130
131        # Find all columns that went into creating this one to list their lineage nodes.
132        for c in set(select.find_all(exp.Column)):
133            table = c.table
134            source = scope.sources.get(table)
135
136            if isinstance(source, Scope):
137                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
138                to_node(
139                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
140                )
141            else:
142                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
143                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
144                # is not passed into the `sources` map.
145                source = source or exp.Placeholder()
146                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
147
148        return node
149
150    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • rules: Optimizer rules to apply, by default only qualifying tables and columns.
  • dialect: The dialect of input SQL.
Returns:

A lineage node.

class LineageHTML:
153class LineageHTML:
154    """Node to HTML generator using vis.js.
155
156    https://visjs.github.io/vis-network/docs/network/
157    """
158
159    def __init__(
160        self,
161        node: Node,
162        dialect: DialectType = None,
163        imports: bool = True,
164        **opts: t.Any,
165    ):
166        self.node = node
167        self.imports = imports
168
169        self.options = {
170            "height": "500px",
171            "width": "100%",
172            "layout": {
173                "hierarchical": {
174                    "enabled": True,
175                    "nodeSpacing": 200,
176                    "sortMethod": "directed",
177                },
178            },
179            "interaction": {
180                "dragNodes": False,
181                "selectable": False,
182            },
183            "physics": {
184                "enabled": False,
185            },
186            "edges": {
187                "arrows": "to",
188            },
189            "nodes": {
190                "font": "20px monaco",
191                "shape": "box",
192                "widthConstraint": {
193                    "maximum": 300,
194                },
195            },
196            **opts,
197        }
198
199        self.nodes = {}
200        self.edges = []
201
202        for node in node.walk():
203            if isinstance(node.expression, exp.Table):
204                label = f"FROM {node.expression.this}"
205                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
206                group = 1
207            else:
208                label = node.expression.sql(pretty=True, dialect=dialect)
209                source = node.source.transform(
210                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
211                    if n is node.expression
212                    else n,
213                    copy=False,
214                ).sql(pretty=True, dialect=dialect)
215                title = f"<pre>{source}</pre>"
216                group = 0
217
218            node_id = id(node)
219
220            self.nodes[node_id] = {
221                "id": node_id,
222                "label": label,
223                "title": title,
224                "group": group,
225            }
226
227            for d in node.downstream:
228                self.edges.append({"from": node_id, "to": id(d)})
229
230    def __str__(self):
231        nodes = json.dumps(list(self.nodes.values()))
232        edges = json.dumps(self.edges)
233        options = json.dumps(self.options)
234        imports = (
235            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
236  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
237  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
238            if self.imports
239            else ""
240        )
241
242        return f"""<div>
243  <div id="sqlglot-lineage"></div>
244  {imports}
245  <script type="text/javascript">
246    var nodes = new vis.DataSet({nodes})
247    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
248
249    new vis.Network(
250        document.getElementById("sqlglot-lineage"),
251        {{
252            nodes: nodes,
253            edges: new vis.DataSet({edges})
254        }},
255        {options},
256    )
257  </script>
258</div>"""
259
260    def _repr_html_(self) -> str:
261        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

LineageHTML( node: sqlglot.lineage.Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
159    def __init__(
160        self,
161        node: Node,
162        dialect: DialectType = None,
163        imports: bool = True,
164        **opts: t.Any,
165    ):
166        self.node = node
167        self.imports = imports
168
169        self.options = {
170            "height": "500px",
171            "width": "100%",
172            "layout": {
173                "hierarchical": {
174                    "enabled": True,
175                    "nodeSpacing": 200,
176                    "sortMethod": "directed",
177                },
178            },
179            "interaction": {
180                "dragNodes": False,
181                "selectable": False,
182            },
183            "physics": {
184                "enabled": False,
185            },
186            "edges": {
187                "arrows": "to",
188            },
189            "nodes": {
190                "font": "20px monaco",
191                "shape": "box",
192                "widthConstraint": {
193                    "maximum": 300,
194                },
195            },
196            **opts,
197        }
198
199        self.nodes = {}
200        self.edges = []
201
202        for node in node.walk():
203            if isinstance(node.expression, exp.Table):
204                label = f"FROM {node.expression.this}"
205                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
206                group = 1
207            else:
208                label = node.expression.sql(pretty=True, dialect=dialect)
209                source = node.source.transform(
210                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
211                    if n is node.expression
212                    else n,
213                    copy=False,
214                ).sql(pretty=True, dialect=dialect)
215                title = f"<pre>{source}</pre>"
216                group = 0
217
218            node_id = id(node)
219
220            self.nodes[node_id] = {
221                "id": node_id,
222                "label": label,
223                "title": title,
224                "group": group,
225            }
226
227            for d in node.downstream:
228                self.edges.append({"from": node_id, "to": id(d)})