summaryrefslogtreecommitdiffstats
path: root/tests/test_lineage.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_lineage.py')
-rw-r--r--tests/test_lineage.py72
1 files changed, 68 insertions, 4 deletions
diff --git a/tests/test_lineage.py b/tests/test_lineage.py
index c782d9a..036f146 100644
--- a/tests/test_lineage.py
+++ b/tests/test_lineage.py
@@ -224,16 +224,50 @@ class TestLineage(unittest.TestCase):
downstream.source.sql(dialect="snowflake"),
"LATERAL FLATTEN(INPUT => TEST_TABLE.RESULT, OUTER => TRUE) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)",
)
- self.assertEqual(
- downstream.expression.sql(dialect="snowflake"),
- "VALUE",
- )
+ self.assertEqual(downstream.expression.sql(dialect="snowflake"), "VALUE")
self.assertEqual(len(downstream.downstream), 1)
downstream = downstream.downstream[0]
self.assertEqual(downstream.name, "TEST_TABLE.RESULT")
self.assertEqual(downstream.source.sql(dialect="snowflake"), "TEST_TABLE AS TEST_TABLE")
+ node = lineage(
+ "FIELD",
+ "SELECT FLATTENED.VALUE:field::text AS FIELD FROM SNOWFLAKE.SCHEMA.MODEL AS MODEL_ALIAS, LATERAL FLATTEN(INPUT => MODEL_ALIAS.A) AS FLATTENED",
+ schema={"SNOWFLAKE": {"SCHEMA": {"TABLE": {"A": "integer"}}}},
+ sources={"SNOWFLAKE.SCHEMA.MODEL": "SELECT A FROM SNOWFLAKE.SCHEMA.TABLE"},
+ dialect="snowflake",
+ )
+ self.assertEqual(node.name, "FIELD")
+
+ downstream = node.downstream[0]
+ self.assertEqual(downstream.name, "FLATTENED.VALUE")
+ self.assertEqual(
+ downstream.source.sql(dialect="snowflake"),
+ "LATERAL FLATTEN(INPUT => MODEL_ALIAS.A) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)",
+ )
+ self.assertEqual(downstream.expression.sql(dialect="snowflake"), "VALUE")
+ self.assertEqual(len(downstream.downstream), 1)
+
+ downstream = downstream.downstream[0]
+ self.assertEqual(downstream.name, "MODEL_ALIAS.A")
+ self.assertEqual(downstream.source_name, "SNOWFLAKE.SCHEMA.MODEL")
+ self.assertEqual(
+ downstream.source.sql(dialect="snowflake"),
+ "SELECT TABLE.A AS A FROM SNOWFLAKE.SCHEMA.TABLE AS TABLE",
+ )
+ self.assertEqual(downstream.expression.sql(dialect="snowflake"), "TABLE.A AS A")
+ self.assertEqual(len(downstream.downstream), 1)
+
+ downstream = downstream.downstream[0]
+ self.assertEqual(downstream.name, "TABLE.A")
+ self.assertEqual(
+ downstream.source.sql(dialect="snowflake"), "SNOWFLAKE.SCHEMA.TABLE AS TABLE"
+ )
+ self.assertEqual(
+ downstream.expression.sql(dialect="snowflake"), "SNOWFLAKE.SCHEMA.TABLE AS TABLE"
+ )
+
def test_subquery(self) -> None:
node = lineage(
"output",
@@ -266,6 +300,7 @@ class TestLineage(unittest.TestCase):
self.assertEqual(node.name, "a")
node = node.downstream[0]
self.assertEqual(node.name, "cte.a")
+ self.assertEqual(node.reference_node_name, "cte")
node = node.downstream[0]
self.assertEqual(node.name, "z.a")
@@ -304,6 +339,27 @@ class TestLineage(unittest.TestCase):
node = a.downstream[0]
self.assertEqual(node.name, "foo.a")
+ # Select from derived table
+ node = lineage(
+ "a",
+ "SELECT a FROM (SELECT a FROM x) subquery",
+ )
+ self.assertEqual(node.name, "a")
+ self.assertEqual(len(node.downstream), 1)
+ node = node.downstream[0]
+ self.assertEqual(node.name, "subquery.a")
+ self.assertEqual(node.reference_node_name, "subquery")
+
+ node = lineage(
+ "a",
+ "SELECT a FROM (SELECT a FROM x)",
+ )
+ self.assertEqual(node.name, "a")
+ self.assertEqual(len(node.downstream), 1)
+ node = node.downstream[0]
+ self.assertEqual(node.name, "_q_0.a")
+ self.assertEqual(node.reference_node_name, "_q_0")
+
def test_lineage_cte_union(self) -> None:
query = """
WITH dataset AS (
@@ -431,3 +487,11 @@ class TestLineage(unittest.TestCase):
downstream = node.downstream[0]
self.assertEqual(downstream.name, "z.a")
self.assertEqual(downstream.source.sql(), "SELECT y.a AS a, y.b AS b, y.c AS c FROM y AS y")
+
+ def test_node_name_doesnt_contain_comment(self) -> None:
+ sql = "SELECT * FROM (SELECT x /* c */ FROM t1) AS t2"
+ node = lineage("x", sql)
+
+ self.assertEqual(len(node.downstream), 1)
+ self.assertEqual(len(node.downstream[0].downstream), 1)
+ self.assertEqual(node.downstream[0].downstream[0].name, "t1.x")