From 5c3d93eafe76c97e50fc174678743fe59be699af Mon Sep 17 00:00:00 2001 From: "copybara-service[bot]" <56741989+copybara-service[bot]@users.noreply.github.com> Date: Thu, 21 Mar 2024 16:48:32 -0700 Subject: [PATCH] feat: [vertexai] add fluent API in ChatSession (#10597) PiperOrigin-RevId: 617901539 Co-authored-by: Jaycee Li --- .../com/google/cloud/vertexai/VertexAI.java | 239 +++++++++++------- .../vertexai/generativeai/ChatSession.java | 205 ++++++++++++--- .../FunctionDeclarationMaker.java | 63 ----- .../generativeai/ChatSessionTest.java | 103 ++++++++ .../FunctionDeclarationMakerTest.java | 177 ------------- 5 files changed, 417 insertions(+), 370 deletions(-) delete mode 100644 java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java delete mode 100644 java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index a4a699bcc5a2..589272d20ce8 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -16,9 +16,6 @@ package com.google.cloud.vertexai; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; - import com.google.api.core.InternalApi; import com.google.api.gax.core.CredentialsProvider; import com.google.api.gax.core.FixedCredentialsProvider; @@ -31,10 +28,8 @@ import com.google.cloud.vertexai.api.LlmUtilityServiceSettings; import com.google.cloud.vertexai.api.PredictionServiceClient; import com.google.cloud.vertexai.api.PredictionServiceSettings; -import com.google.common.base.Strings; import java.io.IOException; import java.util.List; -import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import java.util.logging.Logger; @@ -61,8 +56,9 @@ public class VertexAI implements AutoCloseable { private Transport transport = Transport.GRPC; // The clients will be instantiated lazily private PredictionServiceClient predictionServiceClient = null; + private PredictionServiceClient predictionServiceRestClient = null; private LlmUtilityServiceClient llmUtilityClient = null; - private final ReentrantLock lock = new ReentrantLock(); + private LlmUtilityServiceClient llmUtilityRestClient = null; /** * Construct a VertexAI instance. @@ -197,35 +193,32 @@ public Credentials getCredentials() throws IOException { /** Sets the value for {@link #getTransport()}. */ public void setTransport(Transport transport) { - checkNotNull(transport, "Transport can't be null."); - if (this.transport == transport) { - return; - } - this.transport = transport; - resetClients(); } /** Sets the value for {@link #getApiEndpoint()}. */ public void setApiEndpoint(String apiEndpoint) { - checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty."); - if (this.apiEndpoint == apiEndpoint) { - return; - } this.apiEndpoint = apiEndpoint; - resetClients(); - } - private void resetClients() { if (this.predictionServiceClient != null) { this.predictionServiceClient.close(); this.predictionServiceClient = null; } + if (this.predictionServiceRestClient != null) { + this.predictionServiceRestClient.close(); + this.predictionServiceRestClient = null; + } + if (this.llmUtilityClient != null) { this.llmUtilityClient.close(); this.llmUtilityClient = null; } + + if (this.llmUtilityRestClient != null) { + this.llmUtilityRestClient.close(); + this.llmUtilityRestClient = null; + } } /** @@ -237,47 +230,78 @@ private void resetClients() { */ @InternalApi public PredictionServiceClient getPredictionServiceClient() throws IOException { - if (predictionServiceClient != null) { - return predictionServiceClient; + if (this.transport == Transport.GRPC) { + return getPredictionServiceGrpcClient(); + } else { + return getPredictionServiceRestClient(); } - lock.lock(); - try { - if (predictionServiceClient == null) { - PredictionServiceSettings settings = getPredictionServiceSettings(); - // Disable the warning message logged in getApplicationDefault - Logger defaultCredentialsProviderLogger = - Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); - Level previousLevel = defaultCredentialsProviderLogger.getLevel(); - defaultCredentialsProviderLogger.setLevel(Level.SEVERE); - predictionServiceClient = PredictionServiceClient.create(settings); - defaultCredentialsProviderLogger.setLevel(previousLevel); + } + + /** + * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the + * first prediction API call is made. + * + * @return {@link PredictionServiceClient} that send GRPC requests to the backing service through + * method calls that map to the API methods. + */ + private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException { + if (predictionServiceClient == null) { + PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder(); + settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); + if (this.credentialsProvider != null) { + settingsBuilder.setCredentialsProvider(this.credentialsProvider); } - return predictionServiceClient; - } finally { - lock.unlock(); + HeaderProvider headerProvider = + FixedHeaderProvider.create( + "user-agent", + String.format( + "%s/%s", + Constants.USER_AGENT_HEADER, + GaxProperties.getLibraryVersion(PredictionServiceSettings.class))); + settingsBuilder.setHeaderProvider(headerProvider); + // Disable the warning message logged in getApplicationDefault + Logger defaultCredentialsProviderLogger = + Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); + Level previousLevel = defaultCredentialsProviderLogger.getLevel(); + defaultCredentialsProviderLogger.setLevel(Level.SEVERE); + predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build()); + defaultCredentialsProviderLogger.setLevel(previousLevel); } + return predictionServiceClient; } - private PredictionServiceSettings getPredictionServiceSettings() throws IOException { - PredictionServiceSettings.Builder builder; - if (transport == Transport.REST) { - builder = PredictionServiceSettings.newHttpJsonBuilder(); - } else { - builder = PredictionServiceSettings.newBuilder(); - } - builder.setEndpoint(String.format("%s:443", this.apiEndpoint)); - if (this.credentialsProvider != null) { - builder.setCredentialsProvider(this.credentialsProvider); + /** + * Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the + * first prediction API call is made. + * + * @return {@link PredictionServiceClient} that send REST requests to the backing service through + * method calls that map to the API methods. + */ + private PredictionServiceClient getPredictionServiceRestClient() throws IOException { + if (predictionServiceRestClient == null) { + PredictionServiceSettings.Builder settingsBuilder = + PredictionServiceSettings.newHttpJsonBuilder(); + settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); + if (this.credentialsProvider != null) { + settingsBuilder.setCredentialsProvider(this.credentialsProvider); + } + HeaderProvider headerProvider = + FixedHeaderProvider.create( + "user-agent", + String.format( + "%s/%s", + Constants.USER_AGENT_HEADER, + GaxProperties.getLibraryVersion(PredictionServiceSettings.class))); + settingsBuilder.setHeaderProvider(headerProvider); + // Disable the warning message logged in getApplicationDefault + Logger defaultCredentialsProviderLogger = + Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); + Level previousLevel = defaultCredentialsProviderLogger.getLevel(); + defaultCredentialsProviderLogger.setLevel(Level.SEVERE); + predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build()); + defaultCredentialsProviderLogger.setLevel(previousLevel); } - HeaderProvider headerProvider = - FixedHeaderProvider.create( - "user-agent", - String.format( - "%s/%s", - Constants.USER_AGENT_HEADER, - GaxProperties.getLibraryVersion(PredictionServiceSettings.class))); - builder.setHeaderProvider(headerProvider); - return builder.build(); + return predictionServiceRestClient; } /** @@ -289,47 +313,78 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept */ @InternalApi public LlmUtilityServiceClient getLlmUtilityClient() throws IOException { - if (llmUtilityClient != null) { - return llmUtilityClient; + if (this.transport == Transport.GRPC) { + return getLlmUtilityGrpcClient(); + } else { + return getLlmUtilityRestClient(); } - lock.lock(); - try { - if (llmUtilityClient == null) { - LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings(); - // Disable the warning message logged in getApplicationDefault - Logger defaultCredentialsProviderLogger = - Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); - Level previousLevel = defaultCredentialsProviderLogger.getLevel(); - defaultCredentialsProviderLogger.setLevel(Level.SEVERE); - llmUtilityClient = LlmUtilityServiceClient.create(settings); - defaultCredentialsProviderLogger.setLevel(previousLevel); + } + + /** + * Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the + * first API call is made. + * + * @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through + * method calls that map to the API methods. + */ + private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException { + if (llmUtilityClient == null) { + LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder(); + settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); + if (this.credentialsProvider != null) { + settingsBuilder.setCredentialsProvider(this.credentialsProvider); } - return llmUtilityClient; - } finally { - lock.unlock(); + HeaderProvider headerProvider = + FixedHeaderProvider.create( + "user-agent", + String.format( + "%s/%s", + Constants.USER_AGENT_HEADER, + GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class))); + settingsBuilder.setHeaderProvider(headerProvider); + // Disable the warning message logged in getApplicationDefault + Logger defaultCredentialsProviderLogger = + Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); + Level previousLevel = defaultCredentialsProviderLogger.getLevel(); + defaultCredentialsProviderLogger.setLevel(Level.SEVERE); + llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build()); + defaultCredentialsProviderLogger.setLevel(previousLevel); } + return llmUtilityClient; } - private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException { - LlmUtilityServiceSettings.Builder settingsBuilder; - if (transport == Transport.REST) { - settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder(); - } else { - settingsBuilder = LlmUtilityServiceSettings.newBuilder(); - } - settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); - if (this.credentialsProvider != null) { - settingsBuilder.setCredentialsProvider(this.credentialsProvider); + /** + * Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the + * first API call is made. + * + * @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through + * method calls that map to the API methods. + */ + private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException { + if (llmUtilityRestClient == null) { + LlmUtilityServiceSettings.Builder settingsBuilder = + LlmUtilityServiceSettings.newHttpJsonBuilder(); + settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); + if (this.credentialsProvider != null) { + settingsBuilder.setCredentialsProvider(this.credentialsProvider); + } + HeaderProvider headerProvider = + FixedHeaderProvider.create( + "user-agent", + String.format( + "%s/%s", + Constants.USER_AGENT_HEADER, + GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class))); + settingsBuilder.setHeaderProvider(headerProvider); + // Disable the warning message logged in getApplicationDefault + Logger defaultCredentialsProviderLogger = + Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); + Level previousLevel = defaultCredentialsProviderLogger.getLevel(); + defaultCredentialsProviderLogger.setLevel(Level.SEVERE); + llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build()); + defaultCredentialsProviderLogger.setLevel(previousLevel); } - HeaderProvider headerProvider = - FixedHeaderProvider.create( - "user-agent", - String.format( - "%s/%s", - Constants.USER_AGENT_HEADER, - GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class))); - settingsBuilder.setHeaderProvider(headerProvider); - return settingsBuilder.build(); + return llmUtilityRestClient; } /** Closes the VertexAI instance together with all its instantiated clients. */ @@ -338,8 +393,14 @@ public void close() { if (predictionServiceClient != null) { predictionServiceClient.close(); } + if (predictionServiceRestClient != null) { + predictionServiceRestClient.close(); + } if (llmUtilityClient != null) { llmUtilityClient.close(); } + if (llmUtilityRestClient != null) { + llmUtilityRestClient.close(); + } } } diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java index 31a920d9f542..cca5a6172b7b 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java @@ -19,29 +19,101 @@ import static com.google.cloud.vertexai.generativeai.ResponseHandler.aggregateStreamIntoResponse; import static com.google.cloud.vertexai.generativeai.ResponseHandler.getContent; import static com.google.cloud.vertexai.generativeai.ResponseHandler.getFinishReason; +import static com.google.common.base.Preconditions.checkNotNull; import com.google.api.core.BetaApi; import com.google.cloud.vertexai.api.Candidate.FinishReason; import com.google.cloud.vertexai.api.Content; import com.google.cloud.vertexai.api.GenerateContentResponse; +import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.SafetySetting; +import com.google.cloud.vertexai.api.Tool; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Represents a conversation between the user and the model */ public final class ChatSession { private final GenerativeModel model; + private final Optional rootChatSession; private List history = new ArrayList<>(); - private ResponseStream currentResponseStream = null; - private GenerateContentResponse currentResponse = null; + private Optional> currentResponseStream; + private Optional currentResponse; + /** + * Creates a new chat session given a GenerativeModel instance. Configurations of the chat (e.g., + * GenerationConfig) inherits from the model. + */ @BetaApi public ChatSession(GenerativeModel model) { - if (model == null) { - throw new IllegalArgumentException("model should not be null"); - } + this(model, Optional.empty()); + } + + /** + * Creates a new chat session given a GenerativeModel instance and a root chat session. + * Configurations of the chat (e.g., GenerationConfig) inherits from the model. + * + * @param model a {@link GenerativeModel} instance that generates contents in the chat. + * @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current + * chat session will be merged to the root chat session. + * @return a {@link ChatSession} instance. + */ + @BetaApi + private ChatSession(GenerativeModel model, Optional rootChatSession) { + checkNotNull(model, "model should not be null"); this.model = model; + this.rootChatSession = rootChatSession; + currentResponseStream = Optional.empty(); + currentResponse = Optional.empty(); + } + + /** + * Creates a copy of the current ChatSession with updated GenerationConfig. + * + * @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be + * used in the new ChatSession. + * @return a new {@link ChatSession} instance with the specified GenerationConfig. + */ + @BetaApi + public ChatSession withGenerationConfig(GenerationConfig generationConfig) { + ChatSession rootChat = rootChatSession.orElse(this); + ChatSession newChatSession = + new ChatSession(model.withGenerationConfig(generationConfig), Optional.of(rootChat)); + newChatSession.history = history; + return newChatSession; + } + + /** + * Creates a copy of the current ChatSession with updated SafetySettings. + * + * @param safetySettings a {@link com.google.cloud.vertexai.api.SafetySetting} that will be used + * in the new ChatSession. + * @return a new {@link ChatSession} instance with the specified SafetySettings. + */ + @BetaApi + public ChatSession withSafetySettings(List safetySettings) { + ChatSession rootChat = rootChatSession.orElse(this); + ChatSession newChatSession = + new ChatSession(model.withSafetySettings(safetySettings), Optional.of(rootChat)); + newChatSession.setHistory(history); + return newChatSession; + } + + /** + * Creates a copy of the current ChatSession with updated Tools. + * + * @param tools a {@link com.google.cloud.vertexai.api.Tool} that will be used in the new + * ChatSession. + * @return a new {@link ChatSession} instance with the specified Tools. + */ + @BetaApi + public ChatSession withTools(List tools) { + ChatSession rootChat = rootChatSession.orElse(this); + ChatSession newChatSession = new ChatSession(model.withTools(tools), Optional.of(rootChat)); + newChatSession.setHistory(history); + return newChatSession; } /** @@ -69,8 +141,8 @@ public ResponseStream sendMessageStream(Content content checkLastResponseAndEditHistory(); history.add(content); ResponseStream respStream = model.generateContentStream(history); - currentResponseStream = respStream; - currentResponse = null; + setCurrentResponseStream(Optional.of(respStream)); + return respStream; } @@ -96,8 +168,7 @@ public GenerateContentResponse sendMessage(Content content) throws IOException { checkLastResponseAndEditHistory(); history.add(content); GenerateContentResponse response = model.generateContent(history); - currentResponse = response; - currentResponseStream = null; + setCurrentResponse(Optional.of(response)); return response; } @@ -112,38 +183,37 @@ private void removeLastContent() { * @throws IllegalStateException if the response stream is not finished. */ private void checkLastResponseAndEditHistory() { - if (currentResponseStream == null && currentResponse == null) { - return; - } else if (currentResponseStream != null && !currentResponseStream.isConsumed()) { - throw new IllegalStateException("Response stream is not consumed"); - } else if (currentResponseStream != null && currentResponseStream.isConsumed()) { - GenerateContentResponse response = aggregateStreamIntoResponse(currentResponseStream); - FinishReason finishReason = getFinishReason(response); - if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) { - // We also remove the request from the history. - removeLastContent(); - currentResponseStream = null; - throw new IllegalStateException( - String.format( - "The last round of conversation will not be added to history because response" - + " stream did not finish normally. Finish reason is %s.", - finishReason)); - } - history.add(getContent(response)); - } else if (currentResponseStream == null && currentResponse != null) { - FinishReason finishReason = getFinishReason(currentResponse); - if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) { - // We also remove the request from the history. - removeLastContent(); - currentResponse = null; - throw new IllegalStateException( - String.format( - "The last round of conversation will not be added to history because response did" - + " not finish normally. Finish reason is %s.", - finishReason)); - } - history.add(getContent(currentResponse)); - currentResponse = null; + getCurrentResponse() + .ifPresent( + currentResponse -> { + setCurrentResponse(Optional.empty()); + checkFinishReasonAndRemoveLastContent(currentResponse); + history.add(getContent(currentResponse)); + }); + getCurrentResponseStream() + .ifPresent( + responseStream -> { + if (!responseStream.isConsumed()) { + throw new IllegalStateException("Response stream is not consumed"); + } else { + setCurrentResponseStream(Optional.empty()); + GenerateContentResponse response = aggregateStreamIntoResponse(responseStream); + checkFinishReasonAndRemoveLastContent(response); + history.add(getContent(response)); + } + }); + } + + /** Removes the last content in the history if the response finished with problems. */ + private void checkFinishReasonAndRemoveLastContent(GenerateContentResponse response) { + FinishReason finishReason = getFinishReason(response); + if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) { + removeLastContent(); + throw new IllegalStateException( + String.format( + "The last round of conversation will not be added to history because response" + + " stream did not finish normally. Finish reason is %s.", + finishReason)); } } @@ -169,9 +239,62 @@ public List getHistory() { return Collections.unmodifiableList(history); } + /** + * Returns the current response of the root chat session (if exists) or the current chat session. + */ + private Optional getCurrentResponse() { + if (rootChatSession.isPresent()) { + return rootChatSession.get().currentResponse; + } else { + return currentResponse; + } + } + + /** + * Returns the current responseStream of the root chat session (if exists) or the current chat + * session. + */ + private Optional> getCurrentResponseStream() { + if (rootChatSession.isPresent()) { + return rootChatSession.get().currentResponseStream; + } else { + return currentResponseStream; + } + } + /** Set the history to a list of Content */ @BetaApi public void setHistory(List history) { this.history = history; } + + /** Sets the current response of the root chat session (if exists) or the current chat session. */ + private void setCurrentResponse(Optional response) { + if (currentResponseStream.isPresent()) { + throw new IllegalStateException( + "currentResponse and currentResponseStream cannot be set together"); + } + if (rootChatSession.isPresent()) { + rootChatSession.get().setCurrentResponse(response); + } else { + currentResponse = response; + } + } + + /** + * Sets the current responseStream of the root chat session (if exists) or the current chat + * session. + */ + private void setCurrentResponseStream( + Optional> responseStream) { + if (currentResponse.isPresent()) { + throw new IllegalStateException( + "currentResponseStream and currentResponse cannot be set together"); + } + if (rootChatSession.isPresent()) { + rootChatSession.get().setCurrentResponseStream(responseStream); + } else { + currentResponseStream = responseStream; + } + } } diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java deleted file mode 100644 index c118df5b0d6f..000000000000 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.cloud.vertexai.generativeai; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.cloud.vertexai.api.FunctionDeclaration; -import com.google.common.base.Strings; -import com.google.gson.JsonObject; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.util.JsonFormat; - -/** Helper class to create {@link com.google.cloud.vertexai.api.FunctionDeclaration} */ -public final class FunctionDeclarationMaker { - - /** - * Creates a FunctionDeclaration from a JsonString - * - * @param jsonString A valid Json String that can be parsed to a FunctionDeclaration. - * @return a {@link FunctionDeclaration} by parsing the input json String. - * @throws InvalidProtocolBufferException if the String can't be parsed into a FunctionDeclaration - * proto. - */ - public static FunctionDeclaration fromJsonString(String jsonString) - throws InvalidProtocolBufferException { - checkArgument(!Strings.isNullOrEmpty(jsonString), "Input String can't be null or empty."); - FunctionDeclaration.Builder builder = FunctionDeclaration.newBuilder(); - JsonFormat.parser().merge(jsonString, builder); - FunctionDeclaration declaration = builder.build(); - if (declaration.getName().isEmpty()) { - throw new IllegalArgumentException("name field must be present."); - } - return declaration; - } - - /** - * Creates a FunctionDeclaration from a JsonObject - * - * @param jsonObject A valid Json Object that can be parsed to a FunctionDeclaration. - * @return a {@link FunctionDeclaration} by parsing the input json Object. - * @throws InvalidProtocolBufferException if the jsonObject can't be parsed into a - * FunctionDeclaration proto. - */ - public static FunctionDeclaration fromJsonObject(JsonObject jsonObject) - throws InvalidProtocolBufferException { - checkNotNull(jsonObject, "JsonObject can't be null."); - return fromJsonString(jsonObject.toString()); - } -} diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java index e3b7ba11a40d..ef6733af4c60 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java @@ -18,14 +18,30 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.api.gax.rpc.UnaryCallable; +import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.Candidate; import com.google.cloud.vertexai.api.Candidate.FinishReason; import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionDeclaration; +import com.google.cloud.vertexai.api.GenerateContentRequest; import com.google.cloud.vertexai.api.GenerateContentResponse; +import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.HarmCategory; import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.api.PredictionServiceClient; +import com.google.cloud.vertexai.api.SafetySetting; +import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold; +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.api.Type; import java.io.IOException; +import java.lang.reflect.Field; import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -34,6 +50,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -48,6 +65,8 @@ public final class ChatSessionTest { private static final String RESPONSE_STREAM_CHUNK2_TEXT = "But I'm happy to help you!"; private static final String FULL_RESPONSE_TEXT = RESPONSE_STREAM_CHUNK1_TEXT + RESPONSE_STREAM_CHUNK2_TEXT; + private static final String SAMPLE_MESSAGE_2 = "What can you do for me?"; + private static final String RESPONSE_TEXT_2 = "I can summarize a bo"; private static final GenerateContentResponse RESPONSE_STREAM_CHUNK1_RESPONSE = GenerateContentResponse.newBuilder() .addCandidates( @@ -82,6 +101,14 @@ public final class ChatSessionTest { .setContent( Content.newBuilder().addParts(Part.newBuilder().setText(FULL_RESPONSE_TEXT)))) .build(); + private static final GenerateContentResponse RESPONSE_FROM_UNARY_CALL_2 = + GenerateContentResponse.newBuilder() + .addCandidates( + Candidate.newBuilder() + .setFinishReason(FinishReason.MAX_TOKENS) + .setContent( + Content.newBuilder().addParts(Part.newBuilder().setText(RESPONSE_TEXT_2)))) + .build(); private static final GenerateContentResponse RESPONSE_FROM_UNARY_CALL_WITH_OTHER_FINISH_REASON = GenerateContentResponse.newBuilder() @@ -91,10 +118,43 @@ public final class ChatSessionTest { .setContent( Content.newBuilder().addParts(Part.newBuilder().setText(FULL_RESPONSE_TEXT)))) .build(); + + private static final GenerationConfig GENERATION_CONFIG = + GenerationConfig.newBuilder().setCandidateCount(1).build(); + private static final SafetySetting SAFETY_SETTING = + SafetySetting.newBuilder() + .setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT) + .setThreshold(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) + .build(); + private static final Tool TOOL = + Tool.newBuilder() + .addFunctionDeclarations( + FunctionDeclaration.newBuilder() + .setName("getCurrentWeather") + .setDescription("Get the current weather in a given location") + .setParameters( + Schema.newBuilder() + .setType(Type.OBJECT) + .putProperties( + "location", + Schema.newBuilder() + .setType(Type.STRING) + .setDescription("location") + .build()) + .addRequired("location"))) + .build(); + @Rule public final MockitoRule mocksRule = MockitoJUnit.rule(); @Mock private GenerativeModel mockGenerativeModel; @Mock private Iterator mockServerStreamIterator; + + @Mock private PredictionServiceClient mockPredictionServiceClient; + + @Mock private UnaryCallable mockUnaryCallable; + + @Mock private GenerateContentResponse mockGenerateContentResponse; + private ChatSession chat; @Before @@ -243,4 +303,47 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot List history = chat.getHistory(); assertThat(history.size()).isEqualTo(0); } + + @Test + public void testChatSessionMergeHistoryToRootChatSession() throws Exception { + + // (Arrange) Set up the return value of the generateContent + VertexAI vertexAi = new VertexAI(PROJECT, LOCATION); + GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi); + + Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); + field.setAccessible(true); + field.set(vertexAi, mockPredictionServiceClient); + + when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); + when(mockUnaryCallable.call(any(GenerateContentRequest.class))) + .thenReturn(RESPONSE_FROM_UNARY_CALL) + .thenReturn(RESPONSE_FROM_UNARY_CALL_2); + + // (Act) Send text message in root chat + ChatSession rootChat = model.startChat(); + GenerateContentResponse response = rootChat.sendMessage(SAMPLE_MESSAGE1); + // (Act) Create a child chat session and send message again + ChatSession childChat = + rootChat + .withGenerationConfig(GENERATION_CONFIG) + .withSafetySettings(Arrays.asList(SAFETY_SETTING)) + .withTools(Arrays.asList(TOOL)); + response = childChat.sendMessage(SAMPLE_MESSAGE_2); + + // (Assert) root chat history should contain all 4 contents + List history = rootChat.getHistory(); + assertThat(history.get(0).getParts(0).getText()).isEqualTo(SAMPLE_MESSAGE1); + assertThat(history.get(1).getParts(0).getText()).isEqualTo(FULL_RESPONSE_TEXT); + assertThat(history.get(2).getParts(0).getText()).isEqualTo(SAMPLE_MESSAGE_2); + assertThat(history.get(3).getParts(0).getText()).isEqualTo(RESPONSE_TEXT_2); + + // (Assert) the second request (from child chat) should contained updated configurations + ArgumentCaptor request = + ArgumentCaptor.forClass(GenerateContentRequest.class); + verify(mockUnaryCallable, times(2)).call(request.capture()); + assertThat(request.getAllValues().get(1).getGenerationConfig()).isEqualTo(GENERATION_CONFIG); + assertThat(request.getAllValues().get(1).getSafetySettings(0)).isEqualTo(SAFETY_SETTING); + assertThat(request.getAllValues().get(1).getTools(0)).isEqualTo(TOOL); + } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java deleted file mode 100644 index f8894ba1ca99..000000000000 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.cloud.vertexai.generativeai; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; - -import com.google.cloud.vertexai.api.FunctionDeclaration; -import com.google.cloud.vertexai.api.Schema; -import com.google.cloud.vertexai.api.Type; -import com.google.gson.JsonObject; -import com.google.protobuf.InvalidProtocolBufferException; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class FunctionDeclarationMakerTest { - - @Test - public void fromValidJsonStringTested_returnsFunctionDeclaration() - throws InvalidProtocolBufferException { - String jsonString = - "{\n" - + " \"name\": \"functionName\",\n" - + " \"description\": \"functionDescription\",\n" - + " \"parameters\": {\n" - + " \"type\": \"OBJECT\", \n" - + " \"properties\": {\n" - + " \"param1\": {\n" - + " \"type\": \"STRING\",\n" - + " \"description\": \"param1Description\"\n" - + " }\n" - + " }\n" - + " }\n" - + "}"; - - FunctionDeclaration functionDeclaration = FunctionDeclarationMaker.fromJsonString(jsonString); - - FunctionDeclaration expectedFunctionDeclaration = - FunctionDeclaration.newBuilder() - .setName("functionName") - .setDescription("functionDescription") - .setParameters( - Schema.newBuilder() - .setType(Type.OBJECT) - .putProperties( - "param1", - Schema.newBuilder() - .setType(Type.STRING) - .setDescription("param1Description") - .build())) - .build(); - assertThat(functionDeclaration).isEqualTo(expectedFunctionDeclaration); - } - - @Test - public void fromJsonStringWithInvalidType_throwsInvalidProtocolBufferException() - throws InvalidProtocolBufferException { - // Here we use "parameter" (singular) instead of "parameters" - String jsonString = - "{\n" - + " \"name\": \"functionName\",\n" - + " \"parameter\": {\n" - + " \"type\": \"OBJECT\", \n" - + " \"properties\": {\n" - + " \"param1\": {\n" - + " \"type\": \"STRING\",\n" - + " \"description\": \"param1Description\"\n" - + " }\n" - + " }\n" - + " }\n" - + "}"; - assertThrows( - InvalidProtocolBufferException.class, - () -> FunctionDeclarationMaker.fromJsonString(jsonString)); - } - - @Test - public void fromJsonStringNameMissing_throwsIllegalArgumentException() - throws InvalidProtocolBufferException { - String jsonString = - "{\n" - + " \"description\": \"functionDescription\",\n" - + " \"parameters\": {\n" - + " \"type\": \"OBJECT\", \n" - + " \"properties\": {\n" - + " \"param1\": {\n" - + " \"type\": \"STRING\",\n" - + " \"description\": \"param1Description\"\n" - + " }\n" - + " }\n" - + " }\n" - + "}"; - - IllegalArgumentException thrown = - assertThrows( - IllegalArgumentException.class, - () -> FunctionDeclarationMaker.fromJsonString(jsonString)); - assertThat(thrown).hasMessageThat().isEqualTo("name field must be present."); - } - - @Test - public void fromEmptyString_throwsIllegalArgumentException() - throws InvalidProtocolBufferException { - String jsonString = ""; - - IllegalArgumentException thrown = - assertThrows( - IllegalArgumentException.class, - () -> FunctionDeclarationMaker.fromJsonString(jsonString)); - assertThat(thrown).hasMessageThat().isEqualTo("Input String can't be null or empty."); - } - - @Test - public void fromJsonStringStringIsNull_throwsIllegalArgumentException() - throws InvalidProtocolBufferException { - String jsonString = null; - - IllegalArgumentException thrown = - assertThrows( - IllegalArgumentException.class, - () -> FunctionDeclarationMaker.fromJsonString(jsonString)); - assertThat(thrown).hasMessageThat().isEqualTo("Input String can't be null or empty."); - } - - @Test - public void fromValidJsonObject_returnsFunctionDeclaration() - throws InvalidProtocolBufferException { - JsonObject param1JsonObject = new JsonObject(); - param1JsonObject.addProperty("type", "STRING"); - param1JsonObject.addProperty("description", "param1Description"); - - JsonObject propertiesJsonObject = new JsonObject(); - propertiesJsonObject.add("param1", param1JsonObject); - - JsonObject parametersJsonObject = new JsonObject(); - parametersJsonObject.addProperty("type", "OBJECT"); - parametersJsonObject.add("properties", propertiesJsonObject); - - JsonObject jsonObject = new JsonObject(); - jsonObject.addProperty("name", "functionName"); - jsonObject.addProperty("description", "functionDescription"); - jsonObject.add("parameters", parametersJsonObject); - - FunctionDeclaration functionDeclaration = FunctionDeclarationMaker.fromJsonObject(jsonObject); - - FunctionDeclaration expectedFunctionDeclaration = - FunctionDeclaration.newBuilder() - .setName("functionName") - .setDescription("functionDescription") - .setParameters( - Schema.newBuilder() - .setType(Type.OBJECT) - .putProperties( - "param1", - Schema.newBuilder() - .setType(Type.STRING) - .setDescription("param1Description") - .build())) - .build(); - assertThat(functionDeclaration).isEqualTo(expectedFunctionDeclaration); - } -}