diff options
Diffstat (limited to 'tests/dialects/test_bigquery.py')
-rw-r--r-- | tests/dialects/test_bigquery.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 8d172ea..3cf95a7 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -804,6 +804,63 @@ WHERE }, ) + def test_models(self): + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))" + ) + self.validate_identity( + "SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))" + ) + self.validate_identity( + "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))" + ) + self.validate_identity( + "CREATE OR REPLACE MODEL foo OPTIONS (model_type='linear_reg') AS SELECT bla FROM foo WHERE cond" + ) + self.validate_identity( + """CREATE OR REPLACE MODEL m +TRANSFORM( + ML.FEATURE_CROSS(STRUCT(f1, f2)) AS cross_f, + ML.QUANTILE_BUCKETIZE(f3) OVER () AS buckets, + label_col +) +OPTIONS ( + model_type='linear_reg', + input_label_cols=['label_col'] +) AS +SELECT + * +FROM t""", + pretty=True, + ) + self.validate_identity( + """CREATE MODEL project_id.mydataset.mymodel +INPUT( + f1 INT64, + f2 FLOAT64, + f3 STRING, + f4 ARRAY<INT64> +) +OUTPUT( + out1 INT64, + out2 INT64 +) +REMOTE WITH CONNECTION myproject.us.test_connection +OPTIONS ( + ENDPOINT='https://us-central1-aiplatform.googleapis.com/v1/projects/myproject/locations/us-central1/endpoints/1234' +)""", + pretty=True, + ) + def test_merge(self): self.validate_all( """ |