Skip to content

Commit

Permalink
feat: Add getFunctionCalls to ResponseHanlder
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613710240
  • Loading branch information
happy-qiao authored and Copybara-Service committed Mar 7, 2024
1 parent 5a06eca commit c81be82
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import com.google.cloud.vertexai.api.Citation;
import com.google.cloud.vertexai.api.CitationMetadata;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -33,20 +35,15 @@
public class ResponseHandler {

/**
* Get the text message in a GenerateContentResponse.
* Gets the text message in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return a String that aggregates all the text parts in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static String getText(GenerateContentResponse response) {
FinishReason finishReason = getFinishReason(response);
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
checkFinishReason(getFinishReason(response));

String text = "";
List<Part> parts = response.getCandidates(0).getContent().getPartsList();
Expand All @@ -58,26 +55,40 @@ public static String getText(GenerateContentResponse response) {
}

/**
* Get the content in a GenerateContentResponse.
* Gets the list of function calls in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return a list of {@link com.google.cloud.vertexai.api.FunctionCall} in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static ImmutableList<FunctionCall> getFunctionCalls(GenerateContentResponse response) {
checkFinishReason(getFinishReason(response));
if (response.getCandidatesCount() == 0) {
return ImmutableList.of();
}
return response.getCandidates(0).getContent().getPartsList().stream()
.filter((part) -> part.hasFunctionCall())
.map((part) -> part.getFunctionCall())
.collect(ImmutableList.toImmutableList());
}

/**
* Gets the content in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return the {@link com.google.cloud.vertexai.api.Content} in the response
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
* response is blocked by safety reason or unauthorized citations
*/
public static Content getContent(GenerateContentResponse response) {
FinishReason finishReason = getFinishReason(response);
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
checkFinishReason(getFinishReason(response));

return response.getCandidates(0).getContent();
}

/**
* Get the finish reason in a GenerateContentResponse.
* Gets the finish reason in a GenerateContentResponse.
*
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
* @return the {@link com.google.cloud.vertexai.api.FinishReason} in the response
Expand All @@ -93,7 +104,7 @@ public static FinishReason getFinishReason(GenerateContentResponse response) {
return response.getCandidates(0).getFinishReason();
}

/** Aggregate a stream of responses into a single GenerateContentResponse. */
/** Aggregates a stream of responses into a single GenerateContentResponse. */
static GenerateContentResponse aggregateStreamIntoResponse(
ResponseStream<GenerateContentResponse> responseStream) {
GenerateContentResponse res = GenerateContentResponse.getDefaultInstance();
Expand Down Expand Up @@ -170,4 +181,12 @@ static GenerateContentResponse aggregateStreamIntoResponse(

return res;
}

private static void checkFinishReason(FinishReason finishReason) {
if (finishReason == FinishReason.SAFETY) {
throw new IllegalArgumentException("The response is blocked due to safety reason.");
} else if (finishReason == FinishReason.RECITATION) {
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import com.google.cloud.vertexai.api.Citation;
import com.google.cloud.vertexai.api.CitationMetadata;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Iterator;
import org.junit.Rule;
Expand All @@ -47,6 +49,13 @@ public final class ResponseHandlerTest {
.addParts(Part.newBuilder().setText(TEXT_1))
.addParts(Part.newBuilder().setText(TEXT_2))
.build();
private static final Content CONTENT_WITH_FNCTION_CALL =
Content.newBuilder()
.addParts(Part.newBuilder().setText(TEXT_1))
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
.addParts(Part.newBuilder().setText(TEXT_2))
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
.build();
private static final Citation CITATION_1 =
Citation.newBuilder().setUri("gs://citation1").setStartIndex(1).setEndIndex(2).build();
private static final Citation CITATION_2 =
Expand All @@ -61,10 +70,14 @@ public final class ResponseHandlerTest {
.setContent(CONTENT)
.setCitationMetadata(CitationMetadata.newBuilder().addCitations(CITATION_2))
.build();
private static final Candidate CANDIDATE_3 =
Candidate.newBuilder().setContent(CONTENT_WITH_FNCTION_CALL).build();
private static final GenerateContentResponse RESPONSE_1 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_1).build();
private static final GenerateContentResponse RESPONSE_2 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_2).build();
private static final GenerateContentResponse RESPONSE_3 =
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_3).build();
private static final GenerateContentResponse INVALID_RESPONSE =
GenerateContentResponse.newBuilder()
.addCandidates(CANDIDATE_1)
Expand Down Expand Up @@ -94,6 +107,28 @@ public void testGetTextFromInvalidResponse() {
INVALID_RESPONSE.getCandidatesCount()));
}

@Test
public void testGetFunctionCallsFromResponse() {
ImmutableList<FunctionCall> functionCalls = ResponseHandler.getFunctionCalls(RESPONSE_3);
assertThat(functionCalls.size()).isEqualTo(2);
assertThat(functionCalls.get(0)).isEqualTo(FunctionCall.getDefaultInstance());
assertThat(functionCalls.get(1)).isEqualTo(FunctionCall.getDefaultInstance());
}

@Test
public void testGetFunctionCallsFromInvalidResponse() {
IllegalArgumentException thrown =
assertThrows(
IllegalArgumentException.class,
() -> ResponseHandler.getFunctionCalls(INVALID_RESPONSE));
assertThat(thrown)
.hasMessageThat()
.isEqualTo(
String.format(
"This response should have exactly 1 candidate, but it has %s.",
INVALID_RESPONSE.getCandidatesCount()));
}

@Test
public void testGetContentFromResponse() {
Content content = ResponseHandler.getContent(RESPONSE_1);
Expand Down

0 comments on commit c81be82

Please sign in to comment.