summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_bigquery.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-10 08:53:10 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-10 08:53:10 +0000
commitf7cb7fdb0fb5a8e2d053c1aa18dd98462401a64e (patch)
tree75bbd792c82b8d1e70b5561de82a5b270b61867c /tests/dialects/test_bigquery.py
parentAdding upstream version 18.11.2. (diff)
downloadsqlglot-f7cb7fdb0fb5a8e2d053c1aa18dd98462401a64e.tar.xz
sqlglot-f7cb7fdb0fb5a8e2d053c1aa18dd98462401a64e.zip
Adding upstream version 18.11.6.upstream/18.11.6
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dialects/test_bigquery.py')
-rw-r--r--tests/dialects/test_bigquery.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 8d172ea..3cf95a7 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -804,6 +804,63 @@ WHERE
},
)
+ def test_models(self):
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))"
+ )
+ self.validate_identity(
+ "SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))"
+ )
+ self.validate_identity(
+ "SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))"
+ )
+ self.validate_identity(
+ "CREATE OR REPLACE MODEL foo OPTIONS (model_type='linear_reg') AS SELECT bla FROM foo WHERE cond"
+ )
+ self.validate_identity(
+ """CREATE OR REPLACE MODEL m
+TRANSFORM(
+ ML.FEATURE_CROSS(STRUCT(f1, f2)) AS cross_f,
+ ML.QUANTILE_BUCKETIZE(f3) OVER () AS buckets,
+ label_col
+)
+OPTIONS (
+ model_type='linear_reg',
+ input_label_cols=['label_col']
+) AS
+SELECT
+ *
+FROM t""",
+ pretty=True,
+ )
+ self.validate_identity(
+ """CREATE MODEL project_id.mydataset.mymodel
+INPUT(
+ f1 INT64,
+ f2 FLOAT64,
+ f3 STRING,
+ f4 ARRAY<INT64>
+)
+OUTPUT(
+ out1 INT64,
+ out2 INT64
+)
+REMOTE WITH CONNECTION myproject.us.test_connection
+OPTIONS (
+ ENDPOINT='https://us-central1-aiplatform.googleapis.com/v1/projects/myproject/locations/us-central1/endpoints/1234'
+)""",
+ pretty=True,
+ )
+
def test_merge(self):
self.validate_all(
"""