diff --git a/java-vertexai/README.md b/java-vertexai/README.md index 4afc7a791a22..48283c527d02 100644 --- a/java-vertexai/README.md +++ b/java-vertexai/README.md @@ -18,7 +18,7 @@ If you are using Maven with [BOM][libraries-bom], add this to your pom.xml file: com.google.cloud libraries-bom - 26.30.0 + 26.29.0 pom import diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java index 3061d9643ddb..3bdd3f76d39d 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java @@ -22,11 +22,11 @@ import com.google.cloud.vertexai.api.CountTokensRequest; import com.google.cloud.vertexai.api.CountTokensResponse; import com.google.cloud.vertexai.api.GenerateContentRequest; -import com.google.cloud.vertexai.api.GenerateContentRequest.Builder; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.Part; import com.google.cloud.vertexai.api.SafetySetting; +import com.google.cloud.vertexai.api.Tool; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -40,8 +40,131 @@ public class GenerativeModel { private final VertexAI vertexAi; private GenerationConfig generationConfig = null; private List safetySettings = null; + private List tools = null; private Transport transport; + public static Builder newBuilder() { + return new Builder(); + } + + private GenerativeModel(Builder builder) { + this.modelName = builder.modelName; + + this.vertexAi = builder.vertexAi; + + this.resourceName = + String.format( + "projects/%s/locations/%s/publishers/google/models/%s", + this.vertexAi.getProjectId(), this.vertexAi.getLocation(), this.modelName); + + if (builder.generationConfig != null) { + this.generationConfig = builder.generationConfig; + } + if (builder.safetySettings != null) { + this.safetySettings = builder.safetySettings; + } + if (builder.tools != null) { + this.tools = builder.tools; + } + + if (builder.transport != null) { + this.transport = builder.transport; + } else { + this.transport = this.vertexAi.getTransport(); + } + } + + /** Builder class for {@link GenerativeModel}. */ + public static class Builder { + private String modelName; + private VertexAI vertexAi; + private GenerationConfig generationConfig; + private List safetySettings; + private List tools; + private Transport transport; + + private Builder() {} + + public GenerativeModel build() { + if (this.modelName == null) { + throw new IllegalArgumentException( + "modelName is required. Please call setModelName() before building."); + } + if (this.vertexAi == null) { + throw new IllegalArgumentException( + "vertexAi is required. Please call setVertexAi() before building."); + } + return new GenerativeModel(this); + } + + /** + * Set the name of the generative model. This is required for building a GenerativeModel + * instance. Supported format: "gemini-pro", "models/gemini-pro", + * "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model + * names can be found at + * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models + */ + public Builder setModelName(String modelName) { + this.modelName = validateModelName(modelName); + return this; + } + + /** + * Set {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the + * generative model. This is required for building a GenerativeModel instance. + */ + public Builder setVertexAi(VertexAI vertexAi) { + this.vertexAi = vertexAi; + return this; + } + + /** + * Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to + * interact with the generative model. + */ + public Builder setGenerationConfig(GenerationConfig generationConfig) { + this.generationConfig = generationConfig; + return this; + } + + /** + * Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by + * default to interact with the generative model. + */ + public Builder setSafetySettings(List safetySettings) { + this.safetySettings = new ArrayList<>(); + for (SafetySetting safetySetting : safetySettings) { + if (safetySetting != null) { + this.safetySettings.add(safetySetting); + } + } + return this; + } + + /** + * Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to + * interact with the generative model. + */ + public Builder setTools(List tools) { + this.tools = new ArrayList<>(); + for (Tool tool : tools) { + if (tool != null) { + this.tools.add(tool); + } + } + return this; + } + + /** + * Set the {@link Transport} layer for API calls in the generative model. It overrides the + * transport setting in {@link com.google.cloud.vertexai.VertexAI} + */ + public Builder setTransport(Transport transport) { + this.transport = transport; + return this; + } + } + /** * Construct a GenerativeModel instance. * @@ -384,7 +507,8 @@ public GenerateContentResponse generateContent( public GenerateContentResponse generateContent( List contents, GenerationConfig generationConfig, List safetySettings) throws IOException { - Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents); + GenerateContentRequest.Builder requestBuilder = + GenerateContentRequest.newBuilder().addAllContents(contents); if (generationConfig != null) { requestBuilder.setGenerationConfig(generationConfig); } else if (this.generationConfig != null) { @@ -395,6 +519,9 @@ public GenerateContentResponse generateContent( } else if (this.safetySettings != null) { requestBuilder.addAllSafetySettings(this.safetySettings); } + if (this.tools != null) { + requestBuilder.addAllTools(this.tools); + } return ResponseHandler.aggregateStreamIntoResponse(generateContentStream(requestBuilder)); } @@ -655,7 +782,8 @@ public ResponseStream generateContentStream( public ResponseStream generateContentStream( List contents, GenerationConfig generationConfig, List safetySettings) throws IOException { - Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents); + GenerateContentRequest.Builder requestBuilder = + GenerateContentRequest.newBuilder().addAllContents(contents); if (generationConfig != null) { requestBuilder.setGenerationConfig(generationConfig); } else if (this.generationConfig != null) { @@ -666,6 +794,9 @@ public ResponseStream generateContentStream( } else if (this.safetySettings != null) { requestBuilder.addAllSafetySettings(this.safetySettings); } + if (this.tools != null) { + requestBuilder.addAllTools(this.tools); + } return generateContentStream(requestBuilder); } @@ -678,8 +809,8 @@ public ResponseStream generateContentStream( * com.google.cloud.vertexai.api.GenerateContentResponse} * @throws IOException if an I/O error occurs while making the API call */ - private ResponseStream generateContentStream(Builder requestBuilder) - throws IOException { + private ResponseStream generateContentStream( + GenerateContentRequest.Builder requestBuilder) throws IOException { GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build(); ResponseStream responseStream = null; if (this.transport == Transport.REST) { @@ -723,6 +854,16 @@ public void setSafetySettings(List safetySettings) { } } + /** + * Sets the value for {@link #getTools}, which will be used by default for generating response. + */ + public void setTools(List tools) { + this.tools = new ArrayList<>(); + for (Tool tool : tools) { + this.tools.add(tool); + } + } + /** * Sets the value for {@link #getTransport}, which defines the layer for API calls in this * generative model. @@ -760,6 +901,15 @@ public List getSafetySettings() { } } + /** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this generative model. */ + public List getTools() { + if (this.tools != null) { + return Collections.unmodifiableList(this.tools); + } else { + return null; + } + } + public ChatSession startChat() { return new ChatSession(this); }