Skip to content

Commit b6fa1b9

Browse files
nnegreyShabirmean
authored andcommitted
samples: automl: separate batch predict test, verify model is deployed before prediction (#1931)
* automl: create separate batch prediction test, verify models are deployed before trying to predict * remove bom from automl until bom is released with v1 of client library * Fix typo * Remove score threshold * Rename files from IT to Test * Fix GCS path typo * lint: import order * use fake dataset for export
1 parent 3250995 commit b6fa1b9

8 files changed

+237
-185
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019 Google LLC
2+
* Copyright 2020 Google LLC
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,6 +20,10 @@
2020
import static junit.framework.TestCase.assertNotNull;
2121

2222
import com.google.api.gax.paging.Page;
23+
import com.google.cloud.automl.v1.AutoMlClient;
24+
import com.google.cloud.automl.v1.DeployModelRequest;
25+
import com.google.cloud.automl.v1.Model;
26+
import com.google.cloud.automl.v1.ModelName;
2327
import com.google.cloud.storage.Blob;
2428
import com.google.cloud.storage.Storage;
2529
import com.google.cloud.storage.StorageOptions;
@@ -31,70 +35,54 @@
3135
import org.junit.After;
3236
import org.junit.Before;
3337
import org.junit.BeforeClass;
34-
import org.junit.Ignore;
3538
import org.junit.Test;
3639
import org.junit.runner.RunWith;
3740
import org.junit.runners.JUnit4;
3841

39-
// Tests for automl natural language entity extraction "Predict" sample.
4042
@RunWith(JUnit4.class)
4143
@SuppressWarnings("checkstyle:abbreviationaswordinname")
42-
public class LanguageEntityExtractionPredictIT {
44+
public class BatchPredictTest {
4345
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
44-
private static final String BUCKET_ID = System.getenv("GOOGLE_CLOUD_PROJECT") + "-lcm";
45-
private static final String modelId = System.getenv("ENTITY_EXTRACTION_MODEL_ID");
46+
private static final String BUCKET_ID = PROJECT_ID + "-lcm";
47+
private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID");
4648
private ByteArrayOutputStream bout;
4749
private PrintStream out;
4850

4951
private static void requireEnvVar(String varName) {
5052
assertNotNull(
51-
System.getenv(varName),
52-
"Environment variable '%s' is required to perform these tests.".format(varName)
53-
);
53+
System.getenv(varName),
54+
"Environment variable '%s' is required to perform these tests.".format(varName));
5455
}
5556

5657
@BeforeClass
5758
public static void checkRequirements() {
5859
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
59-
requireEnvVar("GOOGLE_CLOUD_PROJECT");
6060
requireEnvVar("AUTOML_PROJECT_ID");
6161
requireEnvVar("ENTITY_EXTRACTION_MODEL_ID");
6262
}
6363

6464
@Before
65-
public void setUp() {
65+
public void setUp() throws IOException, ExecutionException, InterruptedException {
66+
// Verify that the model is deployed for prediction
67+
try (AutoMlClient client = AutoMlClient.create()) {
68+
ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID);
69+
Model model = client.getModel(modelFullId);
70+
if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) {
71+
// Deploy the model if not deployed
72+
DeployModelRequest request =
73+
DeployModelRequest.newBuilder().setName(modelFullId.toString()).build();
74+
client.deployModelAsync(request).get();
75+
}
76+
}
77+
6678
bout = new ByteArrayOutputStream();
6779
out = new PrintStream(bout);
6880
System.setOut(out);
6981
}
7082

7183
@After
7284
public void tearDown() {
73-
System.setOut(null);
74-
}
75-
76-
@Test
77-
public void testPredict() throws IOException {
78-
String text = "Constitutional mutations in the WT1 gene in patients with Denys-Drash syndrome.";
79-
// Act
80-
LanguageEntityExtractionPredict.predict(PROJECT_ID, modelId, text);
81-
82-
// Assert
83-
String got = bout.toString();
84-
assertThat(got).contains("Text Extract Entity Type:");
85-
}
86-
87-
@Ignore
88-
public void testBatchPredict() throws IOException, ExecutionException, InterruptedException {
89-
String inputUri = String.format("gs://%s/entity_extraction/input.jsonl", BUCKET_ID);
90-
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
91-
// Act
92-
BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);
93-
94-
// Assert
95-
String got = bout.toString();
96-
assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket");
97-
85+
// Delete the created files from GCS
9886
Storage storage = StorageOptions.getDefaultInstance().getService();
9987
Page<Blob> blobs =
10088
storage.list(
@@ -114,5 +102,19 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt
114102
}
115103
}
116104
}
105+
106+
System.setOut(null);
107+
}
108+
109+
@Test
110+
public void testBatchPredict() throws IOException, ExecutionException, InterruptedException {
111+
String inputUri = String.format("gs://%s/entity-extraction/input.jsonl", BUCKET_ID);
112+
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
113+
// Act
114+
BatchPredict.batchPredict(PROJECT_ID, MODEL_ID, inputUri, outputUri);
115+
116+
// Assert
117+
String got = bout.toString();
118+
assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket");
117119
}
118120
}

automl/snippets/src/test/java/com/example/automl/ExportDatasetTest.java

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
public class ExportDatasetTest {
4242

4343
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
44-
private static final String DATASET_ID = System.getenv("ENTITY_EXTRACTION_DATASET_ID");
44+
private static final String DATASET_ID = "TEN0000000000000000000";
4545
private static final String BUCKET_ID = PROJECT_ID + "-lcm";
4646
private static final String BUCKET = "gs://" + BUCKET_ID;
4747
private ByteArrayOutputStream bout;
@@ -69,34 +69,21 @@ public void setUp() {
6969

7070
@After
7171
public void tearDown() {
72-
// Delete the created files from GCS
73-
Storage storage = StorageOptions.getDefaultInstance().getService();
74-
Page<Blob> blobs =
75-
storage.list(
76-
BUCKET_ID,
77-
Storage.BlobListOption.currentDirectory(),
78-
Storage.BlobListOption.prefix("TEST_EXPORT_OUTPUT/"));
79-
80-
for (Blob blob : blobs.iterateAll()) {
81-
Page<Blob> fileBlobs =
82-
storage.list(
83-
BUCKET_ID,
84-
Storage.BlobListOption.currentDirectory(),
85-
Storage.BlobListOption.prefix(blob.getName()));
86-
for (Blob fileBlob : fileBlobs.iterateAll()) {
87-
if (!fileBlob.isDirectory()) {
88-
fileBlob.delete();
89-
}
90-
}
91-
}
92-
9372
System.setOut(null);
9473
}
9574

9675
@Test
9776
public void testExportDataset() throws IOException, ExecutionException, InterruptedException {
98-
ExportDataset.exportDataset(PROJECT_ID, DATASET_ID, BUCKET + "/TEST_EXPORT_OUTPUT/");
99-
String got = bout.toString();
100-
assertThat(got).contains("Dataset exported.");
77+
// As exporting a dataset can take a long time and only one operation can be run on a dataset
78+
// at once. Try to export a nonexistent dataset and confirm that the dataset was not found, but
79+
// other elements of the request were valid.
80+
try {
81+
ExportDataset.exportDataset(PROJECT_ID, DATASET_ID, BUCKET + "/TEST_EXPORT_OUTPUT/");
82+
String got = bout.toString();
83+
assertThat(got).contains("The Dataset doesn't exist or is inaccessible for use with AutoMl.");
84+
} catch (IOException | ExecutionException | InterruptedException e) {
85+
assertThat(e.getMessage())
86+
.contains("The Dataset doesn't exist or is inaccessible for use with AutoMl.");
87+
}
10188
}
10289
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright 2019 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.example.automl;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.cloud.automl.v1.AutoMlClient;
23+
import com.google.cloud.automl.v1.DeployModelRequest;
24+
import com.google.cloud.automl.v1.Model;
25+
import com.google.cloud.automl.v1.ModelName;
26+
27+
import java.io.ByteArrayOutputStream;
28+
import java.io.IOException;
29+
import java.io.PrintStream;
30+
import java.util.concurrent.ExecutionException;
31+
32+
import org.junit.After;
33+
import org.junit.Before;
34+
import org.junit.BeforeClass;
35+
import org.junit.Test;
36+
import org.junit.runner.RunWith;
37+
import org.junit.runners.JUnit4;
38+
39+
@RunWith(JUnit4.class)
40+
@SuppressWarnings("checkstyle:abbreviationaswordinname")
41+
public class LanguageEntityExtractionPredictTest {
42+
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
43+
private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID");
44+
private ByteArrayOutputStream bout;
45+
private PrintStream out;
46+
47+
private static void requireEnvVar(String varName) {
48+
assertNotNull(
49+
System.getenv(varName),
50+
"Environment variable '%s' is required to perform these tests.".format(varName));
51+
}
52+
53+
@BeforeClass
54+
public static void checkRequirements() {
55+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
56+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
57+
requireEnvVar("AUTOML_PROJECT_ID");
58+
requireEnvVar("ENTITY_EXTRACTION_MODEL_ID");
59+
}
60+
61+
@Before
62+
public void setUp() throws IOException, ExecutionException, InterruptedException {
63+
// Verify that the model is deployed for prediction
64+
try (AutoMlClient client = AutoMlClient.create()) {
65+
ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID);
66+
Model model = client.getModel(modelFullId);
67+
if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) {
68+
// Deploy the model if not deployed
69+
DeployModelRequest request =
70+
DeployModelRequest.newBuilder().setName(modelFullId.toString()).build();
71+
client.deployModelAsync(request).get();
72+
}
73+
}
74+
75+
bout = new ByteArrayOutputStream();
76+
out = new PrintStream(bout);
77+
System.setOut(out);
78+
}
79+
80+
@After
81+
public void tearDown() {
82+
System.setOut(null);
83+
}
84+
85+
@Test
86+
public void testPredict() throws IOException {
87+
String text = "Constitutional mutations in the WT1 gene in patients with Denys-Drash syndrome.";
88+
LanguageEntityExtractionPredict.predict(PROJECT_ID, MODEL_ID, text);
89+
String got = bout.toString();
90+
assertThat(got).contains("Text Extract Entity Type:");
91+
}
92+
}

automl/snippets/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictIT.java renamed to automl/snippets/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictTest.java

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@
1919
import static com.google.common.truth.Truth.assertThat;
2020
import static junit.framework.TestCase.assertNotNull;
2121

22+
import com.google.cloud.automl.v1.AutoMlClient;
23+
import com.google.cloud.automl.v1.DeployModelRequest;
24+
import com.google.cloud.automl.v1.Model;
25+
import com.google.cloud.automl.v1.ModelName;
26+
2227
import java.io.ByteArrayOutputStream;
2328
import java.io.IOException;
2429
import java.io.PrintStream;
30+
import java.util.concurrent.ExecutionException;
2531

2632
import org.junit.After;
2733
import org.junit.Before;
@@ -30,20 +36,18 @@
3036
import org.junit.runner.RunWith;
3137
import org.junit.runners.JUnit4;
3238

33-
// Tests for automl natural language sentiment analysis "Predict" sample.
3439
@RunWith(JUnit4.class)
3540
@SuppressWarnings("checkstyle:abbreviationaswordinname")
36-
public class LanguageSentimentAnalysisPredictIT {
41+
public class LanguageSentimentAnalysisPredictTest {
3742
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
38-
private static final String modelId = System.getenv("SENTIMENT_ANALYSIS_MODEL_ID");
43+
private static final String MODEL_ID = System.getenv("SENTIMENT_ANALYSIS_MODEL_ID");
3944
private ByteArrayOutputStream bout;
4045
private PrintStream out;
4146

4247
private static void requireEnvVar(String varName) {
4348
assertNotNull(
44-
System.getenv(varName),
45-
"Environment variable '%s' is required to perform these tests.".format(varName)
46-
);
49+
System.getenv(varName),
50+
"Environment variable '%s' is required to perform these tests.".format(varName));
4751
}
4852

4953
@BeforeClass
@@ -54,7 +58,19 @@ public static void checkRequirements() {
5458
}
5559

5660
@Before
57-
public void setUp() {
61+
public void setUp() throws IOException, ExecutionException, InterruptedException {
62+
// Verify that the model is deployed for prediction
63+
try (AutoMlClient client = AutoMlClient.create()) {
64+
ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID);
65+
Model model = client.getModel(modelFullId);
66+
if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) {
67+
// Deploy the model if not deployed
68+
DeployModelRequest request =
69+
DeployModelRequest.newBuilder().setName(modelFullId.toString()).build();
70+
client.deployModelAsync(request).get();
71+
}
72+
}
73+
5874
bout = new ByteArrayOutputStream();
5975
out = new PrintStream(bout);
6076
System.setOut(out);
@@ -68,10 +84,7 @@ public void tearDown() {
6884
@Test
6985
public void testPredict() throws IOException {
7086
String text = "Hopefully this Claritin kicks in soon";
71-
// Act
72-
LanguageSentimentAnalysisPredict.predict(PROJECT_ID, modelId, text);
73-
74-
// Assert
87+
LanguageSentimentAnalysisPredict.predict(PROJECT_ID, MODEL_ID, text);
7588
String got = bout.toString();
7689
assertThat(got).contains("Predicted sentiment score:");
7790
}

0 commit comments

Comments
 (0)