summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_bigquery.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_bigquery.py')
-rw-r--r--tests/dialects/test_bigquery.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 8d172ea..3cf95a7 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -804,6 +804,63 @@ WHERE
},
)
+ def test_models(self):
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))"
+ )
+ self.validate_identity(
+ "SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))"
+ )
+ self.validate_identity(
+ "CREATE OR REPLACE MODEL foo OPTIONS (model_type='linear_reg') AS SELECT bla FROM foo WHERE cond"
+ )
+ self.validate_identity(
+ """CREATE OR REPLACE MODEL m
+TRANSFORM(
+ ML.FEATURE_CROSS(STRUCT(f1, f2)) AS cross_f,
+ ML.QUANTILE_BUCKETIZE(f3) OVER () AS buckets,
+ label_col
+)
+OPTIONS (
+ model_type='linear_reg',
+ input_label_cols=['label_col']
+) AS
+SELECT
+ *
+FROM t""",
+ pretty=True,
+ )
+ self.validate_identity(
+ """CREATE MODEL project_id.mydataset.mymodel
+INPUT(
+ f1 INT64,
+ f2 FLOAT64,
+ f3 STRING,
+ f4 ARRAY<INT64>
+)
+OUTPUT(
+ out1 INT64,
+ out2 INT64
+)
+REMOTE WITH CONNECTION myproject.us.test_connection
+OPTIONS (
+ ENDPOINT='https://us-central1-aiplatform.googleapis.com/v1/projects/myproject/locations/us-central1/endpoints/1234'
+)""",
+ pretty=True,
+ )
+
def test_merge(self):
self.validate_all(
"""