diff options
Diffstat (limited to 'tests/dataframe')
-rw-r--r-- | tests/dataframe/integration/test_session.py | 7 | ||||
-rw-r--r-- | tests/dataframe/unit/test_column.py | 4 | ||||
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 22 | ||||
-rw-r--r-- | tests/dataframe/unit/test_session.py | 6 |
4 files changed, 23 insertions, 16 deletions
diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py index ec50034..3bb3e20 100644 --- a/tests/dataframe/integration/test_session.py +++ b/tests/dataframe/integration/test_session.py @@ -34,3 +34,10 @@ class TestSessionFunc(DataFrameValidator): .agg(SF.countDistinct(SF.col("employee_id"))) ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_nameless_column(self): + query = "SELECT MAX(age) FROM employee" + df = self.spark.sql(query) + dfs = self.sqlglot.sql(query) + # Spark will alias the column to `max(age)` while sqlglot will alias to `_col_0` so their schemas will differ + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index 7a12808..833005b 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -150,8 +150,8 @@ class TestDataframeColumn(unittest.TestCase): F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), ) self.assertEqual( - "cola BETWEEN CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP) " - "AND CAST('2022-03-01T01:01:01+00:00' AS TIMESTAMP)", + "cola BETWEEN CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP) " + "AND CAST('2022-03-01 01:01:01+00:00' AS TIMESTAMP)", F.col("cola") .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)) .sql(), diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index e40d50d..884cded 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -29,7 +29,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.lit(datetime.date(2022, 1, 1)) self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql()) test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.lit({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) @@ -51,7 +51,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.col(datetime.date(2022, 1, 1)) self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql()) test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.col({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) @@ -250,9 +250,9 @@ class TestFunctions(unittest.TestCase): def test_log10(self): col_str = SF.log10("cola") - self.assertEqual("LOG10(cola)", col_str.sql()) + self.assertEqual("LOG(10, cola)", col_str.sql()) col = SF.log10(SF.col("cola")) - self.assertEqual("LOG10(cola)", col.sql()) + self.assertEqual("LOG(10, cola)", col.sql()) def test_log1p(self): col_str = SF.log1p("cola") @@ -262,9 +262,9 @@ class TestFunctions(unittest.TestCase): def test_log2(self): col_str = SF.log2("cola") - self.assertEqual("LOG2(cola)", col_str.sql()) + self.assertEqual("LOG(2, cola)", col_str.sql()) col = SF.log2(SF.col("cola")) - self.assertEqual("LOG2(cola)", col.sql()) + self.assertEqual("LOG(2, cola)", col.sql()) def test_rint(self): col_str = SF.rint("cola") @@ -1156,17 +1156,17 @@ class TestFunctions(unittest.TestCase): def test_regexp_extract(self): col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)', 1)", col_str.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql()) col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)', 1)", col.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql()) col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)") - self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)')", col_no_idx.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql()) def test_regexp_replace(self): col_str = SF.regexp_replace("cola", r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\d+)', '--')", col_str.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql()) col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\d+)', '--')", col.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql()) def test_initcap(self): col_str = SF.initcap("cola") diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index e2ebae4..848c603 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -79,7 +79,7 @@ class TestDataframeSession(DataFrameSQLValidator): sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb")) self.assertEqual( - "WITH t38189 AS (SELECT cola, colb FROM table), t42330 AS (SELECT cola, colb FROM t38189) SELECT cola, SUM(colb) FROM t42330 GROUP BY cola", + "WITH t26614 AS (SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`), t23454 AS (SELECT cola, colb FROM t26614) SELECT cola, SUM(colb) FROM t23454 GROUP BY cola", df.sql(pretty=False, optimize=False)[0], ) @@ -87,14 +87,14 @@ class TestDataframeSession(DataFrameSQLValidator): query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) - expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" + expected = "CREATE TABLE `new_table` AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) def test_sql_insert(self): query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) - expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" + expected = "INSERT INTO `new_table` SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) def test_session_create_builder_patterns(self): |