From 31f5e12e889ac03d510de4a5b3dcdae0b2b27d0a Mon Sep 17 00:00:00 2001 From: "copybara-service[bot]" <56741989+copybara-service[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 10:23:36 -0700 Subject: [PATCH] BREAKING_CHANGE: [vertexai] remove GenerateContentConfig (#10576) PiperOrigin-RevId: 617275853 Co-authored-by: Jaycee Li --- .../vertexai/generativeai/ChatSession.java | 67 +-------- .../generativeai/GenerateContentConfig.java | 122 ----------------- .../generativeai/GenerativeModel.java | 129 ++---------------- .../generativeai/ChatSessionTest.java | 20 +-- .../GenerateContentConfigTest.java | 79 ----------- .../generativeai/GenerativeModelTest.java | 61 --------- 6 files changed, 20 insertions(+), 458 deletions(-) delete mode 100644 java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java delete mode 100644 java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java index 5f20909649b6..31a920d9f542 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java @@ -53,22 +53,7 @@ public ChatSession(GenerativeModel model) { */ @BetaApi public ResponseStream sendMessageStream(String text) throws IOException { - return sendMessageStream(text, GenerateContentConfig.newBuilder().build()); - } - - /** - * Sends a message to the model and returns a stream of responses. - * - * @param text the message to be sent. - * @param config a {@link GenerateContentConfig} that contains all the configs for sending message - * in a chat session. - * @return an iterable in which each element is a GenerateContentResponse. Can be converted to - * stream by stream() method. - */ - @BetaApi - public ResponseStream sendMessageStream( - String text, GenerateContentConfig config) throws IOException { - return sendMessageStream(ContentMaker.fromString(text), config); + return sendMessageStream(ContentMaker.fromString(text)); } /** @@ -81,25 +66,9 @@ public ResponseStream sendMessageStream( @BetaApi public ResponseStream sendMessageStream(Content content) throws IOException, IllegalArgumentException { - return sendMessageStream(content, GenerateContentConfig.newBuilder().build()); - } - - /** - * Sends a message to the model and returns a stream of responses. - * - * @param content the content to be sent. - * @param config a {@link GenerateContentConfig} that contains all the configs for sending message - * in a chat session. - * @return an iterable in which each element is a GenerateContentResponse. Can be converted to - * stream by stream() method. - */ - @BetaApi - public ResponseStream sendMessageStream( - Content content, GenerateContentConfig config) throws IOException { checkLastResponseAndEditHistory(); history.add(content); - ResponseStream respStream = - model.generateContentStream(history, config); + ResponseStream respStream = model.generateContentStream(history); currentResponseStream = respStream; currentResponse = null; return respStream; @@ -113,21 +82,7 @@ public ResponseStream sendMessageStream( */ @BetaApi public GenerateContentResponse sendMessage(String text) throws IOException { - return sendMessage(text, GenerateContentConfig.newBuilder().build()); - } - - /** - * Sends a message to the model and returns a response. - * - * @param text the message to be sent. - * @param config a {@link GenerateContentConfig} that contains all the configs for sending message - * in a chat session. - * @return a response. - */ - @BetaApi - public GenerateContentResponse sendMessage(String text, GenerateContentConfig config) - throws IOException { - return sendMessage(ContentMaker.fromString(text), config); + return sendMessage(ContentMaker.fromString(text)); } /** @@ -138,23 +93,9 @@ public GenerateContentResponse sendMessage(String text, GenerateContentConfig co */ @BetaApi public GenerateContentResponse sendMessage(Content content) throws IOException { - return sendMessage(content, GenerateContentConfig.newBuilder().build()); - } - - /** - * Sends a message to the model and returns a response. - * - * @param content the content to be sent. - * @param config a {@link GenerateContentConfig} that contains all the configs for sending message - * in a chat session. - * @return a response. - */ - @BetaApi - public GenerateContentResponse sendMessage(Content content, GenerateContentConfig config) - throws IOException { checkLastResponseAndEditHistory(); history.add(content); - GenerateContentResponse response = model.generateContent(history, config); + GenerateContentResponse response = model.generateContent(history); currentResponse = response; currentResponseStream = null; return response; diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java deleted file mode 100644 index 05ebcfdb72da..000000000000 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerateContentConfig.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.cloud.vertexai.generativeai; - -import com.google.api.core.BetaApi; -import com.google.cloud.vertexai.api.GenerationConfig; -import com.google.cloud.vertexai.api.SafetySetting; -import com.google.cloud.vertexai.api.Tool; -import com.google.common.collect.ImmutableList; -import java.util.List; - -/** This class holds all the configs when making a generate content API call */ -public class GenerateContentConfig { - private final GenerationConfig generationConfig; - private final ImmutableList safetySettings; - private final ImmutableList tools; - - /** Creates a builder for the GenerateContentConfig. */ - public static Builder newBuilder() { - return new Builder(); - } - - private GenerateContentConfig(Builder builder) { - if (builder.generationConfig != null) { - this.generationConfig = builder.generationConfig; - } else { - this.generationConfig = null; - } - if (builder.safetySettings != null) { - this.safetySettings = builder.safetySettings; - } else { - this.safetySettings = ImmutableList.of(); - } - if (builder.tools != null) { - this.tools = builder.tools; - } else { - this.tools = ImmutableList.of(); - } - } - - /** Builder class for {@link GenerateContentConfig}. */ - public static class Builder { - private GenerationConfig generationConfig; - private ImmutableList safetySettings; - private ImmutableList tools; - - private Builder() {} - - /** Builds a GenerateContentConfig instance. */ - public GenerateContentConfig build() { - return new GenerateContentConfig(this); - } - - /** - * Sets {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used in the generate - * content API call. - * - * @return builder for the GenerateContentConfig - */ - @BetaApi - public Builder setGenerationConfig(GenerationConfig generationConfig) { - this.generationConfig = generationConfig; - return this; - } - - /** - * Sets a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used in the - * generate content API call. - * - * @return builder for the GenerateContentConfig - */ - @BetaApi - public Builder setSafetySettings(List safetySettings) { - this.safetySettings = ImmutableList.copyOf(safetySettings); - return this; - } - - /** - * Sets a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in the generate - * content API call. - * - * @return builder for the GenerateContentConfig - */ - @BetaApi - public Builder setTools(List tools) { - this.tools = ImmutableList.copyOf(tools); - return this; - } - } - - /** Returns the {@link com.google.cloud.vertexai.api.GenerationConfig} of this config. */ - @BetaApi - public GenerationConfig getGenerationConfig() { - return this.generationConfig; - } - - /** Returns a list of {@link com.google.cloud.vertexai.api.SafetySettings} of this config. */ - @BetaApi("safetySettings is a preview feature.") - public ImmutableList getSafetySettings() { - return this.safetySettings; - } - - /** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this config. */ - @BetaApi("tools is a preview feature.") - public ImmutableList getTools() { - return this.tools; - } -} diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index 6411700ec8bd..281fe21f98e6 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -300,23 +300,7 @@ private CountTokensResponse countTokensFromRequest(CountTokensRequest request) */ @BetaApi public GenerateContentResponse generateContent(String text) throws IOException { - return generateContent(text, GenerateContentConfig.newBuilder().build()); - } - - /** - * Generates content from generative model given a text and configs. - * - * @param text a text message to send to the generative model - * @param config a {@link GenerateContentConfig} that contains all the configs in making a - * generate content api call - * @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains - * response contents and other metadata - * @throws IOException if an I/O error occurs while making the API call - */ - @BetaApi - public GenerateContentResponse generateContent(String text, GenerateContentConfig config) - throws IOException { - return generateContent(ContentMaker.fromString(text), config); + return generateContent(ContentMaker.fromString(text)); } /** @@ -329,23 +313,7 @@ public GenerateContentResponse generateContent(String text, GenerateContentConfi */ @BetaApi("generateContent is a preview feature.") public GenerateContentResponse generateContent(Content content) throws IOException { - return generateContent(content, GenerateContentConfig.newBuilder().build()); - } - - /** - * Generates content from generative model given a single content and configs. - * - * @param content a {@link com.google.cloud.vertexai.api.Content} to send to the generative model - * @param config a {@link GenerateContentConfig} that contains all the configs in making a - * generate content api call - * @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains - * response contents and other metadata - * @throws IOException if an I/O error occurs while making the API call - */ - @BetaApi - public GenerateContentResponse generateContent(Content content, GenerateContentConfig config) - throws IOException { - return generateContent(Arrays.asList(content), config); + return generateContent(Arrays.asList(content)); } /** @@ -359,38 +327,15 @@ public GenerateContentResponse generateContent(Content content, GenerateContentC */ @BetaApi("generateContent is a preview feature.") public GenerateContentResponse generateContent(List contents) throws IOException { - return generateContent(contents, GenerateContentConfig.newBuilder().build()); - } - - /** - * Generates content from generative model given a list of contents and configs. - * - * @param contents a list of {@link com.google.cloud.vertexai.api.Content} to send to the - * generative model - * @param config a {@link GenerateContentConfig} that contains all the configs in making a - * generate content api call - * @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains - * response contents and other metadata - * @throws IOException if an I/O error occurs while making the API call - */ - @BetaApi - public GenerateContentResponse generateContent( - List contents, GenerateContentConfig config) throws IOException { GenerateContentRequest.Builder requestBuilder = GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents); - if (config.getGenerationConfig() != null) { - requestBuilder.setGenerationConfig(config.getGenerationConfig()); - } else if (this.generationConfig != null) { + if (this.generationConfig != null) { requestBuilder.setGenerationConfig(this.generationConfig); } - if (config.getSafetySettings().isEmpty() == false) { - requestBuilder.addAllSafetySettings(config.getSafetySettings()); - } else if (this.safetySettings != null) { + if (this.safetySettings != null) { requestBuilder.addAllSafetySettings(this.safetySettings); } - if (config.getTools().isEmpty() == false) { - requestBuilder.addAllTools(config.getTools()); - } else if (this.tools != null) { + if (this.tools != null) { requestBuilder.addAllTools(this.tools); } @@ -420,22 +365,7 @@ private GenerateContentResponse generateContent(GenerateContentRequest request) */ public ResponseStream generateContentStream(String text) throws IOException { - return generateContentStream(text, GenerateContentConfig.newBuilder().build()); - } - - /** - * Generates content with streaming support from generative model given a text and configs. - * - * @param text a text message to send to the generative model - * @param config a {@link GenerateContentConfig} that contains all the configs in making a - * generate content api call - * @return a {@link ResponseStream} that contains a streaming of {@link - * com.google.cloud.vertexai.api.GenerateContentResponse} - * @throws IOException if an I/O error occurs while making the API call - */ - public ResponseStream generateContentStream( - String text, GenerateContentConfig config) throws IOException { - return generateContentStream(ContentMaker.fromString(text), config); + return generateContentStream(ContentMaker.fromString(text)); } /** @@ -449,23 +379,7 @@ public ResponseStream generateContentStream( */ public ResponseStream generateContentStream(Content content) throws IOException { - return generateContentStream(content, GenerateContentConfig.newBuilder().build()); - } - - /** - * Generates content with streaming support from generative model given a single content and - * configs. - * - * @param content a {@link com.google.cloud.vertexai.api.Content} to send to the generative model - * @param config a {@link GenerateContentConfig} that contains all the configs in making a - * generate content api call - * @return a {@link ResponseStream} that contains a streaming of {@link - * com.google.cloud.vertexai.api.GenerateContentResponse} - * @throws IOException if an I/O error occurs while making the API call - */ - public ResponseStream generateContentStream( - Content content, GenerateContentConfig config) throws IOException { - return generateContentStream(Arrays.asList(content), config); + return generateContentStream(Arrays.asList(content)); } /** @@ -479,38 +393,15 @@ public ResponseStream generateContentStream( */ public ResponseStream generateContentStream(List contents) throws IOException { - return generateContentStream(contents, GenerateContentConfig.newBuilder().build()); - } - - /** - * Generates content with streaming support from generative model given a list of contents and - * configs. - * - * @param contents a list of {@link com.google.cloud.vertexai.api.Content} to send to the - * generative model - * @param config a {@link GenerateContentConfig} that contains all the configs in making a - * generate content api call - * @return a {@link ResponseStream} that contains a streaming of {@link - * com.google.cloud.vertexai.api.GenerateContentResponse} - * @throws IOException if an I/O error occurs while making the API call - */ - public ResponseStream generateContentStream( - List contents, GenerateContentConfig config) throws IOException { GenerateContentRequest.Builder requestBuilder = GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents); - if (config.getGenerationConfig() != null) { - requestBuilder.setGenerationConfig(config.getGenerationConfig()); - } else if (this.generationConfig != null) { + if (this.generationConfig != null) { requestBuilder.setGenerationConfig(this.generationConfig); } - if (config.getSafetySettings().isEmpty() == false) { - requestBuilder.addAllSafetySettings(config.getSafetySettings()); - } else if (this.safetySettings != null) { + if (this.safetySettings != null) { requestBuilder.addAllSafetySettings(this.safetySettings); } - if (config.getTools().isEmpty() == false) { - requestBuilder.addAllTools(config.getTools()); - } else if (this.tools != null) { + if (this.tools != null) { requestBuilder.addAllTools(this.tools); } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java index 5410904e1984..e3b7ba11a40d 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java @@ -18,8 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.when; import com.google.cloud.vertexai.api.Candidate; @@ -115,8 +113,7 @@ public void sendMessageStreamWithText_historyContainsTwoTurns() throws IOExcepti // (Arrange) Set up the return value of the generateContentStream when(mockGenerativeModel.generateContentStream( - eq(Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1))), - any(GenerateContentConfig.class))) + Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1)))) .thenReturn(responseStream); // (Act) send request, consume response and get history @@ -141,8 +138,7 @@ public void sendMessageStreamWithText_throwsIllegalStateExceptionIfResponseNotCo // (Arrange) Set up the return value of the generateContentStream when(mockGenerativeModel.generateContentStream( - eq(Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1))), - any(GenerateContentConfig.class))) + Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1)))) .thenReturn(responseStream); // (Act & Assert) Send request, consume response and get history, but not consume the response @@ -160,8 +156,7 @@ public void sendMessageWithText_historyContainsTwoTurns() throws IOException { // (Arrange) Set up the return value of the generateContent when(mockGenerativeModel.generateContent( - eq(Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1))), - any(GenerateContentConfig.class))) + Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1)))) .thenReturn(RESPONSE_FROM_UNARY_CALL); // (Act) Send text message via sendMessage and get the history. @@ -179,8 +174,7 @@ public void sendMessageWithTextThenModifyHistory_historyChangedToNewContentList( // (Arrange) Set up the return value of the generateContent when(mockGenerativeModel.generateContent( - eq(Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1))), - any(GenerateContentConfig.class))) + Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1)))) .thenReturn(RESPONSE_FROM_UNARY_CALL); // (Act) Send text message via sendMessage and get the history. @@ -210,8 +204,7 @@ public void sendMessageStreamWithText_throwsIllegalStateExceptionWhenFinishReaso // (Arrange) Set up the return value of the generateContentStream when(mockGenerativeModel.generateContentStream( - eq(Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1))), - any(GenerateContentConfig.class))) + Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1)))) .thenReturn(responseStream); // (Act) send request, consume response @@ -235,8 +228,7 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot throws IOException { // (Arrange) Set up the return value of the generateContent when(mockGenerativeModel.generateContent( - eq(Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1))), - any(GenerateContentConfig.class))) + Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1)))) .thenReturn(RESPONSE_FROM_UNARY_CALL_WITH_OTHER_FINISH_REASON); // (Act) Send text message via sendMessage diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java deleted file mode 100644 index 05d835e983b0..000000000000 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerateContentConfigTest.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.cloud.vertexai.generativeai; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.cloud.vertexai.api.FunctionDeclaration; -import com.google.cloud.vertexai.api.GenerationConfig; -import com.google.cloud.vertexai.api.HarmCategory; -import com.google.cloud.vertexai.api.SafetySetting; -import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold; -import com.google.cloud.vertexai.api.Schema; -import com.google.cloud.vertexai.api.Tool; -import com.google.cloud.vertexai.api.Type; -import java.util.Arrays; -import java.util.List; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class GenerateContentConfigTest { - private static final GenerationConfig GENERATION_CONFIG = - GenerationConfig.newBuilder().setCandidateCount(1).build(); - private static final SafetySetting SAFETY_SETTING = - SafetySetting.newBuilder() - .setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT) - .setThreshold(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) - .build(); - private static final Tool TOOL = - Tool.newBuilder() - .addFunctionDeclarations( - FunctionDeclaration.newBuilder() - .setName("getCurrentWeather") - .setDescription("Get the current weather in a given location") - .setParameters( - Schema.newBuilder() - .setType(Type.OBJECT) - .putProperties( - "location", - Schema.newBuilder() - .setType(Type.STRING) - .setDescription("location") - .build()) - .addRequired("location"))) - .build(); - - private List safetySettings = Arrays.asList(SAFETY_SETTING); - private List tools = Arrays.asList(TOOL); - - private GenerateContentConfig config; - - @Test - public void testInstantiateGenerateContentConfigWithBuilder() { - config = - GenerateContentConfig.newBuilder() - .setGenerationConfig(GENERATION_CONFIG) - .setSafetySettings(safetySettings) - .setTools(tools) - .build(); - assertThat(config.getGenerationConfig()).isEqualTo(GENERATION_CONFIG); - assertThat(config.getSafetySettings()).isEqualTo(safetySettings); - assertThat(config.getTools()).isEqualTo(tools); - } -} diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index 14864a378d1e..7bd6b3153205 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -425,36 +425,6 @@ public void testGenerateContentwithDefaultTools() throws Exception { assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); } - @Test - public void testGenerateContentwithGenerateContentConfig() throws Exception { - model = new GenerativeModel(MODEL_NAME, vertexAi); - GenerateContentConfig config = - GenerateContentConfig.newBuilder() - .setGenerationConfig(GENERATION_CONFIG) - .setSafetySettings(safetySettings) - .setTools(tools) - .build(); - - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); - - when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); - when(mockUnaryCallable.call(any(GenerateContentRequest.class))) - .thenReturn(mockGenerateContentResponse); - - GenerateContentResponse unused = model.generateContent(TEXT, config); - - ArgumentCaptor request = - ArgumentCaptor.forClass(GenerateContentRequest.class); - verify(mockUnaryCallable).call(request.capture()); - - assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); - assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG); - assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING); - assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); - } - @Test public void testGenerateContentStreamwithText() throws Exception { model = new GenerativeModel(MODEL_NAME, vertexAi); @@ -598,35 +568,4 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception { verify(mockServerStreamCallable).call(request.capture()); assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); } - - @Test - public void testGenerateContentStreamwithGenerateContentConfig() throws Exception { - model = new GenerativeModel(MODEL_NAME, vertexAi); - GenerateContentConfig config = - GenerateContentConfig.newBuilder() - .setGenerationConfig(GENERATION_CONFIG) - .setSafetySettings(safetySettings) - .setTools(tools) - .build(); - - Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); - field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); - - when(mockPredictionServiceClient.streamGenerateContentCallable()) - .thenReturn(mockServerStreamCallable); - when(mockServerStreamCallable.call(any(GenerateContentRequest.class))) - .thenReturn(mockServerStream); - when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator); - - ResponseStream unused = model.generateContentStream(TEXT, config); - - ArgumentCaptor request = - ArgumentCaptor.forClass(GenerateContentRequest.class); - verify(mockServerStreamCallable).call(request.capture()); - - assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG); - assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING); - assertThat(request.getValue().getTools(0)).isEqualTo(TOOL); - } }