sqlglot.lineage
1from __future__ import annotations 2 3import json 4import typing as t 5from dataclasses import dataclass, field 6 7from sqlglot import Schema, exp, maybe_parse 8from sqlglot.errors import SqlglotError 9from sqlglot.optimizer import Scope, build_scope, qualify 10 11if t.TYPE_CHECKING: 12 from sqlglot.dialects.dialect import DialectType 13 14 15@dataclass(frozen=True) 16class Node: 17 name: str 18 expression: exp.Expression 19 source: exp.Expression 20 downstream: t.List[Node] = field(default_factory=list) 21 alias: str = "" 22 23 def walk(self) -> t.Iterator[Node]: 24 yield self 25 26 for d in self.downstream: 27 if isinstance(d, Node): 28 yield from d.walk() 29 else: 30 yield d 31 32 def to_html(self, **opts) -> LineageHTML: 33 return LineageHTML(self, **opts) 34 35 36def lineage( 37 column: str | exp.Column, 38 sql: str | exp.Expression, 39 schema: t.Optional[t.Dict | Schema] = None, 40 sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, 41 dialect: DialectType = None, 42 **kwargs, 43) -> Node: 44 """Build the lineage graph for a column of a SQL query. 45 46 Args: 47 column: The column to build the lineage for. 48 sql: The SQL string or expression. 49 schema: The schema of tables. 50 sources: A mapping of queries which will be used to continue building lineage. 51 dialect: The dialect of input SQL. 52 **kwargs: Qualification optimizer kwargs. 53 54 Returns: 55 A lineage node. 56 """ 57 58 expression = maybe_parse(sql, dialect=dialect) 59 60 if sources: 61 expression = exp.expand( 62 expression, 63 { 64 k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect)) 65 for k, v in sources.items() 66 }, 67 ) 68 69 qualified = qualify.qualify( 70 expression, 71 dialect=dialect, 72 schema=schema, 73 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 74 ) 75 76 scope = build_scope(qualified) 77 78 if not scope: 79 raise SqlglotError("Cannot build lineage, sql must be SELECT") 80 81 def to_node( 82 column: str | int, 83 scope: Scope, 84 scope_name: t.Optional[str] = None, 85 upstream: t.Optional[Node] = None, 86 alias: t.Optional[str] = None, 87 ) -> Node: 88 aliases = { 89 dt.alias: dt.comments[0].split()[1] 90 for dt in scope.derived_tables 91 if dt.comments and dt.comments[0].startswith("source: ") 92 } 93 94 # Find the specific select clause that is the source of the column we want. 95 # This can either be a specific, named select or a generic `*` clause. 96 select = ( 97 scope.expression.selects[column] 98 if isinstance(column, int) 99 else next( 100 (select for select in scope.expression.selects if select.alias_or_name == column), 101 exp.Star() if scope.expression.is_star else None, 102 ) 103 ) 104 105 if not select: 106 raise ValueError(f"Could not find {column} in {scope.expression}") 107 108 if isinstance(scope.expression, exp.Union): 109 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 110 111 index = ( 112 column 113 if isinstance(column, int) 114 else next( 115 i 116 for i, select in enumerate(scope.expression.selects) 117 if select.alias_or_name == column 118 ) 119 ) 120 121 for s in scope.union_scopes: 122 to_node(index, scope=s, upstream=upstream) 123 124 return upstream 125 126 subquery = select.unalias() 127 128 if isinstance(subquery, exp.Subquery): 129 upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select) 130 scope = t.cast(Scope, build_scope(subquery.unnest())) 131 132 for select in subquery.named_selects: 133 to_node(select, scope=scope, upstream=upstream) 134 135 return upstream 136 137 if isinstance(scope.expression, exp.Select): 138 # For better ergonomics in our node labels, replace the full select with 139 # a version that has only the column we care about. 140 # "x", SELECT x, y FROM foo 141 # => "x", SELECT x FROM foo 142 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 143 else: 144 source = scope.expression 145 146 # Create the node for this step in the lineage chain, and attach it to the previous one. 147 node = Node( 148 name=f"{scope_name}.{column}" if scope_name else str(column), 149 source=source, 150 expression=select, 151 alias=alias or "", 152 ) 153 if upstream: 154 upstream.downstream.append(node) 155 156 # Find all columns that went into creating this one to list their lineage nodes. 157 source_columns = set(select.find_all(exp.Column)) 158 159 # If the source is a UDTF find columns used in the UTDF to generate the table 160 if isinstance(source, exp.UDTF): 161 source_columns |= set(source.find_all(exp.Column)) 162 163 for c in source_columns: 164 table = c.table 165 source = scope.sources.get(table) 166 167 if isinstance(source, Scope): 168 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 169 to_node( 170 c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table) 171 ) 172 else: 173 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 174 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 175 # is not passed into the `sources` map. 176 source = source or exp.Placeholder() 177 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 178 179 return node 180 181 return to_node(column if isinstance(column, str) else column.name, scope) 182 183 184class LineageHTML: 185 """Node to HTML generator using vis.js. 186 187 https://visjs.github.io/vis-network/docs/network/ 188 """ 189 190 def __init__( 191 self, 192 node: Node, 193 dialect: DialectType = None, 194 imports: bool = True, 195 **opts: t.Any, 196 ): 197 self.node = node 198 self.imports = imports 199 200 self.options = { 201 "height": "500px", 202 "width": "100%", 203 "layout": { 204 "hierarchical": { 205 "enabled": True, 206 "nodeSpacing": 200, 207 "sortMethod": "directed", 208 }, 209 }, 210 "interaction": { 211 "dragNodes": False, 212 "selectable": False, 213 }, 214 "physics": { 215 "enabled": False, 216 }, 217 "edges": { 218 "arrows": "to", 219 }, 220 "nodes": { 221 "font": "20px monaco", 222 "shape": "box", 223 "widthConstraint": { 224 "maximum": 300, 225 }, 226 }, 227 **opts, 228 } 229 230 self.nodes = {} 231 self.edges = [] 232 233 for node in node.walk(): 234 if isinstance(node.expression, exp.Table): 235 label = f"FROM {node.expression.this}" 236 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 237 group = 1 238 else: 239 label = node.expression.sql(pretty=True, dialect=dialect) 240 source = node.source.transform( 241 lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") 242 if n is node.expression 243 else n, 244 copy=False, 245 ).sql(pretty=True, dialect=dialect) 246 title = f"<pre>{source}</pre>" 247 group = 0 248 249 node_id = id(node) 250 251 self.nodes[node_id] = { 252 "id": node_id, 253 "label": label, 254 "title": title, 255 "group": group, 256 } 257 258 for d in node.downstream: 259 self.edges.append({"from": node_id, "to": id(d)}) 260 261 def __str__(self): 262 nodes = json.dumps(list(self.nodes.values())) 263 edges = json.dumps(self.edges) 264 options = json.dumps(self.options) 265 imports = ( 266 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 267 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 268 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 269 if self.imports 270 else "" 271 ) 272 273 return f"""<div> 274 <div id="sqlglot-lineage"></div> 275 {imports} 276 <script type="text/javascript"> 277 var nodes = new vis.DataSet({nodes}) 278 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 279 280 new vis.Network( 281 document.getElementById("sqlglot-lineage"), 282 {{ 283 nodes: nodes, 284 edges: new vis.DataSet({edges}) 285 }}, 286 {options}, 287 ) 288 </script> 289</div>""" 290 291 def _repr_html_(self) -> str: 292 return self.__str__()
@dataclass(frozen=True)
class
Node:
16@dataclass(frozen=True) 17class Node: 18 name: str 19 expression: exp.Expression 20 source: exp.Expression 21 downstream: t.List[Node] = field(default_factory=list) 22 alias: str = "" 23 24 def walk(self) -> t.Iterator[Node]: 25 yield self 26 27 for d in self.downstream: 28 if isinstance(d, Node): 29 yield from d.walk() 30 else: 31 yield d 32 33 def to_html(self, **opts) -> LineageHTML: 34 return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, alias: str = '')
expression: sqlglot.expressions.Expression
source: sqlglot.expressions.Expression
downstream: List[Node]
def
lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
37def lineage( 38 column: str | exp.Column, 39 sql: str | exp.Expression, 40 schema: t.Optional[t.Dict | Schema] = None, 41 sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, 42 dialect: DialectType = None, 43 **kwargs, 44) -> Node: 45 """Build the lineage graph for a column of a SQL query. 46 47 Args: 48 column: The column to build the lineage for. 49 sql: The SQL string or expression. 50 schema: The schema of tables. 51 sources: A mapping of queries which will be used to continue building lineage. 52 dialect: The dialect of input SQL. 53 **kwargs: Qualification optimizer kwargs. 54 55 Returns: 56 A lineage node. 57 """ 58 59 expression = maybe_parse(sql, dialect=dialect) 60 61 if sources: 62 expression = exp.expand( 63 expression, 64 { 65 k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect)) 66 for k, v in sources.items() 67 }, 68 ) 69 70 qualified = qualify.qualify( 71 expression, 72 dialect=dialect, 73 schema=schema, 74 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 75 ) 76 77 scope = build_scope(qualified) 78 79 if not scope: 80 raise SqlglotError("Cannot build lineage, sql must be SELECT") 81 82 def to_node( 83 column: str | int, 84 scope: Scope, 85 scope_name: t.Optional[str] = None, 86 upstream: t.Optional[Node] = None, 87 alias: t.Optional[str] = None, 88 ) -> Node: 89 aliases = { 90 dt.alias: dt.comments[0].split()[1] 91 for dt in scope.derived_tables 92 if dt.comments and dt.comments[0].startswith("source: ") 93 } 94 95 # Find the specific select clause that is the source of the column we want. 96 # This can either be a specific, named select or a generic `*` clause. 97 select = ( 98 scope.expression.selects[column] 99 if isinstance(column, int) 100 else next( 101 (select for select in scope.expression.selects if select.alias_or_name == column), 102 exp.Star() if scope.expression.is_star else None, 103 ) 104 ) 105 106 if not select: 107 raise ValueError(f"Could not find {column} in {scope.expression}") 108 109 if isinstance(scope.expression, exp.Union): 110 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 111 112 index = ( 113 column 114 if isinstance(column, int) 115 else next( 116 i 117 for i, select in enumerate(scope.expression.selects) 118 if select.alias_or_name == column 119 ) 120 ) 121 122 for s in scope.union_scopes: 123 to_node(index, scope=s, upstream=upstream) 124 125 return upstream 126 127 subquery = select.unalias() 128 129 if isinstance(subquery, exp.Subquery): 130 upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select) 131 scope = t.cast(Scope, build_scope(subquery.unnest())) 132 133 for select in subquery.named_selects: 134 to_node(select, scope=scope, upstream=upstream) 135 136 return upstream 137 138 if isinstance(scope.expression, exp.Select): 139 # For better ergonomics in our node labels, replace the full select with 140 # a version that has only the column we care about. 141 # "x", SELECT x, y FROM foo 142 # => "x", SELECT x FROM foo 143 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 144 else: 145 source = scope.expression 146 147 # Create the node for this step in the lineage chain, and attach it to the previous one. 148 node = Node( 149 name=f"{scope_name}.{column}" if scope_name else str(column), 150 source=source, 151 expression=select, 152 alias=alias or "", 153 ) 154 if upstream: 155 upstream.downstream.append(node) 156 157 # Find all columns that went into creating this one to list their lineage nodes. 158 source_columns = set(select.find_all(exp.Column)) 159 160 # If the source is a UDTF find columns used in the UTDF to generate the table 161 if isinstance(source, exp.UDTF): 162 source_columns |= set(source.find_all(exp.Column)) 163 164 for c in source_columns: 165 table = c.table 166 source = scope.sources.get(table) 167 168 if isinstance(source, Scope): 169 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 170 to_node( 171 c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table) 172 ) 173 else: 174 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 175 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 176 # is not passed into the `sources` map. 177 source = source or exp.Placeholder() 178 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 179 180 return node 181 182 return to_node(column if isinstance(column, str) else column.name, scope)
Build the lineage graph for a column of a SQL query.
Arguments:
- column: The column to build the lineage for.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
class
LineageHTML:
185class LineageHTML: 186 """Node to HTML generator using vis.js. 187 188 https://visjs.github.io/vis-network/docs/network/ 189 """ 190 191 def __init__( 192 self, 193 node: Node, 194 dialect: DialectType = None, 195 imports: bool = True, 196 **opts: t.Any, 197 ): 198 self.node = node 199 self.imports = imports 200 201 self.options = { 202 "height": "500px", 203 "width": "100%", 204 "layout": { 205 "hierarchical": { 206 "enabled": True, 207 "nodeSpacing": 200, 208 "sortMethod": "directed", 209 }, 210 }, 211 "interaction": { 212 "dragNodes": False, 213 "selectable": False, 214 }, 215 "physics": { 216 "enabled": False, 217 }, 218 "edges": { 219 "arrows": "to", 220 }, 221 "nodes": { 222 "font": "20px monaco", 223 "shape": "box", 224 "widthConstraint": { 225 "maximum": 300, 226 }, 227 }, 228 **opts, 229 } 230 231 self.nodes = {} 232 self.edges = [] 233 234 for node in node.walk(): 235 if isinstance(node.expression, exp.Table): 236 label = f"FROM {node.expression.this}" 237 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 238 group = 1 239 else: 240 label = node.expression.sql(pretty=True, dialect=dialect) 241 source = node.source.transform( 242 lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") 243 if n is node.expression 244 else n, 245 copy=False, 246 ).sql(pretty=True, dialect=dialect) 247 title = f"<pre>{source}</pre>" 248 group = 0 249 250 node_id = id(node) 251 252 self.nodes[node_id] = { 253 "id": node_id, 254 "label": label, 255 "title": title, 256 "group": group, 257 } 258 259 for d in node.downstream: 260 self.edges.append({"from": node_id, "to": id(d)}) 261 262 def __str__(self): 263 nodes = json.dumps(list(self.nodes.values())) 264 edges = json.dumps(self.edges) 265 options = json.dumps(self.options) 266 imports = ( 267 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 268 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 269 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 270 if self.imports 271 else "" 272 ) 273 274 return f"""<div> 275 <div id="sqlglot-lineage"></div> 276 {imports} 277 <script type="text/javascript"> 278 var nodes = new vis.DataSet({nodes}) 279 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 280 281 new vis.Network( 282 document.getElementById("sqlglot-lineage"), 283 {{ 284 nodes: nodes, 285 edges: new vis.DataSet({edges}) 286 }}, 287 {options}, 288 ) 289 </script> 290</div>""" 291 292 def _repr_html_(self) -> str: 293 return self.__str__()
Node to HTML generator using vis.js.
LineageHTML( node: Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
191 def __init__( 192 self, 193 node: Node, 194 dialect: DialectType = None, 195 imports: bool = True, 196 **opts: t.Any, 197 ): 198 self.node = node 199 self.imports = imports 200 201 self.options = { 202 "height": "500px", 203 "width": "100%", 204 "layout": { 205 "hierarchical": { 206 "enabled": True, 207 "nodeSpacing": 200, 208 "sortMethod": "directed", 209 }, 210 }, 211 "interaction": { 212 "dragNodes": False, 213 "selectable": False, 214 }, 215 "physics": { 216 "enabled": False, 217 }, 218 "edges": { 219 "arrows": "to", 220 }, 221 "nodes": { 222 "font": "20px monaco", 223 "shape": "box", 224 "widthConstraint": { 225 "maximum": 300, 226 }, 227 }, 228 **opts, 229 } 230 231 self.nodes = {} 232 self.edges = [] 233 234 for node in node.walk(): 235 if isinstance(node.expression, exp.Table): 236 label = f"FROM {node.expression.this}" 237 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 238 group = 1 239 else: 240 label = node.expression.sql(pretty=True, dialect=dialect) 241 source = node.source.transform( 242 lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") 243 if n is node.expression 244 else n, 245 copy=False, 246 ).sql(pretty=True, dialect=dialect) 247 title = f"<pre>{source}</pre>" 248 group = 0 249 250 node_id = id(node) 251 252 self.nodes[node_id] = { 253 "id": node_id, 254 "label": label, 255 "title": title, 256 "group": group, 257 } 258 259 for d in node.downstream: 260 self.edges.append({"from": node_id, "to": id(d)})