Skip to content

Commit

Permalink
feat: [vertexai] add generateContentAsync methods to GenerativeModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617951189
  • Loading branch information
jaycee-li authored and Copybara-Service committed Mar 21, 2024
1 parent b5e8e3d commit c8b48d9
Show file tree
Hide file tree
Showing 4 changed files with 411 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,97 @@
import com.google.cloud.vertexai.api.Candidate.FinishReason;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

/** Represents a conversation between the user and the model */
public final class ChatSession {
private final GenerativeModel model;
private final Optional<ChatSession> rootChatSession;
private List<Content> history = new ArrayList<>();
private ResponseStream<GenerateContentResponse> currentResponseStream = null;
private GenerateContentResponse currentResponse = null;
private Optional<ResponseStream<GenerateContentResponse>> currentResponseStream;
private Optional<GenerateContentResponse> currentResponse;

/**
* Creates a new chat session given a GenerativeModel instance. Configurations of the chat (e.g.,
* GenerationConfig) inherits from the model.
*/
@BetaApi
public ChatSession(GenerativeModel model) {
this(model, Optional.empty());
}

/**
* Creates a new chat session given a GenerativeModel instance and a root chat session.
* Configurations of the chat (e.g., GenerationConfig) inherits from the model.
*
* @param model a {@link GenerativeModel} instance that generates contents in the chat.
* @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current
* chat session will be merged to the root chat session.
* @return a {@link ChatSession} instance.
*/
@BetaApi
private ChatSession(GenerativeModel model, Optional<ChatSession> rootChatSession) {
if (model == null) {
throw new IllegalArgumentException("model should not be null");
}
this.model = model;
this.rootChatSession = rootChatSession;
currentResponseStream = Optional.empty();
currentResponse = Optional.empty();
}

/**
* Creates a copy of the current ChatSession with updated GenerationConfig.
*
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be
* used in the new ChatSession.
* @return a new {@link ChatSession} instance with the specified GenerationConfig.
*/
@BetaApi
public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(model.withGenerationConfig(generationConfig), Optional.of(rootChat));
newChatSession.setHistory(history);
return newChatSession;
}

/**
* Creates a copy of the current ChatSession with updated SafetySettings.
*
* @param safetySettings a {@link com.google.cloud.vertexai.api.SafetySetting} that will be used
* in the new ChatSession.
* @return a new {@link ChatSession} instance with the specified SafetySettings.
*/
@BetaApi
public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(model.withSafetySettings(safetySettings), Optional.of(rootChat));
newChatSession.setHistory(history);
return newChatSession;
}

/**
* Creates a copy of the current ChatSession with updated Tools.
*
* @param tools a {@link com.google.cloud.vertexai.api.Tool} that will be used in the new
* ChatSession.
* @return a new {@link ChatSession} instance with the specified Tools.
*/
@BetaApi
public ChatSession withTools(List<Tool> tools) {
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession = new ChatSession(model.withTools(tools), Optional.of(rootChat));
newChatSession.setHistory(history);
return newChatSession;
}

/**
Expand Down Expand Up @@ -69,8 +142,8 @@ public ResponseStream<GenerateContentResponse> sendMessageStream(Content content
checkLastResponseAndEditHistory();
history.add(content);
ResponseStream<GenerateContentResponse> respStream = model.generateContentStream(history);
currentResponseStream = respStream;
currentResponse = null;
setCurrentResponseStream(Optional.of(respStream));

return respStream;
}

Expand All @@ -96,8 +169,7 @@ public GenerateContentResponse sendMessage(Content content) throws IOException {
checkLastResponseAndEditHistory();
history.add(content);
GenerateContentResponse response = model.generateContent(history);
currentResponse = response;
currentResponseStream = null;
setCurrentResponse(Optional.of(response));
return response;
}

Expand All @@ -112,38 +184,37 @@ private void removeLastContent() {
* @throws IllegalStateException if the response stream is not finished.
*/
private void checkLastResponseAndEditHistory() {
if (currentResponseStream == null && currentResponse == null) {
return;
} else if (currentResponseStream != null && !currentResponseStream.isConsumed()) {
throw new IllegalStateException("Response stream is not consumed");
} else if (currentResponseStream != null && currentResponseStream.isConsumed()) {
GenerateContentResponse response = aggregateStreamIntoResponse(currentResponseStream);
FinishReason finishReason = getFinishReason(response);
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
// We also remove the request from the history.
removeLastContent();
currentResponseStream = null;
throw new IllegalStateException(
String.format(
"The last round of conversation will not be added to history because response"
+ " stream did not finish normally. Finish reason is %s.",
finishReason));
}
history.add(getContent(response));
} else if (currentResponseStream == null && currentResponse != null) {
FinishReason finishReason = getFinishReason(currentResponse);
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
// We also remove the request from the history.
removeLastContent();
currentResponse = null;
throw new IllegalStateException(
String.format(
"The last round of conversation will not be added to history because response did"
+ " not finish normally. Finish reason is %s.",
finishReason));
}
history.add(getContent(currentResponse));
currentResponse = null;
getCurrentResponse()
.ifPresent(
currentResponse -> {
setCurrentResponse(Optional.empty());
checkFinishReasonAndRemoveLastContent(currentResponse);
history.add(getContent(currentResponse));
});
getCurrentResponseStream()
.ifPresent(
responseStream -> {
if (!responseStream.isConsumed()) {
throw new IllegalStateException("Response stream is not consumed");
} else {
setCurrentResponseStream(Optional.empty());
GenerateContentResponse response = aggregateStreamIntoResponse(responseStream);
checkFinishReasonAndRemoveLastContent(response);
history.add(getContent(response));
}
});
}

/** Removes the last content in the history if the response finished with problems. */
private void checkFinishReasonAndRemoveLastContent(GenerateContentResponse response) {
FinishReason finishReason = getFinishReason(response);
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
removeLastContent();
throw new IllegalStateException(
String.format(
"The last round of conversation will not be added to history because response"
+ " stream did not finish normally. Finish reason is %s.",
finishReason));
}
}

Expand All @@ -169,9 +240,62 @@ public List<Content> getHistory() {
return Collections.unmodifiableList(history);
}

/**
* Returns the current response of the root chat session (if exists) or the current chat session.
*/
private Optional<GenerateContentResponse> getCurrentResponse() {
if (rootChatSession.isPresent()) {
return rootChatSession.get().getCurrentResponse();
} else {
return currentResponse;
}
}

/**
* Returns the current responseStream of the root chat session (if exists) or the current chat
* session.
*/
private Optional<ResponseStream<GenerateContentResponse>> getCurrentResponseStream() {
if (rootChatSession.isPresent()) {
return rootChatSession.get().getCurrentResponseStream();
} else {
return currentResponseStream;
}
}

/** Set the history to a list of Content */
@BetaApi
public void setHistory(List<Content> history) {
this.history = history;
}

/** Sets the current response of the root chat session (if exists) or the current chat session. */
private void setCurrentResponse(Optional<GenerateContentResponse> response) {
if (currentResponseStream.isPresent()) {
throw new IllegalStateException(
"currentResponse and currentResponseStream cannot be set together");
}
if (rootChatSession.isPresent()) {
rootChatSession.get().setCurrentResponse(response);
} else {
currentResponse = response;
}
}

/**
* Sets the current responseStream of the root chat session (if exists) or the current chat
* session.
*/
private void setCurrentResponseStream(
Optional<ResponseStream<GenerateContentResponse>> responseStream) {
if (currentResponse.isPresent()) {
throw new IllegalStateException(
"currentResponseStream and currentResponse cannot be set together");
}
if (rootChatSession.isPresent()) {
rootChatSession.get().setCurrentResponseStream(responseStream);
} else {
currentResponseStream = responseStream;
}
}
}

0 comments on commit c8b48d9

Please sign in to comment.