diff options
Diffstat (limited to 'tests/test_schema.py')
-rw-r--r-- | tests/test_schema.py | 345 |
1 files changed, 118 insertions, 227 deletions
diff --git a/tests/test_schema.py b/tests/test_schema.py index bab97d8..cc0e3d1 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,281 +1,141 @@ import unittest -from sqlglot import table -from sqlglot.dataframe.sql import types as df_types +from sqlglot import exp, to_table +from sqlglot.errors import SchemaError from sqlglot.schema import MappingSchema, ensure_schema class TestSchema(unittest.TestCase): - def test_schema(self): - schema = ensure_schema( - { - "x": { - "a": "uint64", - } - } - ) - self.assertEqual( - schema.column_names( - table( - "x", - ) - ), - ["a"], - ) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x2")) + def assert_column_names(self, schema, *table_results): + for table, result in table_results: + with self.subTest(f"{table} -> {result}"): + self.assertEqual(schema.column_names(to_table(table)), result) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + def assert_column_names_raises(self, schema, *tables): + for table in tables: + with self.subTest(table): + with self.assertRaises(SchemaError): + schema.column_names(to_table(table)) - schema.add_table(table("y"), {"b": "string"}) - schema_with_y = { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - } - self.assertEqual(schema.schema, schema_with_y) - - new_schema = schema.copy() - new_schema.add_table(table("z"), {"c": "string"}) - self.assertEqual(schema.schema, schema_with_y) - self.assertEqual( - new_schema.schema, - { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - "z": { - "c": "string", - }, - }, - ) - schema.add_table(table("m"), {"d": "string"}) - schema.add_table(table("n"), {"e": "string"}) - schema_with_m_n = { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - "m": { - "d": "string", - }, - "n": { - "e": "string", - }, - } - self.assertEqual(schema.schema, schema_with_m_n) - new_schema = schema.copy() - new_schema.add_table(table("o"), {"f": "string"}) - new_schema.add_table(table("p"), {"g": "string"}) - self.assertEqual(schema.schema, schema_with_m_n) - self.assertEqual( - new_schema.schema, + def test_schema(self): + schema = ensure_schema( { "x": { "a": "uint64", }, "y": { - "b": "string", - }, - "m": { - "d": "string", - }, - "n": { - "e": "string", - }, - "o": { - "f": "string", - }, - "p": { - "g": "string", + "b": "uint64", + "c": "uint64", }, }, ) - schema = ensure_schema( - { - "db": { - "x": { - "a": "uint64", - } - } - } + self.assert_column_names( + schema, + ("x", ["a"]), + ("y", ["b", "c"]), + ("z.x", ["a"]), + ("z.x.y", ["b", "c"]), ) - self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - with self.assertRaises(ValueError): - schema.add_table(table("y"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + self.assert_column_names_raises( + schema, + "z", + "z.z", + "z.z.z", + ) - schema.add_table(table("y", db="db"), {"b": "string"}) - self.assertEqual( - schema.schema, + def test_schema_db(self): + schema = ensure_schema( { - "db": { + "d1": { "x": { "a": "uint64", }, "y": { - "b": "string", + "b": "uint64", + }, + }, + "d2": { + "x": { + "c": "uint64", }, - } + }, }, ) - schema = ensure_schema( - { - "c": { - "db": { - "x": { - "a": "uint64", - } - } - } - } + self.assert_column_names( + schema, + ("d1.x", ["a"]), + ("d2.x", ["c"]), + ("y", ["b"]), + ("d1.y", ["b"]), + ("z.d1.y", ["b"]), ) - self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c2")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - with self.assertRaises(ValueError): - schema.add_table(table("x"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("x", db="db"), {"b": "string"}) + self.assert_column_names_raises( + schema, + "x", + "z.x", + "z.y", + ) - schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"}) - self.assertEqual( - schema.schema, + def test_schema_catalog(self): + schema = ensure_schema( { - "c": { - "db": { + "c1": { + "d1": { "x": { "a": "uint64", }, "y": { - "a": "string", - "b": "int", + "b": "uint64", }, - } - } - }, - ) - schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"}) - self.assertEqual( - schema.schema, - { - "c": { - "db": { - "x": { - "a": "uint64", + "z": { + "c": "uint64", }, + }, + }, + "c2": { + "d1": { "y": { - "a": "string", - "b": "int", + "d": "uint64", }, - }, - "db2": { "z": { - "c": "string", - "d": "int", - } - }, - } - }, - ) - schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"}) - self.assertEqual( - schema.schema, - { - "c": { - "db": { - "x": { - "a": "uint64", - }, - "y": { - "a": "string", - "b": "int", + "e": "uint64", }, }, - "db2": { + "d2": { "z": { - "c": "string", - "d": "int", - } + "f": "uint64", + }, }, }, - "c2": { - "db2": { - "m": { - "e": "string", - "f": "int", - } - } - }, - }, - ) - - schema = ensure_schema( - { - "x": { - "a": "uint64", - } } ) - self.assertEqual(schema.column_names(table("x")), ["a"]) - schema = MappingSchema() - schema.add_table(table("x"), {"a": "string"}) - self.assertEqual( - schema.schema, - { - "x": { - "a": "string", - } - }, + self.assert_column_names( + schema, + ("x", ["a"]), + ("d1.x", ["a"]), + ("c1.d1.x", ["a"]), + ("c1.d1.y", ["b"]), + ("c1.d1.z", ["c"]), + ("c2.d1.y", ["d"]), + ("c2.d1.z", ["e"]), + ("d2.z", ["f"]), + ("c2.d2.z", ["f"]), ) - schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())])) - self.assertEqual( - schema.schema, - { - "x": { - "a": "string", - }, - "y": { - "b": "string", - }, - }, + + self.assert_column_names_raises( + schema, + "q", + "d2.x", + "y", + "z", + "d1.y", + "d1.z", + "a.b.c", ) def test_schema_add_table_with_and_without_mapping(self): @@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase): self.assertEqual(schema.column_names("test"), ["x", "y"]) schema.add_table("test") self.assertEqual(schema.column_names("test"), ["x", "y"]) + + def test_schema_get_column_type(self): + schema = MappingSchema({"a": {"b": "varchar"}}) + self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR) + self.assertEqual( + schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR + ) + schema = MappingSchema({"a": {"b": {"c": "varchar"}}}) + self.assertEqual( + schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR + ) + schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}}) + self.assertEqual( + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"), + exp.DataType.Type.VARCHAR, + ) |