summaryrefslogtreecommitdiffstats
path: root/sqlglot/lineage.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/lineage.py')
-rw-r--r--sqlglot/lineage.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 0eac870..04a8073 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -5,10 +5,8 @@ import typing as t
from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
-from sqlglot.optimizer import Scope, build_scope, optimize
-from sqlglot.optimizer.expand_laterals import expand_laterals
-from sqlglot.optimizer.qualify_columns import qualify_columns
-from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.errors import SqlglotError
+from sqlglot.optimizer import Scope, build_scope, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@@ -40,8 +38,8 @@ def lineage(
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
- rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
dialect: DialectType = None,
+ **kwargs,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
@@ -50,8 +48,8 @@ def lineage(
sql: The SQL string or expression.
schema: The schema of tables.
sources: A mapping of queries which will be used to continue building lineage.
- rules: Optimizer rules to apply, by default only qualifying tables and columns.
dialect: The dialect of input SQL.
+ **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
@@ -68,8 +66,17 @@ def lineage(
},
)
- optimized = optimize(expression, schema=schema, rules=rules)
- scope = build_scope(optimized)
+ qualified = qualify.qualify(
+ expression,
+ dialect=dialect,
+ schema=schema,
+ **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
+ )
+
+ scope = build_scope(qualified)
+
+ if not scope:
+ raise SqlglotError("Cannot build lineage, sql must be SELECT")
def to_node(
column_name: str,
@@ -109,10 +116,7 @@ def lineage(
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
# => "x", SELECT x FROM foo
- source = optimize(
- scope.expression.select(select, append=False), schema=schema, rules=rules
- )
- select = source.selects[0]
+ source = t.cast(exp.Expression, scope.expression.select(select, append=False))
else:
source = scope.expression