Skip to content

Commit

Permalink
feat: [vertexai] add fluent API in GenerativeModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617585215
  • Loading branch information
jaycee-li authored and Copybara-Service committed Mar 20, 2024
1 parent a2407ab commit cc7669c
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

Expand All @@ -41,9 +40,9 @@ public final class GenerativeModel {
private final String modelName;
private final String resourceName;
private final VertexAI vertexAi;
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
private ImmutableList<Tool> tools = ImmutableList.of();
private final GenerationConfig generationConfig;
private final ImmutableList<SafetySetting> safetySettings;
private final ImmutableList<Tool> tools;

/**
* Constructs a GenerativeModel instance.
Expand All @@ -59,8 +58,8 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
this(
modelName,
GenerationConfig.getDefaultInstance(),
new ArrayList<SafetySetting>(),
new ArrayList<Tool>(),
ImmutableList.of(),
ImmutableList.of(),
vertexAi);
}

Expand All @@ -81,22 +80,29 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
private GenerativeModel(
String modelName,
GenerationConfig generationConfig,
List<SafetySetting> safetySettings,
List<Tool> tools,
ImmutableList<SafetySetting> safetySettings,
ImmutableList<Tool> tools,
VertexAI vertexAi) {
checkArgument(
!Strings.isNullOrEmpty(modelName),
"modelName can't be null or empty. Please refer to"
+ " https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models"
+ " to find the right model name.");
checkNotNull(vertexAi, "VertexAI can't be null.");
checkNotNull(generationConfig, "GenerationConfig can't be null.");
checkNotNull(safetySettings, "ImmutableList<SafetySettings> can't be null.");
checkNotNull(tools, "ImmutableList<Tool> can't be null.");

modelName = reconcileModelName(modelName);
this.modelName = modelName;
this.resourceName =
String.format(
"projects/%s/locations/%s/publishers/google/models/%s",
vertexAi.getProjectId(), vertexAi.getLocation(), modelName);
checkNotNull(generationConfig, "GenerationConfig can't be null.");
checkNotNull(safetySettings, "List<SafetySettings> can't be null.");
checkNotNull(tools, "List<Tool> can't be null.");
this.vertexAi = vertexAi;
this.generationConfig = generationConfig;
this.safetySettings = ImmutableList.copyOf(safetySettings);
this.tools = ImmutableList.copyOf(tools);
this.safetySettings = safetySettings;
this.tools = tools;
}

/** Builder class for {@link GenerativeModel}. */
Expand Down Expand Up @@ -163,7 +169,6 @@ public Builder setSafetySettings(List<SafetySetting> safetySettings) {
checkNotNull(
safetySettings,
"safetySettings can't be null. Use an empty list if no safety settings is intended.");
safetySettings.removeIf(safetySetting -> safetySetting == null);
this.safetySettings = ImmutableList.copyOf(safetySettings);
return this;
}
Expand All @@ -175,12 +180,46 @@ public Builder setSafetySettings(List<SafetySetting> safetySettings) {
@BetaApi
public Builder setTools(List<Tool> tools) {
checkNotNull(tools, "tools can't be null. Use an empty list if no tool is to be used.");
tools.removeIf(tool -> tool == null);
this.tools = ImmutableList.copyOf(tools);
return this;
}
}

/**
* Creates a copy of the current model with updated GenerationConfig.
*
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be
* used in the new model.
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
*/
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
}

/**
* Creates a copy of the current model with updated safetySettings.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will
* be used in the new model.
* @return a new {@link GenerativeModel} instance with the specified safetySettings.
*/
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
return new GenerativeModel(
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
}

/**
* Creates a copy of the current model with updated tools.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
* the new model.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withTools(List<Tool> tools) {
return new GenerativeModel(
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
}

/**
* Counts tokens in a text message.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,34 @@ public void testGenerateContentwithDefaultTools() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentwithFluentApi() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

GenerateContentResponse unused =
model
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.generateContent(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithText() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down Expand Up @@ -569,4 +597,34 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithFluentApi() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
when(mockServerStreamCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockServerStream);
when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator);

ResponseStream unused =
model
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.generateContentStream(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}
}

0 comments on commit cc7669c

Please sign in to comment.