diff --git a/java-vertexai/README.md b/java-vertexai/README.md
index 4afc7a791a22..48283c527d02 100644
--- a/java-vertexai/README.md
+++ b/java-vertexai/README.md
@@ -18,7 +18,7 @@ If you are using Maven with [BOM][libraries-bom], add this to your pom.xml file:
com.google.cloud
libraries-bom
- 26.30.0
+ 26.29.0
pom
import
diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java
index 3061d9643ddb..3bdd3f76d39d 100644
--- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java
+++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/preview/GenerativeModel.java
@@ -22,11 +22,11 @@
import com.google.cloud.vertexai.api.CountTokensRequest;
import com.google.cloud.vertexai.api.CountTokensResponse;
import com.google.cloud.vertexai.api.GenerateContentRequest;
-import com.google.cloud.vertexai.api.GenerateContentRequest.Builder;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.SafetySetting;
+import com.google.cloud.vertexai.api.Tool;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@@ -40,8 +40,131 @@ public class GenerativeModel {
private final VertexAI vertexAi;
private GenerationConfig generationConfig = null;
private List safetySettings = null;
+ private List tools = null;
private Transport transport;
+ public static Builder newBuilder() {
+ return new Builder();
+ }
+
+ private GenerativeModel(Builder builder) {
+ this.modelName = builder.modelName;
+
+ this.vertexAi = builder.vertexAi;
+
+ this.resourceName =
+ String.format(
+ "projects/%s/locations/%s/publishers/google/models/%s",
+ this.vertexAi.getProjectId(), this.vertexAi.getLocation(), this.modelName);
+
+ if (builder.generationConfig != null) {
+ this.generationConfig = builder.generationConfig;
+ }
+ if (builder.safetySettings != null) {
+ this.safetySettings = builder.safetySettings;
+ }
+ if (builder.tools != null) {
+ this.tools = builder.tools;
+ }
+
+ if (builder.transport != null) {
+ this.transport = builder.transport;
+ } else {
+ this.transport = this.vertexAi.getTransport();
+ }
+ }
+
+ /** Builder class for {@link GenerativeModel}. */
+ public static class Builder {
+ private String modelName;
+ private VertexAI vertexAi;
+ private GenerationConfig generationConfig;
+ private List safetySettings;
+ private List tools;
+ private Transport transport;
+
+ private Builder() {}
+
+ public GenerativeModel build() {
+ if (this.modelName == null) {
+ throw new IllegalArgumentException(
+ "modelName is required. Please call setModelName() before building.");
+ }
+ if (this.vertexAi == null) {
+ throw new IllegalArgumentException(
+ "vertexAi is required. Please call setVertexAi() before building.");
+ }
+ return new GenerativeModel(this);
+ }
+
+ /**
+ * Set the name of the generative model. This is required for building a GenerativeModel
+ * instance. Supported format: "gemini-pro", "models/gemini-pro",
+ * "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
+ * names can be found at
+ * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
+ */
+ public Builder setModelName(String modelName) {
+ this.modelName = validateModelName(modelName);
+ return this;
+ }
+
+ /**
+ * Set {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the
+ * generative model. This is required for building a GenerativeModel instance.
+ */
+ public Builder setVertexAi(VertexAI vertexAi) {
+ this.vertexAi = vertexAi;
+ return this;
+ }
+
+ /**
+ * Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to
+ * interact with the generative model.
+ */
+ public Builder setGenerationConfig(GenerationConfig generationConfig) {
+ this.generationConfig = generationConfig;
+ return this;
+ }
+
+ /**
+ * Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by
+ * default to interact with the generative model.
+ */
+ public Builder setSafetySettings(List safetySettings) {
+ this.safetySettings = new ArrayList<>();
+ for (SafetySetting safetySetting : safetySettings) {
+ if (safetySetting != null) {
+ this.safetySettings.add(safetySetting);
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to
+ * interact with the generative model.
+ */
+ public Builder setTools(List tools) {
+ this.tools = new ArrayList<>();
+ for (Tool tool : tools) {
+ if (tool != null) {
+ this.tools.add(tool);
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Set the {@link Transport} layer for API calls in the generative model. It overrides the
+ * transport setting in {@link com.google.cloud.vertexai.VertexAI}
+ */
+ public Builder setTransport(Transport transport) {
+ this.transport = transport;
+ return this;
+ }
+ }
+
/**
* Construct a GenerativeModel instance.
*
@@ -384,7 +507,8 @@ public GenerateContentResponse generateContent(
public GenerateContentResponse generateContent(
List contents, GenerationConfig generationConfig, List safetySettings)
throws IOException {
- Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
+ GenerateContentRequest.Builder requestBuilder =
+ GenerateContentRequest.newBuilder().addAllContents(contents);
if (generationConfig != null) {
requestBuilder.setGenerationConfig(generationConfig);
} else if (this.generationConfig != null) {
@@ -395,6 +519,9 @@ public GenerateContentResponse generateContent(
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
+ if (this.tools != null) {
+ requestBuilder.addAllTools(this.tools);
+ }
return ResponseHandler.aggregateStreamIntoResponse(generateContentStream(requestBuilder));
}
@@ -655,7 +782,8 @@ public ResponseStream generateContentStream(
public ResponseStream generateContentStream(
List contents, GenerationConfig generationConfig, List safetySettings)
throws IOException {
- Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
+ GenerateContentRequest.Builder requestBuilder =
+ GenerateContentRequest.newBuilder().addAllContents(contents);
if (generationConfig != null) {
requestBuilder.setGenerationConfig(generationConfig);
} else if (this.generationConfig != null) {
@@ -666,6 +794,9 @@ public ResponseStream generateContentStream(
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
+ if (this.tools != null) {
+ requestBuilder.addAllTools(this.tools);
+ }
return generateContentStream(requestBuilder);
}
@@ -678,8 +809,8 @@ public ResponseStream generateContentStream(
* com.google.cloud.vertexai.api.GenerateContentResponse}
* @throws IOException if an I/O error occurs while making the API call
*/
- private ResponseStream generateContentStream(Builder requestBuilder)
- throws IOException {
+ private ResponseStream generateContentStream(
+ GenerateContentRequest.Builder requestBuilder) throws IOException {
GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build();
ResponseStream responseStream = null;
if (this.transport == Transport.REST) {
@@ -723,6 +854,16 @@ public void setSafetySettings(List safetySettings) {
}
}
+ /**
+ * Sets the value for {@link #getTools}, which will be used by default for generating response.
+ */
+ public void setTools(List tools) {
+ this.tools = new ArrayList<>();
+ for (Tool tool : tools) {
+ this.tools.add(tool);
+ }
+ }
+
/**
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
* generative model.
@@ -760,6 +901,15 @@ public List getSafetySettings() {
}
}
+ /** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this generative model. */
+ public List getTools() {
+ if (this.tools != null) {
+ return Collections.unmodifiableList(this.tools);
+ } else {
+ return null;
+ }
+ }
+
public ChatSession startChat() {
return new ChatSession(this);
}