11/*
2- * Copyright 2019 Google LLC
2+ * Copyright 2020 Google LLC
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2020import static junit .framework .TestCase .assertNotNull ;
2121
2222import com .google .api .gax .paging .Page ;
23+ import com .google .cloud .automl .v1 .AutoMlClient ;
24+ import com .google .cloud .automl .v1 .DeployModelRequest ;
25+ import com .google .cloud .automl .v1 .Model ;
26+ import com .google .cloud .automl .v1 .ModelName ;
2327import com .google .cloud .storage .Blob ;
2428import com .google .cloud .storage .Storage ;
2529import com .google .cloud .storage .StorageOptions ;
3135import org .junit .After ;
3236import org .junit .Before ;
3337import org .junit .BeforeClass ;
34- import org .junit .Ignore ;
3538import org .junit .Test ;
3639import org .junit .runner .RunWith ;
3740import org .junit .runners .JUnit4 ;
3841
39- // Tests for automl natural language entity extraction "Predict" sample.
4042@ RunWith (JUnit4 .class )
4143@ SuppressWarnings ("checkstyle:abbreviationaswordinname" )
42- public class LanguageEntityExtractionPredictIT {
44+ public class BatchPredictTest {
4345 private static final String PROJECT_ID = System .getenv ("AUTOML_PROJECT_ID" );
44- private static final String BUCKET_ID = System . getenv ( "GOOGLE_CLOUD_PROJECT" ) + "-lcm" ;
45- private static final String modelId = System .getenv ("ENTITY_EXTRACTION_MODEL_ID" );
46+ private static final String BUCKET_ID = PROJECT_ID + "-lcm" ;
47+ private static final String MODEL_ID = System .getenv ("ENTITY_EXTRACTION_MODEL_ID" );
4648 private ByteArrayOutputStream bout ;
4749 private PrintStream out ;
4850
4951 private static void requireEnvVar (String varName ) {
5052 assertNotNull (
51- System .getenv (varName ),
52- "Environment variable '%s' is required to perform these tests." .format (varName )
53- );
53+ System .getenv (varName ),
54+ "Environment variable '%s' is required to perform these tests." .format (varName ));
5455 }
5556
5657 @ BeforeClass
5758 public static void checkRequirements () {
5859 requireEnvVar ("GOOGLE_APPLICATION_CREDENTIALS" );
59- requireEnvVar ("GOOGLE_CLOUD_PROJECT" );
6060 requireEnvVar ("AUTOML_PROJECT_ID" );
6161 requireEnvVar ("ENTITY_EXTRACTION_MODEL_ID" );
6262 }
6363
6464 @ Before
65- public void setUp () {
65+ public void setUp () throws IOException , ExecutionException , InterruptedException {
66+ // Verify that the model is deployed for prediction
67+ try (AutoMlClient client = AutoMlClient .create ()) {
68+ ModelName modelFullId = ModelName .of (PROJECT_ID , "us-central1" , MODEL_ID );
69+ Model model = client .getModel (modelFullId );
70+ if (model .getDeploymentState () == Model .DeploymentState .UNDEPLOYED ) {
71+ // Deploy the model if not deployed
72+ DeployModelRequest request =
73+ DeployModelRequest .newBuilder ().setName (modelFullId .toString ()).build ();
74+ client .deployModelAsync (request ).get ();
75+ }
76+ }
77+
6678 bout = new ByteArrayOutputStream ();
6779 out = new PrintStream (bout );
6880 System .setOut (out );
6981 }
7082
7183 @ After
7284 public void tearDown () {
73- System .setOut (null );
74- }
75-
76- @ Test
77- public void testPredict () throws IOException {
78- String text = "Constitutional mutations in the WT1 gene in patients with Denys-Drash syndrome." ;
79- // Act
80- LanguageEntityExtractionPredict .predict (PROJECT_ID , modelId , text );
81-
82- // Assert
83- String got = bout .toString ();
84- assertThat (got ).contains ("Text Extract Entity Type:" );
85- }
86-
87- @ Ignore
88- public void testBatchPredict () throws IOException , ExecutionException , InterruptedException {
89- String inputUri = String .format ("gs://%s/entity_extraction/input.jsonl" , BUCKET_ID );
90- String outputUri = String .format ("gs://%s/TEST_BATCH_PREDICT/" , BUCKET_ID );
91- // Act
92- BatchPredict .batchPredict (PROJECT_ID , modelId , inputUri , outputUri );
93-
94- // Assert
95- String got = bout .toString ();
96- assertThat (got ).contains ("Batch Prediction results saved to specified Cloud Storage bucket" );
97-
85+ // Delete the created files from GCS
9886 Storage storage = StorageOptions .getDefaultInstance ().getService ();
9987 Page <Blob > blobs =
10088 storage .list (
@@ -114,5 +102,19 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt
114102 }
115103 }
116104 }
105+
106+ System .setOut (null );
107+ }
108+
109+ @ Test
110+ public void testBatchPredict () throws IOException , ExecutionException , InterruptedException {
111+ String inputUri = String .format ("gs://%s/entity-extraction/input.jsonl" , BUCKET_ID );
112+ String outputUri = String .format ("gs://%s/TEST_BATCH_PREDICT/" , BUCKET_ID );
113+ // Act
114+ BatchPredict .batchPredict (PROJECT_ID , MODEL_ID , inputUri , outputUri );
115+
116+ // Assert
117+ String got = bout .toString ();
118+ assertThat (got ).contains ("Batch Prediction results saved to specified Cloud Storage bucket" );
117119 }
118120}
0 commit comments