Skip to content

Commit

Permalink
feat: [vertexai] adding system instruction support (#10775)
Browse files Browse the repository at this point in the history
* feat(systeminstructions): adding system instructions support

* feat(systeminstructions): formatting tweaks

* feat(systeminstructions): use an optional for system instructions

* feat(systeminstructions): update field name, add getter and setter

* feat(systeminstructions): adding integration test for system instructions
  • Loading branch information
glaforge committed May 6, 2024
1 parent e8cc3b5 commit b90c291
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

/** This class holds a generative model that can complete what you provided. */
public final class GenerativeModel {
Expand All @@ -45,6 +46,7 @@ public final class GenerativeModel {
private final GenerationConfig generationConfig;
private final ImmutableList<SafetySetting> safetySettings;
private final ImmutableList<Tool> tools;
private final Optional<Content> systemInstruction;

/**
* Constructs a GenerativeModel instance.
Expand All @@ -53,7 +55,7 @@ public final class GenerativeModel {
* "models/gemini-pro", "publishers/google/models/gemini-pro", where "gemini-pro" is the model
* name. Valid model names can be found at
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
public GenerativeModel(String modelName, VertexAI vertexAi) {
Expand All @@ -62,6 +64,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
GenerationConfig.getDefaultInstance(),
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
vertexAi);
}

Expand All @@ -76,14 +79,15 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
* that will be used by default for generating response
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
* the model as auxiliary tools to generate content.
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
private GenerativeModel(
String modelName,
GenerationConfig generationConfig,
ImmutableList<SafetySetting> safetySettings,
ImmutableList<Tool> tools,
Optional<Content> systemInstruction,
VertexAI vertexAi) {
checkArgument(
!Strings.isNullOrEmpty(modelName),
Expand All @@ -105,6 +109,7 @@ private GenerativeModel(
this.generationConfig = generationConfig;
this.safetySettings = safetySettings;
this.tools = tools;
this.systemInstruction = systemInstruction;
}

/** Builder class for {@link GenerativeModel}. */
Expand All @@ -114,20 +119,22 @@ public static class Builder {
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
private ImmutableList<Tool> tools = ImmutableList.of();
private Optional<Content> systemInstruction = Optional.empty();

public GenerativeModel build() {
checkArgument(
!Strings.isNullOrEmpty(modelName),
"modelName is required. Please call setModelName() before building.");
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
return new GenerativeModel(
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
}

/**
* Sets the name of the generative model. This is required for building a GenerativeModel
* instance. Supported format: "gemini-pro", "models/gemini-pro",
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
* names can be found at
* names can be found in the Gemini models documentation
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
*/
@CanIgnoreReturnValue
Expand Down Expand Up @@ -187,6 +194,19 @@ public Builder setTools(List<Tool> tools) {
this.tools = ImmutableList.copyOf(tools);
return this;
}

/**
* Sets a system instruction that will be used by default to interact with the generative model.
*/
@CanIgnoreReturnValue
public Builder setSystemInstruction(Content systemInstruction) {
checkNotNull(
systemInstruction,
"system instruction can't be null. "
+ "Use Optional.empty() if no system instruction should be provided.");
this.systemInstruction = Optional.of(systemInstruction);
return this;
}
}

/**
Expand All @@ -197,7 +217,13 @@ public Builder setTools(List<Tool> tools) {
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
*/
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
Expand All @@ -209,19 +235,46 @@ public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
*/
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
return new GenerativeModel(
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated tools.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
* the new model.
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in the new
* model.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withTools(List<Tool> tools) {
return new GenerativeModel(
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated system instructions.
*
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system
* instructions.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withSystemInstruction(Content systemInstruction) {
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
Optional.of(systemInstruction),
vertexAi);
}

/**
Expand Down Expand Up @@ -453,13 +506,20 @@ private ApiFuture<GenerateContentResponse> generateContentAsync(GenerateContentR
*/
private GenerateContentRequest buildGenerateContentRequest(List<Content> contents) {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
return GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
.setGenerationConfig(generationConfig)
.addAllSafetySettings(safetySettings)
.addAllTools(tools)
.build();

GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
.setGenerationConfig(generationConfig)
.addAllSafetySettings(safetySettings)
.addAllTools(tools);

if (systemInstruction.isPresent()) {
requestBuilder.setSystemInstruction(systemInstruction.get());
}

return requestBuilder.build();
}

/** Returns the model name of this generative model. */
Expand All @@ -475,8 +535,7 @@ public GenerationConfig getGenerationConfig() {
}

/**
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySettings} of this generative
* model.
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySetting} of this generative model.
*/
public ImmutableList<SafetySetting> getSafetySettings() {
return safetySettings;
Expand All @@ -487,6 +546,11 @@ public ImmutableList<Tool> getTools() {
return tools;
}

/** Returns the optional system instruction of this generative model. */
public Optional<Content> getSystemInstruction() {
return systemInstruction;
}

public ChatSession startChat() {
return new ChatSession(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,28 @@ public void testGenerateContentwithContents() throws Exception {
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
}

@Test
public void testGenerateContentwithSystemInstructions() throws Exception {
String systemInstructionText =
"You're a helpful assistant that starts all its answers with: \"COOL\"";
Content systemInstructions = ContentMaker.fromString(systemInstructionText);

model = new GenerativeModel(MODEL_NAME, vertexAi).withSystemInstruction(systemInstructions);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

Content content = ContentMaker.fromString(TEXT);
GenerateContentResponse unused = model.generateContent(Arrays.asList(content));

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getSystemInstruction().getParts(0).getText())
.isEqualTo(systemInstructionText);
}

@Test
public void testGenerateContentwithDefaultGenerationConfig() throws Exception {
model =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class ITGenerativeModelIntegrationTest {

// Tested content
private static final String TEXT = "What do you think about Google Pixel?";
private static final String PIRATE_INSTRUCTION = "Speak like a pirate when answering questions.";
private static final String IMAGE_INQUIRY = "Please describe this image: ";
private static final String IMAGE_URL = "https://picsum.photos/id/1/200/300";
private static final String VIDEO_INQUIRY = "Please summarize this video: ";
Expand Down Expand Up @@ -259,4 +260,15 @@ public void countTokens_withPlainText_returnsNonZeroTokens() throws IOException
logger.info(String.format("Print number of tokens:\n%s", tokens));
assertThat(tokens.getTotalTokens()).isGreaterThan(0);
}

@Test
public void generateContent_withSystemInstruction() throws Exception {
GenerativeModel pirateModel =
textModel.withSystemInstruction(ContentMaker.fromString(PIRATE_INSTRUCTION));

// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
GenerateContentResponse pirateResponse = pirateModel.generateContent(TEXT);
assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, pirateResponse);
}
}

0 comments on commit b90c291

Please sign in to comment.