Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add getFunctionCalls to ResponseHanlder #10499

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import com.google.cloud.vertexai.api.Citation;
import com.google.cloud.vertexai.api.CitationMetadata;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -33,20 +35,15 @@
public class ResponseHandler {

/**
* Get the text message in a GenerateContentResponse.
* Gets the text message in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return a String that aggregates all the text parts in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static String getText(GenerateContentResponse response) {
FinishReason finishReason = getFinishReason(response);
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
checkFinishReason(getFinishReason(response));

String text = "";
List<Part> parts = response.getCandidates(0).getContent().getPartsList();
Expand All @@ -58,26 +55,40 @@ public static String getText(GenerateContentResponse response) {
}

/**
* Get the content in a GenerateContentResponse.
* Gets the list of function calls in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return a list of {@link com.google.cloud.vertexai.api.FunctionCall} in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static ImmutableList<FunctionCall> getFunctionCalls(GenerateContentResponse response) {
checkFinishReason(getFinishReason(response));
if (response.getCandidatesCount() == 0) {
return ImmutableList.of();
}
return response.getCandidates(0).getContent().getPartsList().stream()
.filter((part) -> part.hasFunctionCall())
.map((part) -> part.getFunctionCall())
.collect(ImmutableList.toImmutableList());
}

/**
* Gets the content in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return the {@link com.google.cloud.vertexai.api.Content} in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static Content getContent(GenerateContentResponse response) {
FinishReason finishReason = getFinishReason(response);
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
checkFinishReason(getFinishReason(response));

return response.getCandidates(0).getContent();
}

/**
* Get the finish reason in a GenerateContentResponse.
* Gets the finish reason in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return the {@link com.google.cloud.vertexai.api.FinishReason} in the response
Expand All @@ -93,7 +104,7 @@ public static FinishReason getFinishReason(GenerateContentResponse response) {
return response.getCandidates(0).getFinishReason();
}

/** Aggregate a stream of responses into a single GenerateContentResponse. */
/** Aggregates a stream of responses into a single GenerateContentResponse. */
static GenerateContentResponse aggregateStreamIntoResponse(
ResponseStream<GenerateContentResponse> responseStream) {
GenerateContentResponse res = GenerateContentResponse.getDefaultInstance();
Expand Down Expand Up @@ -170,4 +181,12 @@ static GenerateContentResponse aggregateStreamIntoResponse(

return res;
}

private static void checkFinishReason(FinishReason finishReason) {
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import com.google.cloud.vertexai.api.Citation;
import com.google.cloud.vertexai.api.CitationMetadata;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Iterator;
import org.junit.Rule;
Expand All @@ -47,6 +49,13 @@ public final class ResponseHandlerTest {
.addParts(Part.newBuilder().setText(TEXT_1))
.addParts(Part.newBuilder().setText(TEXT_2))
.build();
private static final Content CONTENT_WITH_FNCTION_CALL =
Content.newBuilder()
.addParts(Part.newBuilder().setText(TEXT_1))
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
.addParts(Part.newBuilder().setText(TEXT_2))
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
.build();
private static final Citation CITATION_1 =
Citation.newBuilder().setUri("gs://citation1").setStartIndex(1).setEndIndex(2).build();
private static final Citation CITATION_2 =
Expand All @@ -61,10 +70,14 @@ public final class ResponseHandlerTest {
.setContent(CONTENT)
.setCitationMetadata(CitationMetadata.newBuilder().addCitations(CITATION_2))
.build();
private static final Candidate CANDIDATE_3 =
Candidate.newBuilder().setContent(CONTENT_WITH_FNCTION_CALL).build();
private static final GenerateContentResponse RESPONSE_1 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_1).build();
private static final GenerateContentResponse RESPONSE_2 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_2).build();
private static final GenerateContentResponse RESPONSE_3 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_3).build();
private static final GenerateContentResponse INVALID_RESPONSE =
GenerateContentResponse.newBuilder()
.addCandidates(CANDIDATE_1)
Expand Down Expand Up @@ -94,6 +107,28 @@ public void testGetTextFromInvalidResponse() {
INVALID_RESPONSE.getCandidatesCount()));
}

@Test
public void testGetFunctionCallsFromResponse() {
ImmutableList<FunctionCall> functionCalls = ResponseHandler.getFunctionCalls(RESPONSE_3);
assertThat(functionCalls.size()).isEqualTo(2);
assertThat(functionCalls.get(0)).isEqualTo(FunctionCall.getDefaultInstance());
assertThat(functionCalls.get(1)).isEqualTo(FunctionCall.getDefaultInstance());
}

@Test
public void testGetFunctionCallsFromInvalidResponse() {
IllegalArgumentException thrown =
assertThrows(
IllegalArgumentException.class,
() -> ResponseHandler.getFunctionCalls(INVALID_RESPONSE));
assertThat(thrown)
.hasMessageThat()
.isEqualTo(
String.format(
"This response should have exactly 1 candidate, but it has %s.",
INVALID_RESPONSE.getCandidatesCount()));
}

@Test
public void testGetContentFromResponse() {
Content content = ResponseHandler.getContent(RESPONSE_1);
Expand Down