summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_bigquery.py57
-rw-r--r--tests/dialects/test_mysql.py3
-rw-r--r--tests/dialects/test_postgres.py14
-rw-r--r--tests/dialects/test_presto.py6
-rw-r--r--tests/dialects/test_redshift.py16
-rw-r--r--tests/dialects/test_tsql.py4
6 files changed, 95 insertions, 5 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 8d172ea..3cf95a7 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -804,6 +804,63 @@ WHERE
},
)
+ def test_models(self):
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))"
+ )
+ self.validate_identity(
+ "SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))"
+ )
+ self.validate_identity(
+ "CREATE OR REPLACE MODEL foo OPTIONS (model_type='linear_reg') AS SELECT bla FROM foo WHERE cond"
+ )
+ self.validate_identity(
+ """CREATE OR REPLACE MODEL m
+TRANSFORM(
+ ML.FEATURE_CROSS(STRUCT(f1, f2)) AS cross_f,
+ ML.QUANTILE_BUCKETIZE(f3) OVER () AS buckets,
+ label_col
+)
+OPTIONS (
+ model_type='linear_reg',
+ input_label_cols=['label_col']
+) AS
+SELECT
+ *
+FROM t""",
+ pretty=True,
+ )
+ self.validate_identity(
+ """CREATE MODEL project_id.mydataset.mymodel
+INPUT(
+ f1 INT64,
+ f2 FLOAT64,
+ f3 STRING,
+ f4 ARRAY<INT64>
+)
+OUTPUT(
+ out1 INT64,
+ out2 INT64
+)
+REMOTE WITH CONNECTION myproject.us.test_connection
+OPTIONS (
+ ENDPOINT='https://us-central1-aiplatform.googleapis.com/v1/projects/myproject/locations/us-central1/endpoints/1234'
+)""",
+ pretty=True,
+ )
+
def test_merge(self):
self.validate_all(
"""
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 11f921c..14a864b 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -30,6 +30,9 @@ class TestMySQL(Validator):
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
self.validate_identity(
+ "CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))"
+ )
+ self.validate_identity(
"CREATE TABLE `x` (`username` VARCHAR(200), PRIMARY KEY (`username`(16)))"
)
self.validate_identity(
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 0ddc106..22bede4 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -648,10 +648,10 @@ class TestPostgres(Validator):
},
)
self.validate_all(
- "merge into x as x using (select id) as y on a = b WHEN matched then update set X.a = y.b",
+ """merge into x as x using (select id) as y on a = b WHEN matched then update set X."A" = y.b""",
write={
- "postgres": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b",
- "snowflake": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X.a = y.b",
+ "postgres": """MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET "A" = y.b""",
+ "snowflake": """MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X."A" = y.b""",
},
)
self.validate_all(
@@ -724,3 +724,11 @@ class TestPostgres(Validator):
"presto": "CONCAT(CAST(a AS VARCHAR), CAST(b AS VARCHAR))",
},
)
+
+ def test_variance(self):
+ self.validate_all("VAR_SAMP(x)", write={"postgres": "VAR_SAMP(x)"})
+ self.validate_all("VAR_POP(x)", write={"postgres": "VAR_POP(x)"})
+ self.validate_all("VARIANCE(x)", write={"postgres": "VAR_SAMP(x)"})
+ self.validate_all(
+ "VAR_POP(x)", read={"": "VARIANCE_POP(x)"}, write={"postgres": "VAR_POP(x)"}
+ )
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 8edd31c..fd297d7 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -353,6 +353,12 @@ class TestPresto(Validator):
},
)
self.validate_all(
+ "CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'",
+ read={
+ "spark": "FROM_UTC_TIMESTAMP('2012-10-31 00:00', 'America/Sao_Paulo')",
+ },
+ )
+ self.validate_all(
"CAST(x AS TIMESTAMP)",
write={"presto": "CAST(x AS TIMESTAMP)"},
read={"mysql": "CAST(x AS DATETIME)", "clickhouse": "CAST(x AS DATETIME64)"},
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index ae1b987..9f2761f 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -273,6 +273,22 @@ class TestRedshift(Validator):
"SELECT DATE_ADD('day', 1, DATE('2023-01-01'))",
"SELECT DATEADD(day, 1, CAST(DATE('2023-01-01') AS DATE))",
)
+ self.validate_identity(
+ """SELECT
+ c_name,
+ orders.o_orderkey AS orderkey,
+ index AS orderkey_index
+FROM customer_orders_lineitem AS c, c.c_orders AS orders AT index
+ORDER BY
+ orderkey_index""",
+ pretty=True,
+ )
+ self.validate_identity(
+ "SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders[0] WHERE c_custkey = 9451"
+ )
+ self.validate_identity(
+ "SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders AS val AT attr WHERE c_custkey = 9451"
+ )
def test_values(self):
# Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 7d89d06..fbd913d 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -32,7 +32,7 @@ class TestTSQL(Validator):
self.validate_all(
"""CREATE TABLE [dbo].[mytable](
[email] [varchar](255) NOT NULL,
- CONSTRAINT [UN_t_mytable] UNIQUE NONCLUSTERED
+ CONSTRAINT [UN_t_mytable] UNIQUE NONCLUSTERED
(
[email] ASC
)
@@ -343,7 +343,7 @@ class TestTSQL(Validator):
"CAST(x as DOUBLE)",
write={
"spark": "CAST(x AS DOUBLE)",
- "tsql": "CAST(x AS DOUBLE)",
+ "tsql": "CAST(x AS FLOAT)",
},
)