Skip to content

Commit

Permalink
feat: [vertexai] Support Function calling (#10242)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599919035

Co-authored-by: Jaycee Li <jayceeli@google.com>
  • Loading branch information
copybara-service[bot] and jaycee-li committed Jan 22, 2024
1 parent 2da4e3e commit 89d2b15
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 6 deletions.
2 changes: 1 addition & 1 deletion java-vertexai/README.md
Expand Up @@ -18,7 +18,7 @@ If you are using Maven with [BOM][libraries-bom], add this to your pom.xml file:
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>libraries-bom</artifactId>
<version>26.30.0</version>
<version>26.29.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
Expand Down
Expand Up @@ -22,11 +22,11 @@
import com.google.cloud.vertexai.api.CountTokensRequest;
import com.google.cloud.vertexai.api.CountTokensResponse;
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentRequest.Builder;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -40,8 +40,131 @@ public class GenerativeModel {
private final VertexAI vertexAi;
private GenerationConfig generationConfig = null;
private List<SafetySetting> safetySettings = null;
private List<Tool> tools = null;
private Transport transport;

public static Builder newBuilder() {
return new Builder();
}

private GenerativeModel(Builder builder) {
this.modelName = builder.modelName;

this.vertexAi = builder.vertexAi;

this.resourceName =
String.format(
"projects/%s/locations/%s/publishers/google/models/%s",
this.vertexAi.getProjectId(), this.vertexAi.getLocation(), this.modelName);

if (builder.generationConfig != null) {
this.generationConfig = builder.generationConfig;
}
if (builder.safetySettings != null) {
this.safetySettings = builder.safetySettings;
}
if (builder.tools != null) {
this.tools = builder.tools;
}

if (builder.transport != null) {
this.transport = builder.transport;
} else {
this.transport = this.vertexAi.getTransport();
}
}

/** Builder class for {@link GenerativeModel}. */
public static class Builder {
private String modelName;
private VertexAI vertexAi;
private GenerationConfig generationConfig;
private List<SafetySetting> safetySettings;
private List<Tool> tools;
private Transport transport;

private Builder() {}

public GenerativeModel build() {
if (this.modelName == null) {
throw new IllegalArgumentException(
"modelName is required. Please call setModelName() before building.");
}
if (this.vertexAi == null) {
throw new IllegalArgumentException(
"vertexAi is required. Please call setVertexAi() before building.");
}
return new GenerativeModel(this);
}

/**
* Set the name of the generative model. This is required for building a GenerativeModel
* instance. Supported format: "gemini-pro", "models/gemini-pro",
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
* names can be found at
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
*/
public Builder setModelName(String modelName) {
this.modelName = validateModelName(modelName);
return this;
}

/**
* Set {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the
* generative model. This is required for building a GenerativeModel instance.
*/
public Builder setVertexAi(VertexAI vertexAi) {
this.vertexAi = vertexAi;
return this;
}

/**
* Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to
* interact with the generative model.
*/
public Builder setGenerationConfig(GenerationConfig generationConfig) {
this.generationConfig = generationConfig;
return this;
}

/**
* Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by
* default to interact with the generative model.
*/
public Builder setSafetySettings(List<SafetySetting> safetySettings) {
this.safetySettings = new ArrayList<>();
for (SafetySetting safetySetting : safetySettings) {
if (safetySetting != null) {
this.safetySettings.add(safetySetting);
}
}
return this;
}

/**
* Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to
* interact with the generative model.
*/
public Builder setTools(List<Tool> tools) {
this.tools = new ArrayList<>();
for (Tool tool : tools) {
if (tool != null) {
this.tools.add(tool);
}
}
return this;
}

/**
* Set the {@link Transport} layer for API calls in the generative model. It overrides the
* transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
public Builder setTransport(Transport transport) {
this.transport = transport;
return this;
}
}

/**
* Construct a GenerativeModel instance.
*
Expand Down Expand Up @@ -384,7 +507,8 @@ public GenerateContentResponse generateContent(
public GenerateContentResponse generateContent(
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
throws IOException {
Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
if (generationConfig != null) {
requestBuilder.setGenerationConfig(generationConfig);
} else if (this.generationConfig != null) {
Expand All @@ -395,6 +519,9 @@ public GenerateContentResponse generateContent(
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
if (this.tools != null) {
requestBuilder.addAllTools(this.tools);
}
return ResponseHandler.aggregateStreamIntoResponse(generateContentStream(requestBuilder));
}

Expand Down Expand Up @@ -655,7 +782,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
public ResponseStream<GenerateContentResponse> generateContentStream(
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
throws IOException {
Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
if (generationConfig != null) {
requestBuilder.setGenerationConfig(generationConfig);
} else if (this.generationConfig != null) {
Expand All @@ -666,6 +794,9 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
if (this.tools != null) {
requestBuilder.addAllTools(this.tools);
}
return generateContentStream(requestBuilder);
}

Expand All @@ -678,8 +809,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
* com.google.cloud.vertexai.api.GenerateContentResponse}
* @throws IOException if an I/O error occurs while making the API call
*/
private ResponseStream<GenerateContentResponse> generateContentStream(Builder requestBuilder)
throws IOException {
private ResponseStream<GenerateContentResponse> generateContentStream(
GenerateContentRequest.Builder requestBuilder) throws IOException {
GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build();
ResponseStream<GenerateContentResponse> responseStream = null;
if (this.transport == Transport.REST) {
Expand Down Expand Up @@ -723,6 +854,16 @@ public void setSafetySettings(List<SafetySetting> safetySettings) {
}
}

/**
* Sets the value for {@link #getTools}, which will be used by default for generating response.
*/
public void setTools(List<Tool> tools) {
this.tools = new ArrayList<>();
for (Tool tool : tools) {
this.tools.add(tool);
}
}

/**
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
* generative model.
Expand Down Expand Up @@ -760,6 +901,15 @@ public List<SafetySetting> getSafetySettings() {
}
}

/** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this generative model. */
public List<Tool> getTools() {
if (this.tools != null) {
return Collections.unmodifiableList(this.tools);
} else {
return null;
}
}

public ChatSession startChat() {
return new ChatSession(this);
}
Expand Down

0 comments on commit 89d2b15

Please sign in to comment.