summaryrefslogtreecommitdiffstats
path: root/tests/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe')
-rw-r--r--tests/dataframe/integration/test_session.py7
-rw-r--r--tests/dataframe/unit/test_column.py4
-rw-r--r--tests/dataframe/unit/test_functions.py22
-rw-r--r--tests/dataframe/unit/test_session.py6
4 files changed, 23 insertions, 16 deletions
diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py
index ec50034..3bb3e20 100644
--- a/tests/dataframe/integration/test_session.py
+++ b/tests/dataframe/integration/test_session.py
@@ -34,3 +34,10 @@ class TestSessionFunc(DataFrameValidator):
.agg(SF.countDistinct(SF.col("employee_id")))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
+
+ def test_nameless_column(self):
+ query = "SELECT MAX(age) FROM employee"
+ df = self.spark.sql(query)
+ dfs = self.sqlglot.sql(query)
+ # Spark will alias the column to `max(age)` while sqlglot will alias to `_col_0` so their schemas will differ
+ self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py
index 7a12808..833005b 100644
--- a/tests/dataframe/unit/test_column.py
+++ b/tests/dataframe/unit/test_column.py
@@ -150,8 +150,8 @@ class TestDataframeColumn(unittest.TestCase):
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
)
self.assertEqual(
- "cola BETWEEN CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP) "
- "AND CAST('2022-03-01T01:01:01+00:00' AS TIMESTAMP)",
+ "cola BETWEEN CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP) "
+ "AND CAST('2022-03-01 01:01:01+00:00' AS TIMESTAMP)",
F.col("cola")
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
.sql(),
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
index e40d50d..884cded 100644
--- a/tests/dataframe/unit/test_functions.py
+++ b/tests/dataframe/unit/test_functions.py
@@ -29,7 +29,7 @@ class TestFunctions(unittest.TestCase):
test_date = SF.lit(datetime.date(2022, 1, 1))
self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql())
test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1))
- self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
+ self.assertEqual("CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
test_dict = SF.lit({"cola": 1, "colb": "test"})
self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())
@@ -51,7 +51,7 @@ class TestFunctions(unittest.TestCase):
test_date = SF.col(datetime.date(2022, 1, 1))
self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql())
test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1))
- self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
+ self.assertEqual("CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
test_dict = SF.col({"cola": 1, "colb": "test"})
self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())
@@ -250,9 +250,9 @@ class TestFunctions(unittest.TestCase):
def test_log10(self):
col_str = SF.log10("cola")
- self.assertEqual("LOG10(cola)", col_str.sql())
+ self.assertEqual("LOG(10, cola)", col_str.sql())
col = SF.log10(SF.col("cola"))
- self.assertEqual("LOG10(cola)", col.sql())
+ self.assertEqual("LOG(10, cola)", col.sql())
def test_log1p(self):
col_str = SF.log1p("cola")
@@ -262,9 +262,9 @@ class TestFunctions(unittest.TestCase):
def test_log2(self):
col_str = SF.log2("cola")
- self.assertEqual("LOG2(cola)", col_str.sql())
+ self.assertEqual("LOG(2, cola)", col_str.sql())
col = SF.log2(SF.col("cola"))
- self.assertEqual("LOG2(cola)", col.sql())
+ self.assertEqual("LOG(2, cola)", col.sql())
def test_rint(self):
col_str = SF.rint("cola")
@@ -1156,17 +1156,17 @@ class TestFunctions(unittest.TestCase):
def test_regexp_extract(self):
col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1)
- self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)', 1)", col_str.sql())
+ self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql())
col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1)
- self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)', 1)", col.sql())
+ self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql())
col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)")
- self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)')", col_no_idx.sql())
+ self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql())
def test_regexp_replace(self):
col_str = SF.regexp_replace("cola", r"(\d+)", "--")
- self.assertEqual("REGEXP_REPLACE(cola, '(\\d+)', '--')", col_str.sql())
+ self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql())
col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--")
- self.assertEqual("REGEXP_REPLACE(cola, '(\\d+)', '--')", col.sql())
+ self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql())
def test_initcap(self):
col_str = SF.initcap("cola")
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index e2ebae4..848c603 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -79,7 +79,7 @@ class TestDataframeSession(DataFrameSQLValidator):
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
self.assertEqual(
- "WITH t38189 AS (SELECT cola, colb FROM table), t42330 AS (SELECT cola, colb FROM t38189) SELECT cola, SUM(colb) FROM t42330 GROUP BY cola",
+ "WITH t26614 AS (SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`), t23454 AS (SELECT cola, colb FROM t26614) SELECT cola, SUM(colb) FROM t23454 GROUP BY cola",
df.sql(pretty=False, optimize=False)[0],
)
@@ -87,14 +87,14 @@ class TestDataframeSession(DataFrameSQLValidator):
query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
- expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
+ expected = "CREATE TABLE `new_table` AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_sql_insert(self):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
- expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
+ expected = "INSERT INTO `new_table` SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):