Skip to content

Commit

Permalink
BREAKING_CHANGE: [vertexai] remove Transport from GenerativeModel (#1…
Browse files Browse the repository at this point in the history
…0530)

PiperOrigin-RevId: 615144883

Co-authored-by: Jaycee Li <jayceeli@google.com>
  • Loading branch information
copybara-service[bot] and jaycee-li committed Mar 16, 2024
1 parent e153330 commit f024111
Show file tree
Hide file tree
Showing 22 changed files with 8,102 additions and 160 deletions.
Expand Up @@ -42,7 +42,6 @@ public class GenerativeModel {
private GenerationConfig generationConfig = null;
private List<SafetySetting> safetySettings = null;
private List<Tool> tools = null;
private Transport transport;

public static Builder newBuilder() {
return new Builder();
Expand All @@ -67,12 +66,6 @@ private GenerativeModel(Builder builder) {
if (builder.tools != null) {
this.tools = builder.tools;
}

if (builder.transport != null) {
this.transport = builder.transport;
} else {
this.transport = this.vertexAi.getTransport();
}
}

/** Builder class for {@link GenerativeModel}. */
Expand All @@ -82,7 +75,6 @@ public static class Builder {
private GenerationConfig generationConfig;
private List<SafetySetting> safetySettings;
private List<Tool> tools;
private Transport transport;

private Builder() {}

Expand Down Expand Up @@ -158,15 +150,6 @@ public Builder setTools(List<Tool> tools) {
}
return this;
}

/**
* Sets the {@link Transport} layer for API calls in the generative model. It overrides the
* transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
public Builder setTransport(Transport transport) {
this.transport = transport;
return this;
}
}

/**
Expand All @@ -180,21 +163,7 @@ public Builder setTransport(Transport transport) {
* for the generative model
*/
public GenerativeModel(String modelName, VertexAI vertexAi) {
this(modelName, null, null, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
public GenerativeModel(String modelName, VertexAI vertexAi, Transport transport) {
this(modelName, null, null, vertexAi, transport);
this(modelName, null, null, vertexAi);
}

/**
Expand All @@ -209,25 +178,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi, Transport transport)
*/
@BetaApi
public GenerativeModel(String modelName, GenerationConfig generationConfig, VertexAI vertexAi) {
this(modelName, generationConfig, null, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance with default generation config.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} instance that
* will be used by default for generating response
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
@BetaApi
public GenerativeModel(
String modelName, GenerationConfig generationConfig, VertexAI vertexAi, Transport transport) {
this(modelName, generationConfig, null, vertexAi, transport);
this(modelName, generationConfig, null, vertexAi);
}

/**
Expand All @@ -242,28 +193,7 @@ public GenerativeModel(
*/
@BetaApi("safetySettings is a preview feature.")
public GenerativeModel(String modelName, List<SafetySetting> safetySettings, VertexAI vertexAi) {
this(modelName, null, safetySettings, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance with default safety settings.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} instances
* that will be used by default for generating response
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
@BetaApi("safetySettings is a preview feature.")
public GenerativeModel(
String modelName,
List<SafetySetting> safetySettings,
VertexAI vertexAi,
Transport transport) {
this(modelName, null, safetySettings, vertexAi, transport);
this(modelName, null, safetySettings, vertexAi);
}

/**
Expand All @@ -284,30 +214,6 @@ public GenerativeModel(
GenerationConfig generationConfig,
List<SafetySetting> safetySettings,
VertexAI vertexAi) {
this(modelName, generationConfig, safetySettings, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance with default generation config and safety settings.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} instance that
* will be used by default for generating response
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} instances
* that will be used by default for generating response
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
@BetaApi
public GenerativeModel(
String modelName,
GenerationConfig generationConfig,
List<SafetySetting> safetySettings,
VertexAI vertexAi,
Transport transport) {
modelName = reconcileModelName(modelName);
this.modelName = modelName;
this.resourceName =
Expand All @@ -324,11 +230,6 @@ public GenerativeModel(
}
}
this.vertexAi = vertexAi;
if (transport != null) {
this.transport = transport;
} else {
this.transport = vertexAi.getTransport();
}
}

/**
Expand Down Expand Up @@ -388,7 +289,7 @@ public CountTokensResponse countTokens(List<Content> contents) throws IOExceptio
@BetaApi
private CountTokensResponse countTokensFromRequest(CountTokensRequest request)
throws IOException {
if (this.transport == Transport.REST) {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getLlmUtilityRestClient().countTokens(request);
} else {
return vertexAi.getLlmUtilityClient().countTokens(request);
Expand Down Expand Up @@ -619,7 +520,7 @@ public GenerateContentResponse generateContent(
*/
private GenerateContentResponse generateContent(GenerateContentRequest request)
throws IOException {
if (this.transport == Transport.REST) {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getPredictionServiceRestClient().generateContentCallable().call(request);
} else {
return vertexAi.getPredictionServiceClient().generateContentCallable().call(request);
Expand Down Expand Up @@ -1031,7 +932,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
*/
private ResponseStream<GenerateContentResponse> generateContentStream(
GenerateContentRequest request) throws IOException {
if (this.transport == Transport.REST) {
if (vertexAi.getTransport() == Transport.REST) {
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
Expand Down Expand Up @@ -1082,24 +983,11 @@ public void setTools(List<Tool> tools) {
}
}

/**
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
* generative model.
*/
public void setTransport(Transport transport) {
this.transport = transport;
}

/** Returns the model name of this generative model. */
public String getModelName() {
return this.modelName;
}

/** Returns the {@link Transport} layer for API calls in this generative model. */
public Transport getTransport() {
return this.transport;
}

/**
* Returns the {@link com.google.cloud.vertexai.api.GenerationConfig} of this generative model.
*/
Expand Down
Expand Up @@ -26,7 +26,6 @@
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensRequest;
Expand All @@ -35,15 +34,18 @@
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
import com.google.cloud.vertexai.api.HarmCategory;
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.Retrieval;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.Type;
import com.google.cloud.vertexai.api.VertexAISearch;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
Expand Down Expand Up @@ -96,14 +98,30 @@ public final class GenerativeModelTest {
.build())
.addRequired("location")))
.build();
private static final Tool GOOGLE_SEARCH_TOOL =
Tool.newBuilder()
.setGoogleSearchRetrieval(GoogleSearchRetrieval.newBuilder().setDisableAttribution(false))
.build();
private static final Tool VERTEX_AI_SEARCH_TOOL =
Tool.newBuilder()
.setRetrieval(
Retrieval.newBuilder()
.setVertexAiSearch(
VertexAISearch.newBuilder()
.setDatastore(
String.format(
"projects/%s/locations/%s/collections/%s/dataStores/%s",
PROJECT, "global", "default_collection", "test_123")))
.setDisableAttribution(false))
.build();

private static final String TEXT = "What is your name?";

private VertexAI vertexAi;
private GenerativeModel model;
private List<SafetySetting> safetySettings = Arrays.asList(SAFETY_SETTING);
private List<SafetySetting> defaultSafetySettings = Arrays.asList(DEFAULT_SAFETY_SETTING);
private List<Tool> tools = Arrays.asList(TOOL);
private List<Tool> tools = Arrays.asList(TOOL, GOOGLE_SEARCH_TOOL, VERTEX_AI_SEARCH_TOOL);

@Rule public final MockitoRule mocksRule = MockitoJUnit.rule();

Expand Down Expand Up @@ -169,7 +187,6 @@ public void testInstantiateGenerativeModelwithBuilder() {
assertThat(model.getGenerationConfig()).isNull();
assertThat(model.getSafetySettings()).isNull();
assertThat(model.getTools()).isNull();
assertThat(model.getTransport()).isEqualTo(Transport.GRPC);
}

@Test
Expand All @@ -181,13 +198,11 @@ public void testInstantiateGenerativeModelwithBuilderAllConfigs() {
.setGenerationConfig(GENERATION_CONFIG)
.setSafetySettings(safetySettings)
.setTools(tools)
.setTransport(Transport.REST)
.build();
assertThat(model.getModelName()).isEqualTo(MODEL_NAME);
assertThat(model.getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(model.getSafetySettings()).isEqualTo(safetySettings);
assertThat(model.getTools()).isEqualTo(tools);
assertThat(model.getTransport()).isEqualTo(Transport.REST);
}

@Test
Expand Down

0 comments on commit f024111

Please sign in to comment.