From 7bdfa559f477eed3bd9d819c4d640ad792cbf0e6 Mon Sep 17 00:00:00 2001 From: "copybara-service[bot]" <56741989+copybara-service[bot]@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:48:54 -0400 Subject: [PATCH] fix: Simplify VertexAI with Suppliers.memorize and avoid accessing private members in tests. (#10694) - Implement lazy init using Suppliers.memorize instead of an explicit lock. - Add a newBuilder method in VertexAI. - Updates unit tests to avoid accessing private fields in VertexAI. PiperOrigin-RevId: 624303836 Co-authored-by: A Vertex SDK engineer --- .../com/google/cloud/vertexai/VertexAI.java | 142 +++++++++++------- .../generativeai/ChatSessionTest.java | 15 +- .../generativeai/GenerativeModelTest.java | 93 +----------- .../it/ITGenerativeModelIntegrationTest.java | 26 ---- 4 files changed, 103 insertions(+), 173 deletions(-) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index 1e1cc2ec9449..7623e3b52707 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -32,11 +32,13 @@ import com.google.cloud.vertexai.api.PredictionServiceClient; import com.google.cloud.vertexai.api.PredictionServiceSettings; import com.google.common.base.Strings; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.io.IOException; import java.util.List; import java.util.Optional; -import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import java.util.logging.Logger; @@ -61,13 +63,12 @@ public class VertexAI implements AutoCloseable { private final String apiEndpoint; private final Transport transport; private final CredentialsProvider credentialsProvider; - private final ReentrantLock lock = new ReentrantLock(); - // The clients will be instantiated lazily - private Optional predictionServiceClient = Optional.empty(); - private Optional llmUtilityClient = Optional.empty(); + + private final transient Supplier predictionClientSupplier; + private final transient Supplier llmClientSupplier; /** - * Construct a VertexAI instance. + * Constructs a VertexAI instance. * * @param projectId the default project to use when making API calls * @param location the default location to use when making API calls @@ -78,8 +79,10 @@ public VertexAI(String projectId, String location) { location, Transport.GRPC, ImmutableList.of(), - Optional.empty(), - Optional.empty()); + /* credentials= */ Optional.empty(), + /* apiEndpoint= */ Optional.empty(), + /* predictionClientSupplierOpt= */ Optional.empty(), + /* llmClientSupplierOpt= */ Optional.empty()); } private VertexAI( @@ -88,7 +91,9 @@ private VertexAI( Transport transport, List scopes, Optional credentials, - Optional apiEndpoint) { + Optional apiEndpoint, + Optional> predictionClientSupplierOpt, + Optional> llmClientSupplierOpt) { if (!scopes.isEmpty() && credentials.isPresent()) { throw new IllegalArgumentException( "At most one of Credentials and scopes should be specified."); @@ -113,9 +118,19 @@ private VertexAI( .build(); } + this.predictionClientSupplier = + Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient)); + + this.llmClientSupplier = + Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient)); + this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location)); } + public static Builder builder() { + return new Builder(); + } + /** Builder for {@link VertexAI}. */ public static class Builder { private String projectId; @@ -125,11 +140,25 @@ public static class Builder { private Optional credentials = Optional.empty(); private Optional apiEndpoint = Optional.empty(); + private Supplier predictionClientSupplier; + + private Supplier llmClientSupplier; + + Builder() {} + public VertexAI build() { checkNotNull(projectId, "projectId must be set."); checkNotNull(location, "location must be set."); - return new VertexAI(projectId, location, transport, scopes, credentials, apiEndpoint); + return new VertexAI( + projectId, + location, + transport, + scopes, + credentials, + apiEndpoint, + Optional.ofNullable(predictionClientSupplier), + Optional.ofNullable(llmClientSupplier)); } public Builder setProjectId(String projectId) { @@ -167,6 +196,19 @@ public Builder setCredentials(Credentials credentials) { return this; } + @CanIgnoreReturnValue + public Builder setPredictionClientSupplier( + Supplier predictionClientSupplier) { + this.predictionClientSupplier = predictionClientSupplier; + return this; + } + + @CanIgnoreReturnValue + public Builder setLlmClientSupplier(Supplier llmClientSupplier) { + this.llmClientSupplier = llmClientSupplier; + return this; + } + public Builder setScopes(List scopes) { checkNotNull(scopes, "scopes can't be null"); @@ -228,25 +270,23 @@ public Credentials getCredentials() throws IOException { * method calls that map to the API methods. */ @InternalApi - public PredictionServiceClient getPredictionServiceClient() throws IOException { - if (predictionServiceClient.isPresent()) { - return predictionServiceClient.get(); - } - lock.lock(); + public PredictionServiceClient getPredictionServiceClient() { + return predictionClientSupplier.get(); + } + + private PredictionServiceClient newPredictionServiceClient() { + // Disable the warning message logged in getApplicationDefault + Logger defaultCredentialsProviderLogger = + Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); + Level previousLevel = defaultCredentialsProviderLogger.getLevel(); + defaultCredentialsProviderLogger.setLevel(Level.SEVERE); + try { - if (!predictionServiceClient.isPresent()) { - PredictionServiceSettings settings = getPredictionServiceSettings(); - // Disable the warning message logged in getApplicationDefault - Logger defaultCredentialsProviderLogger = - Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); - Level previousLevel = defaultCredentialsProviderLogger.getLevel(); - defaultCredentialsProviderLogger.setLevel(Level.SEVERE); - predictionServiceClient = Optional.of(PredictionServiceClient.create(settings)); - defaultCredentialsProviderLogger.setLevel(previousLevel); - } - return predictionServiceClient.get(); + return PredictionServiceClient.create(getPredictionServiceSettings()); + } catch (IOException e) { + throw new IllegalStateException(e); } finally { - lock.unlock(); + defaultCredentialsProviderLogger.setLevel(previousLevel); } } @@ -257,8 +297,8 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept } else { builder = PredictionServiceSettings.newBuilder(); } - builder.setEndpoint(String.format("%s:443", this.apiEndpoint)); - builder.setCredentialsProvider(this.credentialsProvider); + builder.setEndpoint(String.format("%s:443", apiEndpoint)); + builder.setCredentialsProvider(credentialsProvider); HeaderProvider headerProvider = FixedHeaderProvider.create( @@ -279,25 +319,23 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept * calls that map to the API methods. */ @InternalApi - public LlmUtilityServiceClient getLlmUtilityClient() throws IOException { - if (llmUtilityClient.isPresent()) { - return llmUtilityClient.get(); - } - lock.lock(); + public LlmUtilityServiceClient getLlmUtilityClient() { + return llmClientSupplier.get(); + } + + private LlmUtilityServiceClient newLlmUtilityClient() { + // Disable the warning message logged in getApplicationDefault + Logger defaultCredentialsProviderLogger = + Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); + Level previousLevel = defaultCredentialsProviderLogger.getLevel(); + defaultCredentialsProviderLogger.setLevel(Level.SEVERE); + try { - if (!llmUtilityClient.isPresent()) { - LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings(); - // Disable the warning message logged in getApplicationDefault - Logger defaultCredentialsProviderLogger = - Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); - Level previousLevel = defaultCredentialsProviderLogger.getLevel(); - defaultCredentialsProviderLogger.setLevel(Level.SEVERE); - llmUtilityClient = Optional.of(LlmUtilityServiceClient.create(settings)); - defaultCredentialsProviderLogger.setLevel(previousLevel); - } - return llmUtilityClient.get(); + return LlmUtilityServiceClient.create(getLlmUtilityServiceClientSettings()); + } catch (IOException e) { + throw new IllegalStateException(e); } finally { - lock.unlock(); + defaultCredentialsProviderLogger.setLevel(previousLevel); } } @@ -308,8 +346,8 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO } else { settingsBuilder = LlmUtilityServiceSettings.newBuilder(); } - settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); - settingsBuilder.setCredentialsProvider(this.credentialsProvider); + settingsBuilder.setEndpoint(String.format("%s:443", apiEndpoint)); + settingsBuilder.setCredentialsProvider(credentialsProvider); HeaderProvider headerProvider = FixedHeaderProvider.create( @@ -325,11 +363,7 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO /** Closes the VertexAI instance together with all its instantiated clients. */ @Override public void close() { - if (predictionServiceClient.isPresent()) { - predictionServiceClient.get().close(); - } - if (llmUtilityClient.isPresent()) { - llmUtilityClient.get().close(); - } + predictionClientSupplier.get().close(); + llmClientSupplier.get().close(); } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java index a4cabcf068c5..67c64505b226 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java @@ -41,11 +41,9 @@ import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.api.Type; import java.io.IOException; -import java.lang.reflect.Field; import java.util.Arrays; import java.util.Iterator; import java.util.List; -import java.util.Optional; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -309,12 +307,15 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot public void testChatSessionMergeHistoryToRootChatSession() throws Exception { // (Arrange) Set up the return value of the generateContent - VertexAI vertexAi = new VertexAI(PROJECT, LOCATION); - GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); + VertexAI vertexAi = + VertexAI.builder() + .setProjectId(PROJECT) + .setLocation(LOCATION) + .setPredictionClientSupplier(() -> mockPredictionServiceClient) + .build(); + + GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index d9b4ff777004..f0545521ab28 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -47,12 +47,10 @@ import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.api.Type; import com.google.cloud.vertexai.api.VertexAISearch; -import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; -import java.util.Optional; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -68,8 +66,6 @@ public final class GenerativeModelTest { private static final String PROJECT = "test_project"; private static final String LOCATION = "test_location"; private static final String MODEL_NAME = "gemini-pro"; - private static final String MODEL_NAME_2 = "models/gemini-pro"; - private static final String MODEL_NAME_3 = "publishers/google/models/gemini-pro"; private static final GenerationConfig GENERATION_CONFIG = GenerationConfig.newBuilder().setCandidateCount(1).build(); private static final GenerationConfig DEFAULT_GENERATION_CONFIG = @@ -145,12 +141,17 @@ public final class GenerativeModelTest { @Mock private ApiFuture mockApiFuture; @Before - public void doBeforeEachTest() { + public void setUp() { + when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); + when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture); + vertexAi = - new VertexAI.Builder() + VertexAI.builder() .setProjectId(PROJECT) .setLocation(LOCATION) .setCredentials(mockGoogleCredentials) + .setLlmClientSupplier(() -> mockLlmUtilityServiceClient) + .setPredictionClientSupplier(() -> mockPredictionServiceClient) .build(); } @@ -240,10 +241,6 @@ public void testInstantiateGenerativeModelwithBuilderMissingVertexAi() { public void testCountTokenswithText() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("llmUtilityClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient)); - CountTokensResponse unused = model.countTokens(TEXT); ArgumentCaptor request = ArgumentCaptor.forClass(CountTokensRequest.class); @@ -255,10 +252,6 @@ public void testCountTokenswithText() throws Exception { public void testCountTokenswithContent() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("llmUtilityClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient)); - Content content = ContentMaker.fromString(TEXT); CountTokensResponse unused = model.countTokens(content); @@ -271,10 +264,6 @@ public void testCountTokenswithContent() throws Exception { public void testCountTokenswithContents() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("llmUtilityClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient)); - Content content = ContentMaker.fromString(TEXT); CountTokensResponse unused = model.countTokens(Arrays.asList(content)); @@ -287,10 +276,6 @@ public void testCountTokenswithContents() throws Exception { public void testGenerateContentwithText() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) .thenReturn(mockGenerateContentResponse); @@ -307,10 +292,6 @@ public void testGenerateContentwithText() throws Exception { public void testGenerateContentwithContent() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) .thenReturn(mockGenerateContentResponse); @@ -329,10 +310,6 @@ public void testGenerateContentwithContent() throws Exception { public void testGenerateContentwithContents() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) .thenReturn(mockGenerateContentResponse); @@ -356,10 +333,6 @@ public void testGenerateContentwithDefaultGenerationConfig() throws Exception { .setGenerationConfig(DEFAULT_GENERATION_CONFIG) .build(); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) .thenReturn(mockGenerateContentResponse); @@ -382,10 +355,6 @@ public void testGenerateContentwithDefaultSafetySettings() throws Exception { .setVertexAi(vertexAi) .build(); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) .thenReturn(mockGenerateContentResponse); @@ -408,10 +377,6 @@ public void testGenerateContentwithDefaultTools() throws Exception { .setTools(tools) .build(); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) .thenReturn(mockGenerateContentResponse); @@ -429,10 +394,6 @@ public void testGenerateContentwithDefaultTools() throws Exception { public void testGenerateContentwithFluentApi() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) .thenReturn(mockGenerateContentResponse); @@ -467,10 +428,6 @@ public void generateContent_withNullContents_throws() throws Exception { public void testGenerateContentStreamwithText() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) @@ -490,10 +447,6 @@ public void testGenerateContentStreamwithText() throws Exception { public void testGenerateContentStreamwithContent() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) @@ -515,10 +468,6 @@ public void testGenerateContentStreamwithContent() throws Exception { public void testGenerateContentStreamwithContents() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) @@ -545,10 +494,6 @@ public void testGenerateContentStreamwithDefaultGenerationConfig() throws Except .setVertexAi(vertexAi) .build(); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) @@ -572,10 +517,6 @@ public void testGenerateContentStreamwithDefaultSafetySettings() throws Exceptio .setVertexAi(vertexAi) .build(); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) @@ -599,10 +540,6 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception { .setTools(tools) .build(); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) @@ -621,10 +558,6 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception { public void testGenerateContentStreamwithFluentApi() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) @@ -661,10 +594,6 @@ public void generateContentStream_withEmptyContents_throws() throws Exception { public void generateContentAsync_withText_sendsCorrectRequest() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture); @@ -682,10 +611,6 @@ public void generateContentAsync_withText_sendsCorrectRequest() throws Exception public void generateContentAsync_withContent_sendsCorrectRequest() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture); @@ -703,10 +628,6 @@ public void generateContentAsync_withContent_sendsCorrectRequest() throws Except public void generateContentAsync_withContents_sendsCorrectRequest() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, Optional.of(mockPredictionServiceClient)); - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture); diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java index eb1cc42ff03d..eb7f3e6993fa 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java @@ -50,7 +50,6 @@ public class ITGenerativeModelIntegrationTest { private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); private static final String MODEL_NAME_TEXT = "gemini-pro"; private static final String MODEL_NAME_MULTIMODAL = "gemini-pro-vision"; - private static final String MODEL_NAME_LATEST_GEMINI = "gemini-1.5-pro-preview-0409"; private static final String LOCATION = "us-central1"; // Tested content @@ -68,14 +67,12 @@ public class ITGenerativeModelIntegrationTest { private VertexAI vertexAi; private GenerativeModel textModel; private GenerativeModel multiModalModel; - private GenerativeModel latestGemini; @Before public void setUp() throws IOException { vertexAi = new VertexAI(PROJECT_ID, LOCATION); textModel = new GenerativeModel(MODEL_NAME_TEXT, vertexAi); multiModalModel = new GenerativeModel(MODEL_NAME_MULTIMODAL, vertexAi); - latestGemini = new GenerativeModel(MODEL_NAME_LATEST_GEMINI, vertexAi); } @After @@ -163,29 +160,6 @@ public void generateContentAsync_withPlainText_nonEmptyCandidateList() throws Ex assertNonEmptyAndLogResponse(methodName, TEXT, response); } - @Test - public void generateContent_withContentList_nonEmptyCandidate() throws IOException { - String followupPrompt = "Why do you think these two things are put together?"; - GenerateContentResponse response = - latestGemini.generateContent( - Arrays.asList( - // First Content is from the user (default role is 'user') - ContentMaker.fromMultiModalData( - "Please describe this image.", - PartMaker.fromMimeTypeAndData("image/jpeg", GCS_IMAGE_URI)), - // Second Content is from the model - ContentMaker.forRole("model") - .fromString( - "This is an image of a yellow rubber duck and a blue toy truck. The duck is" - + " on the left and the truck is on the right."), - // Another followup question from the user; "user" is the default role so we omit - // can it. Same for `fromMultiModalList` - ContentMaker.fromString(followupPrompt))); - - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogResponse(methodName, followupPrompt, response); - } - @Test public void generateContentStream_withPlainText_nonEmptyCandidateList() throws IOException { ResponseStream stream = textModel.generateContentStream(TEXT);