diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/lineage.py | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 0eac870..04a8073 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -5,10 +5,8 @@ import typing as t from dataclasses import dataclass, field from sqlglot import Schema, exp, maybe_parse -from sqlglot.optimizer import Scope, build_scope, optimize -from sqlglot.optimizer.expand_laterals import expand_laterals -from sqlglot.optimizer.qualify_columns import qualify_columns -from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.errors import SqlglotError +from sqlglot.optimizer import Scope, build_scope, qualify if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -40,8 +38,8 @@ def lineage( sql: str | exp.Expression, schema: t.Optional[t.Dict | Schema] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, - rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals), dialect: DialectType = None, + **kwargs, ) -> Node: """Build the lineage graph for a column of a SQL query. @@ -50,8 +48,8 @@ def lineage( sql: The SQL string or expression. schema: The schema of tables. sources: A mapping of queries which will be used to continue building lineage. - rules: Optimizer rules to apply, by default only qualifying tables and columns. dialect: The dialect of input SQL. + **kwargs: Qualification optimizer kwargs. Returns: A lineage node. @@ -68,8 +66,17 @@ def lineage( }, ) - optimized = optimize(expression, schema=schema, rules=rules) - scope = build_scope(optimized) + qualified = qualify.qualify( + expression, + dialect=dialect, + schema=schema, + **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore + ) + + scope = build_scope(qualified) + + if not scope: + raise SqlglotError("Cannot build lineage, sql must be SELECT") def to_node( column_name: str, @@ -109,10 +116,7 @@ def lineage( # a version that has only the column we care about. # "x", SELECT x, y FROM foo # => "x", SELECT x FROM foo - source = optimize( - scope.expression.select(select, append=False), schema=schema, rules=rules - ) - select = source.selects[0] + source = t.cast(exp.Expression, scope.expression.select(select, append=False)) else: source = scope.expression |