diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-25 16:01:43 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-25 16:01:43 +0000 |
commit | 29f9a15cce138301cd5a84a1fd4060494a3a65b6 (patch) | |
tree | c593be2f0b0fdc60a43983aa547b34a441170e59 /tests | |
parent | Adding upstream version 9.0.1. (diff) | |
download | sqlglot-29f9a15cce138301cd5a84a1fd4060494a3a65b6.tar.xz sqlglot-29f9a15cce138301cd5a84a1fd4060494a3a65b6.zip |
Adding upstream version 9.0.3.upstream/9.0.3
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 54 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 223 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 1 | ||||
-rw-r--r-- | tests/test_build.py | 4 | ||||
-rw-r--r-- | tests/test_expressions.py | 66 | ||||
-rw-r--r-- | tests/test_time.py | 2 |
9 files changed, 339 insertions, 33 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 10f3b57..97753bd 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -9,7 +9,6 @@ from sqlglot.errors import ErrorLevel class TestFunctions(unittest.TestCase): - @unittest.skip("not yet fixed.") def test_invoke_anonymous(self): for name, func in inspect.getmembers(SF, inspect.isfunction): with self.subTest(f"{name} should not invoke anonymous_function"): @@ -438,13 +437,13 @@ class TestFunctions(unittest.TestCase): def test_pow(self): col_str = SF.pow("cola", "colb") - self.assertEqual("POW(cola, colb)", col_str.sql()) + self.assertEqual("POWER(cola, colb)", col_str.sql()) col = SF.pow(SF.col("cola"), SF.col("colb")) - self.assertEqual("POW(cola, colb)", col.sql()) + self.assertEqual("POWER(cola, colb)", col.sql()) col_float = SF.pow(10.10, "colb") - self.assertEqual("POW(10.1, colb)", col_float.sql()) + self.assertEqual("POWER(10.1, colb)", col_float.sql()) col_float2 = SF.pow("cola", 10.10) - self.assertEqual("POW(cola, 10.1)", col_float2.sql()) + self.assertEqual("POWER(cola, 10.1)", col_float2.sql()) def test_row_number(self): col_str = SF.row_number() @@ -493,6 +492,8 @@ class TestFunctions(unittest.TestCase): self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql()) col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc")) self.assertEqual("COALESCE(cola, colb, colc)", col.sql()) + col_single = SF.coalesce("cola") + self.assertEqual("COALESCE(cola)", col_single.sql()) def test_corr(self): col_str = SF.corr("cola", "colb") @@ -843,8 +844,8 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TO_DATE(cola)", col_str.sql()) col = SF.to_date(SF.col("cola")) self.assertEqual("TO_DATE(cola)", col.sql()) - col_with_format = SF.to_date("cola", "yyyy-MM-dd") - self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql()) + col_with_format = SF.to_date("cola", "yy-MM-dd") + self.assertEqual("TO_DATE(cola, 'yy-MM-dd')", col_with_format.sql()) def test_to_timestamp(self): col_str = SF.to_timestamp("cola") @@ -883,16 +884,16 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql()) col = SF.from_unixtime(SF.col("cola")) self.assertEqual("FROM_UNIXTIME(cola)", col.sql()) - col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss") - self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm") + self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm')", col_format.sql()) def test_unix_timestamp(self): col_str = SF.unix_timestamp("cola") self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql()) col = SF.unix_timestamp(SF.col("cola")) self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql()) - col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss") - self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm") + self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm')", col_format.sql()) col_current = SF.unix_timestamp() self.assertEqual("UNIX_TIMESTAMP()", col_current.sql()) @@ -1427,6 +1428,13 @@ class TestFunctions(unittest.TestCase): self.assertEqual("ARRAY_SORT(cola)", col_str.sql()) col = SF.array_sort(SF.col("cola")) self.assertEqual("ARRAY_SORT(cola)", col.sql()) + col_comparator = SF.array_sort( + "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x)) + ) + self.assertEqual( + "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)", + col_comparator.sql(), + ) def test_reverse(self): col_str = SF.reverse("cola") @@ -1514,8 +1522,6 @@ class TestFunctions(unittest.TestCase): SF.lit(0), lambda accumulator, target: accumulator + target, lambda accumulator: accumulator * 2, - "accumulator", - "target", ) self.assertEqual( "AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)", @@ -1527,7 +1533,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql()) col = SF.transform(SF.col("cola"), lambda x, i: x * i) self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) - col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count") + col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count) self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) @@ -1536,7 +1542,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql()) col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0) self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql()) - col_custom_name = SF.exists("cola", lambda target: target > 0, "target") + col_custom_name = SF.exists("cola", lambda target: target > 0) self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql()) def test_forall(self): @@ -1544,7 +1550,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql()) col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo")) self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql()) - col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target") + col_custom_name = SF.forall("cola", lambda target: target.rlike("foo")) self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql()) def test_filter(self): @@ -1552,9 +1558,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) - col_custom_names = SF.filter( - "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count" - ) + col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)) self.assertEqual( "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() @@ -1565,7 +1569,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql()) col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) - col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r") + col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r)) self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) def test_transform_keys(self): @@ -1573,7 +1577,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql()) col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k)) self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql()) - col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_") + col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key)) self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql()) def test_transform_values(self): @@ -1581,7 +1585,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql()) col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) - col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value") + col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value)) self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) def test_map_filter(self): @@ -1589,5 +1593,9 @@ class TestFunctions(unittest.TestCase): self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql()) col = SF.map_filter(SF.col("cola"), lambda k, v: k > v) self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql()) - col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value") + col_custom_names = SF.map_filter("cola", lambda key, value: key > value) self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql()) + + def test_map_zip_with(self): + col = SF.map_zip_with("base", "ratio", lambda k, v1, v2: SF.round(v1 * v2, 2)) + self.assertEqual("MAP_ZIP_WITH(base, ratio, (k, v1, v2) -> ROUND(v1 * v2, 2))", col.sql()) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 14fea9d..050d41e 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -106,6 +106,15 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "CURRENT_DATE", + read={ + "tsql": "GETDATE()", + }, + write={ + "tsql": "GETDATE()", + }, + ) + self.validate_all( "current_datetime", write={ "bigquery": "CURRENT_DATETIME()", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index e1524e9..5d1cf13 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -434,12 +434,7 @@ class TestDialect(Validator): "presto": "DATE_ADD('day', 1, x)", "spark": "DATE_ADD(x, 1)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", - }, - ) - self.validate_all( - "DATE_ADD(x, y, 'day')", - write={ - "postgres": UnsupportedError, + "tsql": "DATEADD(day, 1, x)", }, ) self.validate_all( @@ -634,11 +629,13 @@ class TestDialect(Validator): read={ "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", + "starrocks": "x->'y'", }, write={ "oracle": "JSON_EXTRACT(x, 'y')", "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", + "starrocks": "x->'y'", }, ) self.validate_all( @@ -983,6 +980,7 @@ class TestDialect(Validator): ) def test_limit(self): + self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}) self.validate_all( "SELECT x FROM y LIMIT 10", write={ diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 339d1a6..8605bd1 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -282,3 +282,6 @@ TBLPROPERTIES ( "spark": "SELECT ARRAY_SORT(x)", }, ) + + def test_iif(self): + self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 9a6bc36..2a20163 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -71,3 +71,226 @@ class TestTSQL(Validator): "spark": "LOCATE('sub', 'testsubstring')", }, ) + + def test_len(self): + self.validate_all("LEN(x)", write={"spark": "LENGTH(x)"}) + + def test_replicate(self): + self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"}) + + def test_isnull(self): + self.validate_all("ISNULL(x, y)", write={"spark": "COALESCE(x, y)"}) + + def test_jsonvalue(self): + self.validate_all( + "JSON_VALUE(r.JSON, '$.Attr_INT')", + write={"spark": "GET_JSON_OBJECT(r.JSON, '$.Attr_INT')"}, + ) + + def test_datefromparts(self): + self.validate_all( + "SELECT DATEFROMPARTS('2020', 10, 01)", + write={"spark": "SELECT MAKE_DATE('2020', 10, 01)"}, + ) + + def test_datename(self): + self.validate_all( + "SELECT DATENAME(mm,'01-01-1970')", + write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MMMM')"}, + ) + self.validate_all( + "SELECT DATENAME(dw,'01-01-1970')", + write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'EEEE')"}, + ) + + def test_datepart(self): + self.validate_all( + "SELECT DATEPART(month,'01-01-1970')", + write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MM')"}, + ) + + def test_convert_date_format(self): + self.validate_all( + "CONVERT(NVARCHAR(200), x)", + write={ + "spark": "CAST(x AS VARCHAR(200))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR, x)", + write={ + "spark": "CAST(x AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR(MAX), x)", + write={ + "spark": "CAST(x AS STRING)", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(200), x)", + write={ + "spark": "CAST(x AS VARCHAR(200))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR, x)", + write={ + "spark": "CAST(x AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(MAX), x)", + write={ + "spark": "CAST(x AS STRING)", + }, + ) + self.validate_all( + "CONVERT(CHAR(40), x)", + write={ + "spark": "CAST(x AS CHAR(40))", + }, + ) + self.validate_all( + "CONVERT(CHAR, x)", + write={ + "spark": "CAST(x AS CHAR(30))", + }, + ) + self.validate_all( + "CONVERT(NCHAR(40), x)", + write={ + "spark": "CAST(x AS CHAR(40))", + }, + ) + self.validate_all( + "CONVERT(NCHAR, x)", + write={ + "spark": "CAST(x AS CHAR(30))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR, x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(40), x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(MAX), x, 121)", + write={ + "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR, x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR(40), x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR(MAX), x, 121)", + write={ + "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(DATE, x, 121)", + write={ + "spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(DATETIME, x, 121)", + write={ + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(DATETIME2, x, 121)", + write={ + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(INT, x)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "CONVERT(INT, x, 121)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "TRY_CONVERT(NVARCHAR, x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + }, + ) + self.validate_all( + "TRY_CONVERT(INT, x)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "TRY_CAST(x AS INT)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "CAST(x AS INT)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + + def test_add_date(self): + self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") + self.validate_all( + "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} + ) + self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}) + self.validate_all("SELECT DATEADD(wk, 1, '2017/08/25')", write={"spark": "SELECT DATE_ADD('2017/08/25', 7)"}) + + def test_date_diff(self): + self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") + self.validate_all( + "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", + write={ + "tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", + "spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", + }, + ) + self.validate_all( + "SELECT DATEDIFF(month, 'start','end')", + write={"spark": "SELECT MONTHS_BETWEEN('end', 'start')", "tsql": "SELECT DATEDIFF(month, 'start', 'end')"}, + ) + self.validate_all( + "SELECT DATEDIFF(quarter, 'start', 'end')", write={"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3"} + ) + + def test_iif(self): + self.validate_identity("SELECT IIF(cond, 'True', 'False')") + self.validate_all( + "SELECT IIF(cond, 'True', 'False');", + write={ + "spark": "SELECT IF(cond, 'True', 'False')", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 67e4cab..d7084ac 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -149,7 +149,6 @@ SELECT 1 AS count FROM test SELECT 1 AS comment FROM test SELECT 1 AS numeric FROM test SELECT 1 AS number FROM test -SELECT 1 AS number # annotation SELECT t.count SELECT DISTINCT x FROM test SELECT DISTINCT x, y FROM test diff --git a/tests/test_build.py b/tests/test_build.py index a432ef1..f51996d 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -330,6 +330,10 @@ class TestBuild(unittest.TestCase): "UPDATE tbl SET x = 1 WHERE y > 0", ), ( + lambda: exp.update("tbl", {"x": 1}, where=exp.condition("y > 0")), + "UPDATE tbl SET x = 1 WHERE y > 0", + ), + ( lambda: exp.update("tbl", {"x": 1}, from_="tbl2"), "UPDATE tbl SET x = 1 FROM tbl2", ), diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 79b4ee5..9af59d9 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -135,6 +135,53 @@ class TestExpressions(unittest.TestCase): "SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a", ) + def test_replace_placeholders(self): + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from :tbl1 JOIN :tbl2 ON :col1 = :col2 WHERE :col3 > 100"), + tbl1="foo", + tbl2="bar", + col1="a", + col2="b", + col3="c", + ).sql(), + "SELECT * FROM foo JOIN bar ON a = b WHERE c > 100", + ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from ? JOIN ? ON ? = ? WHERE ? > 100"), + "foo", + "bar", + "a", + "b", + "c", + ).sql(), + "SELECT * FROM foo JOIN bar ON a = b WHERE c > 100", + ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from ? WHERE ? > 100"), + "foo", + ).sql(), + "SELECT * FROM foo WHERE ? > 100", + ) + self.assertEqual( + exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(), + "SELECT * FROM :name WHERE ? > 100", + ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from (SELECT :col1 FROM ?) WHERE :col2 > 100"), + "tbl1", + "tbl2", + "tbl3", + col1="a", + col2="b", + col3="c", + ).sql(), + "SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100", + ) + def test_named_selects(self): expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) @@ -504,9 +551,24 @@ class TestExpressions(unittest.TestCase): [e.alias_or_name for e in expression.expressions], ["a", "B", "c", "D"], ) - self.assertEqual(expression.sql(), sql) + self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D") self.assertEqual(expression.expressions[2].name, "comment") - self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D") + self.assertEqual( + expression.sql(pretty=True, annotations=False), + """SELECT + a, + b AS B, + c, + d AS D""", + ) + self.assertEqual( + expression.sql(pretty=True), + """SELECT + a, + b AS B, + c # comment, + d AS D # another_comment FROM foo""", + ) def test_to_table(self): table_only = exp.to_table("table_name") diff --git a/tests/test_time.py b/tests/test_time.py index 17821c2..bd0e63f 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -5,7 +5,7 @@ from sqlglot.time import format_time class TestTime(unittest.TestCase): def test_format_time(self): - self.assertEqual(format_time("", {}), "") + self.assertEqual(format_time("", {}), None) self.assertEqual(format_time(" ", {}), " ") mapping = {"a": "b", "aa": "c"} self.assertEqual(format_time("a", mapping), "b") |