Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.errors import SqlglotError
  9from sqlglot.optimizer import Scope, build_scope, qualify
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot.dialects.dialect import DialectType
 13
 14
 15@dataclass(frozen=True)
 16class Node:
 17    name: str
 18    expression: exp.Expression
 19    source: exp.Expression
 20    downstream: t.List[Node] = field(default_factory=list)
 21    alias: str = ""
 22
 23    def walk(self) -> t.Iterator[Node]:
 24        yield self
 25
 26        for d in self.downstream:
 27            if isinstance(d, Node):
 28                yield from d.walk()
 29            else:
 30                yield d
 31
 32    def to_html(self, **opts) -> LineageHTML:
 33        return LineageHTML(self, **opts)
 34
 35
 36def lineage(
 37    column: str | exp.Column,
 38    sql: str | exp.Expression,
 39    schema: t.Optional[t.Dict | Schema] = None,
 40    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 41    dialect: DialectType = None,
 42    **kwargs,
 43) -> Node:
 44    """Build the lineage graph for a column of a SQL query.
 45
 46    Args:
 47        column: The column to build the lineage for.
 48        sql: The SQL string or expression.
 49        schema: The schema of tables.
 50        sources: A mapping of queries which will be used to continue building lineage.
 51        dialect: The dialect of input SQL.
 52        **kwargs: Qualification optimizer kwargs.
 53
 54    Returns:
 55        A lineage node.
 56    """
 57
 58    expression = maybe_parse(sql, dialect=dialect)
 59
 60    if sources:
 61        expression = exp.expand(
 62            expression,
 63            {
 64                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 65                for k, v in sources.items()
 66            },
 67        )
 68
 69    qualified = qualify.qualify(
 70        expression,
 71        dialect=dialect,
 72        schema=schema,
 73        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
 74    )
 75
 76    scope = build_scope(qualified)
 77
 78    if not scope:
 79        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 80
 81    def to_node(
 82        column_name: str,
 83        scope: Scope,
 84        scope_name: t.Optional[str] = None,
 85        upstream: t.Optional[Node] = None,
 86        alias: t.Optional[str] = None,
 87    ) -> Node:
 88        aliases = {
 89            dt.alias: dt.comments[0].split()[1]
 90            for dt in scope.derived_tables
 91            if dt.comments and dt.comments[0].startswith("source: ")
 92        }
 93        if isinstance(scope.expression, exp.Union):
 94            for scope in scope.union_scopes:
 95                node = to_node(
 96                    column_name,
 97                    scope=scope,
 98                    scope_name=scope_name,
 99                    upstream=upstream,
100                    alias=aliases.get(scope_name),
101                )
102            return node
103
104        # Find the specific select clause that is the source of the column we want.
105        # This can either be a specific, named select or a generic `*` clause.
106        select = next(
107            (select for select in scope.expression.selects if select.alias_or_name == column_name),
108            exp.Star() if scope.expression.is_star else None,
109        )
110
111        if not select:
112            raise ValueError(f"Could not find {column_name} in {scope.expression}")
113
114        if isinstance(scope.expression, exp.Select):
115            # For better ergonomics in our node labels, replace the full select with
116            # a version that has only the column we care about.
117            #   "x", SELECT x, y FROM foo
118            #     => "x", SELECT x FROM foo
119            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
120        else:
121            source = scope.expression
122
123        # Create the node for this step in the lineage chain, and attach it to the previous one.
124        node = Node(
125            name=f"{scope_name}.{column_name}" if scope_name else column_name,
126            source=source,
127            expression=select,
128            alias=alias or "",
129        )
130        if upstream:
131            upstream.downstream.append(node)
132
133        # Find all columns that went into creating this one to list their lineage nodes.
134        for c in set(select.find_all(exp.Column)):
135            table = c.table
136            source = scope.sources.get(table)
137
138            if isinstance(source, Scope):
139                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
140                to_node(
141                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
142                )
143            else:
144                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
145                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
146                # is not passed into the `sources` map.
147                source = source or exp.Placeholder()
148                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
149
150        return node
151
152    return to_node(column if isinstance(column, str) else column.name, scope)
153
154
155class LineageHTML:
156    """Node to HTML generator using vis.js.
157
158    https://visjs.github.io/vis-network/docs/network/
159    """
160
161    def __init__(
162        self,
163        node: Node,
164        dialect: DialectType = None,
165        imports: bool = True,
166        **opts: t.Any,
167    ):
168        self.node = node
169        self.imports = imports
170
171        self.options = {
172            "height": "500px",
173            "width": "100%",
174            "layout": {
175                "hierarchical": {
176                    "enabled": True,
177                    "nodeSpacing": 200,
178                    "sortMethod": "directed",
179                },
180            },
181            "interaction": {
182                "dragNodes": False,
183                "selectable": False,
184            },
185            "physics": {
186                "enabled": False,
187            },
188            "edges": {
189                "arrows": "to",
190            },
191            "nodes": {
192                "font": "20px monaco",
193                "shape": "box",
194                "widthConstraint": {
195                    "maximum": 300,
196                },
197            },
198            **opts,
199        }
200
201        self.nodes = {}
202        self.edges = []
203
204        for node in node.walk():
205            if isinstance(node.expression, exp.Table):
206                label = f"FROM {node.expression.this}"
207                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
208                group = 1
209            else:
210                label = node.expression.sql(pretty=True, dialect=dialect)
211                source = node.source.transform(
212                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
213                    if n is node.expression
214                    else n,
215                    copy=False,
216                ).sql(pretty=True, dialect=dialect)
217                title = f"<pre>{source}</pre>"
218                group = 0
219
220            node_id = id(node)
221
222            self.nodes[node_id] = {
223                "id": node_id,
224                "label": label,
225                "title": title,
226                "group": group,
227            }
228
229            for d in node.downstream:
230                self.edges.append({"from": node_id, "to": id(d)})
231
232    def __str__(self):
233        nodes = json.dumps(list(self.nodes.values()))
234        edges = json.dumps(self.edges)
235        options = json.dumps(self.options)
236        imports = (
237            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
238  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
239  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
240            if self.imports
241            else ""
242        )
243
244        return f"""<div>
245  <div id="sqlglot-lineage"></div>
246  {imports}
247  <script type="text/javascript">
248    var nodes = new vis.DataSet({nodes})
249    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
250
251    new vis.Network(
252        document.getElementById("sqlglot-lineage"),
253        {{
254            nodes: nodes,
255            edges: new vis.DataSet({edges})
256        }},
257        {options},
258    )
259  </script>
260</div>"""
261
262    def _repr_html_(self) -> str:
263        return self.__str__()
@dataclass(frozen=True)
class Node:
16@dataclass(frozen=True)
17class Node:
18    name: str
19    expression: exp.Expression
20    source: exp.Expression
21    downstream: t.List[Node] = field(default_factory=list)
22    alias: str = ""
23
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
32
33    def to_html(self, **opts) -> LineageHTML:
34        return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[sqlglot.lineage.Node] = <factory>, alias: str = '')
name: str
downstream: List[sqlglot.lineage.Node]
alias: str = ''
def walk(self) -> Iterator[sqlglot.lineage.Node]:
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
def to_html(self, **opts) -> sqlglot.lineage.LineageHTML:
33    def to_html(self, **opts) -> LineageHTML:
34        return LineageHTML(self, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> sqlglot.lineage.Node:
 37def lineage(
 38    column: str | exp.Column,
 39    sql: str | exp.Expression,
 40    schema: t.Optional[t.Dict | Schema] = None,
 41    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 42    dialect: DialectType = None,
 43    **kwargs,
 44) -> Node:
 45    """Build the lineage graph for a column of a SQL query.
 46
 47    Args:
 48        column: The column to build the lineage for.
 49        sql: The SQL string or expression.
 50        schema: The schema of tables.
 51        sources: A mapping of queries which will be used to continue building lineage.
 52        dialect: The dialect of input SQL.
 53        **kwargs: Qualification optimizer kwargs.
 54
 55    Returns:
 56        A lineage node.
 57    """
 58
 59    expression = maybe_parse(sql, dialect=dialect)
 60
 61    if sources:
 62        expression = exp.expand(
 63            expression,
 64            {
 65                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 66                for k, v in sources.items()
 67            },
 68        )
 69
 70    qualified = qualify.qualify(
 71        expression,
 72        dialect=dialect,
 73        schema=schema,
 74        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
 75    )
 76
 77    scope = build_scope(qualified)
 78
 79    if not scope:
 80        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 81
 82    def to_node(
 83        column_name: str,
 84        scope: Scope,
 85        scope_name: t.Optional[str] = None,
 86        upstream: t.Optional[Node] = None,
 87        alias: t.Optional[str] = None,
 88    ) -> Node:
 89        aliases = {
 90            dt.alias: dt.comments[0].split()[1]
 91            for dt in scope.derived_tables
 92            if dt.comments and dt.comments[0].startswith("source: ")
 93        }
 94        if isinstance(scope.expression, exp.Union):
 95            for scope in scope.union_scopes:
 96                node = to_node(
 97                    column_name,
 98                    scope=scope,
 99                    scope_name=scope_name,
100                    upstream=upstream,
101                    alias=aliases.get(scope_name),
102                )
103            return node
104
105        # Find the specific select clause that is the source of the column we want.
106        # This can either be a specific, named select or a generic `*` clause.
107        select = next(
108            (select for select in scope.expression.selects if select.alias_or_name == column_name),
109            exp.Star() if scope.expression.is_star else None,
110        )
111
112        if not select:
113            raise ValueError(f"Could not find {column_name} in {scope.expression}")
114
115        if isinstance(scope.expression, exp.Select):
116            # For better ergonomics in our node labels, replace the full select with
117            # a version that has only the column we care about.
118            #   "x", SELECT x, y FROM foo
119            #     => "x", SELECT x FROM foo
120            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
121        else:
122            source = scope.expression
123
124        # Create the node for this step in the lineage chain, and attach it to the previous one.
125        node = Node(
126            name=f"{scope_name}.{column_name}" if scope_name else column_name,
127            source=source,
128            expression=select,
129            alias=alias or "",
130        )
131        if upstream:
132            upstream.downstream.append(node)
133
134        # Find all columns that went into creating this one to list their lineage nodes.
135        for c in set(select.find_all(exp.Column)):
136            table = c.table
137            source = scope.sources.get(table)
138
139            if isinstance(source, Scope):
140                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
141                to_node(
142                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
143                )
144            else:
145                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
146                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
147                # is not passed into the `sources` map.
148                source = source or exp.Placeholder()
149                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
150
151        return node
152
153    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

class LineageHTML:
156class LineageHTML:
157    """Node to HTML generator using vis.js.
158
159    https://visjs.github.io/vis-network/docs/network/
160    """
161
162    def __init__(
163        self,
164        node: Node,
165        dialect: DialectType = None,
166        imports: bool = True,
167        **opts: t.Any,
168    ):
169        self.node = node
170        self.imports = imports
171
172        self.options = {
173            "height": "500px",
174            "width": "100%",
175            "layout": {
176                "hierarchical": {
177                    "enabled": True,
178                    "nodeSpacing": 200,
179                    "sortMethod": "directed",
180                },
181            },
182            "interaction": {
183                "dragNodes": False,
184                "selectable": False,
185            },
186            "physics": {
187                "enabled": False,
188            },
189            "edges": {
190                "arrows": "to",
191            },
192            "nodes": {
193                "font": "20px monaco",
194                "shape": "box",
195                "widthConstraint": {
196                    "maximum": 300,
197                },
198            },
199            **opts,
200        }
201
202        self.nodes = {}
203        self.edges = []
204
205        for node in node.walk():
206            if isinstance(node.expression, exp.Table):
207                label = f"FROM {node.expression.this}"
208                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
209                group = 1
210            else:
211                label = node.expression.sql(pretty=True, dialect=dialect)
212                source = node.source.transform(
213                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
214                    if n is node.expression
215                    else n,
216                    copy=False,
217                ).sql(pretty=True, dialect=dialect)
218                title = f"<pre>{source}</pre>"
219                group = 0
220
221            node_id = id(node)
222
223            self.nodes[node_id] = {
224                "id": node_id,
225                "label": label,
226                "title": title,
227                "group": group,
228            }
229
230            for d in node.downstream:
231                self.edges.append({"from": node_id, "to": id(d)})
232
233    def __str__(self):
234        nodes = json.dumps(list(self.nodes.values()))
235        edges = json.dumps(self.edges)
236        options = json.dumps(self.options)
237        imports = (
238            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
239  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
240  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
241            if self.imports
242            else ""
243        )
244
245        return f"""<div>
246  <div id="sqlglot-lineage"></div>
247  {imports}
248  <script type="text/javascript">
249    var nodes = new vis.DataSet({nodes})
250    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
251
252    new vis.Network(
253        document.getElementById("sqlglot-lineage"),
254        {{
255            nodes: nodes,
256            edges: new vis.DataSet({edges})
257        }},
258        {options},
259    )
260  </script>
261</div>"""
262
263    def _repr_html_(self) -> str:
264        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

LineageHTML( node: sqlglot.lineage.Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
162    def __init__(
163        self,
164        node: Node,
165        dialect: DialectType = None,
166        imports: bool = True,
167        **opts: t.Any,
168    ):
169        self.node = node
170        self.imports = imports
171
172        self.options = {
173            "height": "500px",
174            "width": "100%",
175            "layout": {
176                "hierarchical": {
177                    "enabled": True,
178                    "nodeSpacing": 200,
179                    "sortMethod": "directed",
180                },
181            },
182            "interaction": {
183                "dragNodes": False,
184                "selectable": False,
185            },
186            "physics": {
187                "enabled": False,
188            },
189            "edges": {
190                "arrows": "to",
191            },
192            "nodes": {
193                "font": "20px monaco",
194                "shape": "box",
195                "widthConstraint": {
196                    "maximum": 300,
197                },
198            },
199            **opts,
200        }
201
202        self.nodes = {}
203        self.edges = []
204
205        for node in node.walk():
206            if isinstance(node.expression, exp.Table):
207                label = f"FROM {node.expression.this}"
208                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
209                group = 1
210            else:
211                label = node.expression.sql(pretty=True, dialect=dialect)
212                source = node.source.transform(
213                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
214                    if n is node.expression
215                    else n,
216                    copy=False,
217                ).sql(pretty=True, dialect=dialect)
218                title = f"<pre>{source}</pre>"
219                group = 0
220
221            node_id = id(node)
222
223            self.nodes[node_id] = {
224                "id": node_id,
225                "label": label,
226                "title": title,
227                "group": group,
228            }
229
230            for d in node.downstream:
231                self.edges.append({"from": node_id, "to": id(d)})
node
imports
options
nodes
edges