Skip to content

Commit b90c291

Browse files
authoredMay 6, 2024··
feat: [vertexai] adding system instruction support (#10775)
* feat(systeminstructions): adding system instructions support * feat(systeminstructions): formatting tweaks * feat(systeminstructions): use an optional for system instructions * feat(systeminstructions): update field name, add getter and setter * feat(systeminstructions): adding integration test for system instructions
1 parent e8cc3b5 commit b90c291

File tree

3 files changed

+116
-18
lines changed

3 files changed

+116
-18
lines changed
 

‎java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

+82-18
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.io.IOException;
3737
import java.util.Arrays;
3838
import java.util.List;
39+
import java.util.Optional;
3940

4041
/** This class holds a generative model that can complete what you provided. */
4142
public final class GenerativeModel {
@@ -45,6 +46,7 @@ public final class GenerativeModel {
4546
private final GenerationConfig generationConfig;
4647
private final ImmutableList<SafetySetting> safetySettings;
4748
private final ImmutableList<Tool> tools;
49+
private final Optional<Content> systemInstruction;
4850

4951
/**
5052
* Constructs a GenerativeModel instance.
@@ -53,7 +55,7 @@ public final class GenerativeModel {
5355
* "models/gemini-pro", "publishers/google/models/gemini-pro", where "gemini-pro" is the model
5456
* name. Valid model names can be found at
5557
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
56-
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
58+
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
5759
* for the generative model
5860
*/
5961
public GenerativeModel(String modelName, VertexAI vertexAi) {
@@ -62,6 +64,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
6264
GenerationConfig.getDefaultInstance(),
6365
ImmutableList.of(),
6466
ImmutableList.of(),
67+
Optional.empty(),
6568
vertexAi);
6669
}
6770

@@ -76,14 +79,15 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
7679
* that will be used by default for generating response
7780
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
7881
* the model as auxiliary tools to generate content.
79-
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
82+
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
8083
* for the generative model
8184
*/
8285
private GenerativeModel(
8386
String modelName,
8487
GenerationConfig generationConfig,
8588
ImmutableList<SafetySetting> safetySettings,
8689
ImmutableList<Tool> tools,
90+
Optional<Content> systemInstruction,
8791
VertexAI vertexAi) {
8892
checkArgument(
8993
!Strings.isNullOrEmpty(modelName),
@@ -105,6 +109,7 @@ private GenerativeModel(
105109
this.generationConfig = generationConfig;
106110
this.safetySettings = safetySettings;
107111
this.tools = tools;
112+
this.systemInstruction = systemInstruction;
108113
}
109114

110115
/** Builder class for {@link GenerativeModel}. */
@@ -114,20 +119,22 @@ public static class Builder {
114119
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
115120
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
116121
private ImmutableList<Tool> tools = ImmutableList.of();
122+
private Optional<Content> systemInstruction = Optional.empty();
117123

118124
public GenerativeModel build() {
119125
checkArgument(
120126
!Strings.isNullOrEmpty(modelName),
121127
"modelName is required. Please call setModelName() before building.");
122128
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
123-
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
129+
return new GenerativeModel(
130+
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
124131
}
125132

126133
/**
127134
* Sets the name of the generative model. This is required for building a GenerativeModel
128135
* instance. Supported format: "gemini-pro", "models/gemini-pro",
129136
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
130-
* names can be found at
137+
* names can be found in the Gemini models documentation
131138
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
132139
*/
133140
@CanIgnoreReturnValue
@@ -187,6 +194,19 @@ public Builder setTools(List<Tool> tools) {
187194
this.tools = ImmutableList.copyOf(tools);
188195
return this;
189196
}
197+
198+
/**
199+
* Sets a system instruction that will be used by default to interact with the generative model.
200+
*/
201+
@CanIgnoreReturnValue
202+
public Builder setSystemInstruction(Content systemInstruction) {
203+
checkNotNull(
204+
systemInstruction,
205+
"system instruction can't be null. "
206+
+ "Use Optional.empty() if no system instruction should be provided.");
207+
this.systemInstruction = Optional.of(systemInstruction);
208+
return this;
209+
}
190210
}
191211

192212
/**
@@ -197,7 +217,13 @@ public Builder setTools(List<Tool> tools) {
197217
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
198218
*/
199219
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
200-
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
220+
return new GenerativeModel(
221+
modelName,
222+
generationConfig,
223+
ImmutableList.copyOf(safetySettings),
224+
ImmutableList.copyOf(tools),
225+
systemInstruction,
226+
vertexAi);
201227
}
202228

203229
/**
@@ -209,19 +235,46 @@ public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
209235
*/
210236
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
211237
return new GenerativeModel(
212-
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
238+
modelName,
239+
generationConfig,
240+
ImmutableList.copyOf(safetySettings),
241+
ImmutableList.copyOf(tools),
242+
systemInstruction,
243+
vertexAi);
213244
}
214245

215246
/**
216247
* Creates a copy of the current model with updated tools.
217248
*
218-
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
219-
* the new model.
249+
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in the new
250+
* model.
220251
* @return a new {@link GenerativeModel} instance with the specified tools.
221252
*/
222253
public GenerativeModel withTools(List<Tool> tools) {
223254
return new GenerativeModel(
224-
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
255+
modelName,
256+
generationConfig,
257+
ImmutableList.copyOf(safetySettings),
258+
ImmutableList.copyOf(tools),
259+
systemInstruction,
260+
vertexAi);
261+
}
262+
263+
/**
264+
* Creates a copy of the current model with updated system instructions.
265+
*
266+
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system
267+
* instructions.
268+
* @return a new {@link GenerativeModel} instance with the specified tools.
269+
*/
270+
public GenerativeModel withSystemInstruction(Content systemInstruction) {
271+
return new GenerativeModel(
272+
modelName,
273+
generationConfig,
274+
ImmutableList.copyOf(safetySettings),
275+
ImmutableList.copyOf(tools),
276+
Optional.of(systemInstruction),
277+
vertexAi);
225278
}
226279

227280
/**
@@ -453,13 +506,20 @@ private ApiFuture<GenerateContentResponse> generateContentAsync(GenerateContentR
453506
*/
454507
private GenerateContentRequest buildGenerateContentRequest(List<Content> contents) {
455508
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
456-
return GenerateContentRequest.newBuilder()
457-
.setModel(resourceName)
458-
.addAllContents(contents)
459-
.setGenerationConfig(generationConfig)
460-
.addAllSafetySettings(safetySettings)
461-
.addAllTools(tools)
462-
.build();
509+
510+
GenerateContentRequest.Builder requestBuilder =
511+
GenerateContentRequest.newBuilder()
512+
.setModel(resourceName)
513+
.addAllContents(contents)
514+
.setGenerationConfig(generationConfig)
515+
.addAllSafetySettings(safetySettings)
516+
.addAllTools(tools);
517+
518+
if (systemInstruction.isPresent()) {
519+
requestBuilder.setSystemInstruction(systemInstruction.get());
520+
}
521+
522+
return requestBuilder.build();
463523
}
464524

465525
/** Returns the model name of this generative model. */
@@ -475,8 +535,7 @@ public GenerationConfig getGenerationConfig() {
475535
}
476536

477537
/**
478-
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySettings} of this generative
479-
* model.
538+
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySetting} of this generative model.
480539
*/
481540
public ImmutableList<SafetySetting> getSafetySettings() {
482541
return safetySettings;
@@ -487,6 +546,11 @@ public ImmutableList<Tool> getTools() {
487546
return tools;
488547
}
489548

549+
/** Returns the optional system instruction of this generative model. */
550+
public Optional<Content> getSystemInstruction() {
551+
return systemInstruction;
552+
}
553+
490554
public ChatSession startChat() {
491555
return new ChatSession(this);
492556
}

‎java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java

+22
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,28 @@ public void testGenerateContentwithContents() throws Exception {
324324
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
325325
}
326326

327+
@Test
328+
public void testGenerateContentwithSystemInstructions() throws Exception {
329+
String systemInstructionText =
330+
"You're a helpful assistant that starts all its answers with: \"COOL\"";
331+
Content systemInstructions = ContentMaker.fromString(systemInstructionText);
332+
333+
model = new GenerativeModel(MODEL_NAME, vertexAi).withSystemInstruction(systemInstructions);
334+
335+
when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
336+
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
337+
.thenReturn(mockGenerateContentResponse);
338+
339+
Content content = ContentMaker.fromString(TEXT);
340+
GenerateContentResponse unused = model.generateContent(Arrays.asList(content));
341+
342+
ArgumentCaptor<GenerateContentRequest> request =
343+
ArgumentCaptor.forClass(GenerateContentRequest.class);
344+
verify(mockUnaryCallable).call(request.capture());
345+
assertThat(request.getValue().getSystemInstruction().getParts(0).getText())
346+
.isEqualTo(systemInstructionText);
347+
}
348+
327349
@Test
328350
public void testGenerateContentwithDefaultGenerationConfig() throws Exception {
329351
model =

‎java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java

+12
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public class ITGenerativeModelIntegrationTest {
5858

5959
// Tested content
6060
private static final String TEXT = "What do you think about Google Pixel?";
61+
private static final String PIRATE_INSTRUCTION = "Speak like a pirate when answering questions.";
6162
private static final String IMAGE_INQUIRY = "Please describe this image: ";
6263
private static final String IMAGE_URL = "https://picsum.photos/id/1/200/300";
6364
private static final String VIDEO_INQUIRY = "Please summarize this video: ";
@@ -259,4 +260,15 @@ public void countTokens_withPlainText_returnsNonZeroTokens() throws IOException
259260
logger.info(String.format("Print number of tokens:\n%s", tokens));
260261
assertThat(tokens.getTotalTokens()).isGreaterThan(0);
261262
}
263+
264+
@Test
265+
public void generateContent_withSystemInstruction() throws Exception {
266+
GenerativeModel pirateModel =
267+
textModel.withSystemInstruction(ContentMaker.fromString(PIRATE_INSTRUCTION));
268+
269+
// GenAI output is flaky so we always print out the response.
270+
// For the same reason, we don't do assertions much.
271+
GenerateContentResponse pirateResponse = pirateModel.generateContent(TEXT);
272+
assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, pirateResponse);
273+
}
262274
}

0 commit comments

Comments
 (0)
Please sign in to comment.