diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_bigquery.py | 57 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 16 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 4 |
6 files changed, 95 insertions, 5 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 8d172ea..3cf95a7 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -804,6 +804,63 @@ WHERE }, ) + def test_models(self): + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))" + ) + self.validate_identity( + "SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))" + ) + self.validate_identity( + "CREATE OR REPLACE MODEL foo OPTIONS (model_type='linear_reg') AS SELECT bla FROM foo WHERE cond" + ) + self.validate_identity( + """CREATE OR REPLACE MODEL m +TRANSFORM( + ML.FEATURE_CROSS(STRUCT(f1, f2)) AS cross_f, + ML.QUANTILE_BUCKETIZE(f3) OVER () AS buckets, + label_col +) +OPTIONS ( + model_type='linear_reg', + input_label_cols=['label_col'] +) AS +SELECT + * +FROM t""", + pretty=True, + ) + self.validate_identity( + """CREATE MODEL project_id.mydataset.mymodel +INPUT( + f1 INT64, + f2 FLOAT64, + f3 STRING, + f4 ARRAY<INT64> +) +OUTPUT( + out1 INT64, + out2 INT64 +) +REMOTE WITH CONNECTION myproject.us.test_connection +OPTIONS ( + ENDPOINT='https://us-central1-aiplatform.googleapis.com/v1/projects/myproject/locations/us-central1/endpoints/1234' +)""", + pretty=True, + ) + def test_merge(self): self.validate_all( """ diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 11f921c..14a864b 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -30,6 +30,9 @@ class TestMySQL(Validator): self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))") self.validate_identity( + "CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))" + ) + self.validate_identity( "CREATE TABLE `x` (`username` VARCHAR(200), PRIMARY KEY (`username`(16)))" ) self.validate_identity( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 0ddc106..22bede4 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -648,10 +648,10 @@ class TestPostgres(Validator): }, ) self.validate_all( - "merge into x as x using (select id) as y on a = b WHEN matched then update set X.a = y.b", + """merge into x as x using (select id) as y on a = b WHEN matched then update set X."A" = y.b""", write={ - "postgres": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", - "snowflake": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X.a = y.b", + "postgres": """MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET "A" = y.b""", + "snowflake": """MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X."A" = y.b""", }, ) self.validate_all( @@ -724,3 +724,11 @@ class TestPostgres(Validator): "presto": "CONCAT(CAST(a AS VARCHAR), CAST(b AS VARCHAR))", }, ) + + def test_variance(self): + self.validate_all("VAR_SAMP(x)", write={"postgres": "VAR_SAMP(x)"}) + self.validate_all("VAR_POP(x)", write={"postgres": "VAR_POP(x)"}) + self.validate_all("VARIANCE(x)", write={"postgres": "VAR_SAMP(x)"}) + self.validate_all( + "VAR_POP(x)", read={"": "VARIANCE_POP(x)"}, write={"postgres": "VAR_POP(x)"} + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 8edd31c..fd297d7 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -353,6 +353,12 @@ class TestPresto(Validator): }, ) self.validate_all( + "CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'", + read={ + "spark": "FROM_UTC_TIMESTAMP('2012-10-31 00:00', 'America/Sao_Paulo')", + }, + ) + self.validate_all( "CAST(x AS TIMESTAMP)", write={"presto": "CAST(x AS TIMESTAMP)"}, read={"mysql": "CAST(x AS DATETIME)", "clickhouse": "CAST(x AS DATETIME64)"}, diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index ae1b987..9f2761f 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -273,6 +273,22 @@ class TestRedshift(Validator): "SELECT DATE_ADD('day', 1, DATE('2023-01-01'))", "SELECT DATEADD(day, 1, CAST(DATE('2023-01-01') AS DATE))", ) + self.validate_identity( + """SELECT + c_name, + orders.o_orderkey AS orderkey, + index AS orderkey_index +FROM customer_orders_lineitem AS c, c.c_orders AS orders AT index +ORDER BY + orderkey_index""", + pretty=True, + ) + self.validate_identity( + "SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders[0] WHERE c_custkey = 9451" + ) + self.validate_identity( + "SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders AS val AT attr WHERE c_custkey = 9451" + ) def test_values(self): # Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 7d89d06..fbd913d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -32,7 +32,7 @@ class TestTSQL(Validator): self.validate_all( """CREATE TABLE [dbo].[mytable]( [email] [varchar](255) NOT NULL, - CONSTRAINT [UN_t_mytable] UNIQUE NONCLUSTERED + CONSTRAINT [UN_t_mytable] UNIQUE NONCLUSTERED ( [email] ASC ) @@ -343,7 +343,7 @@ class TestTSQL(Validator): "CAST(x as DOUBLE)", write={ "spark": "CAST(x AS DOUBLE)", - "tsql": "CAST(x AS DOUBLE)", + "tsql": "CAST(x AS FLOAT)", }, ) |