Skip to content

Commit

Permalink
feat: [vertexai] add generateContentAsync methods to GenerativeModel (#…
Browse files Browse the repository at this point in the history
…10599)

PiperOrigin-RevId: 617951189

Co-authored-by: Jaycee Li <jayceeli@google.com>
  • Loading branch information
copybara-service[bot] and jaycee-li committed Mar 22, 2024
1 parent 5c3d93e commit 17b01c6
Show file tree
Hide file tree
Showing 5 changed files with 467 additions and 163 deletions.
Expand Up @@ -16,6 +16,9 @@

package com.google.cloud.vertexai;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.api.core.InternalApi;
import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
Expand All @@ -28,8 +31,10 @@
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.base.Strings;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand All @@ -56,9 +61,8 @@ public class VertexAI implements AutoCloseable {
private Transport transport = Transport.GRPC;
// The clients will be instantiated lazily
private PredictionServiceClient predictionServiceClient = null;
private PredictionServiceClient predictionServiceRestClient = null;
private LlmUtilityServiceClient llmUtilityClient = null;
private LlmUtilityServiceClient llmUtilityRestClient = null;
private final ReentrantLock lock = new ReentrantLock();

/**
* Construct a VertexAI instance.
Expand Down Expand Up @@ -193,32 +197,35 @@ public Credentials getCredentials() throws IOException {

/** Sets the value for {@link #getTransport()}. */
public void setTransport(Transport transport) {
checkNotNull(transport, "Transport can't be null.");
if (this.transport == transport) {
return;
}

this.transport = transport;
resetClients();
}

/** Sets the value for {@link #getApiEndpoint()}. */
public void setApiEndpoint(String apiEndpoint) {
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty.");
if (this.apiEndpoint == apiEndpoint) {
return;
}
this.apiEndpoint = apiEndpoint;
resetClients();
}

private void resetClients() {
if (this.predictionServiceClient != null) {
this.predictionServiceClient.close();
this.predictionServiceClient = null;
}

if (this.predictionServiceRestClient != null) {
this.predictionServiceRestClient.close();
this.predictionServiceRestClient = null;
}

if (this.llmUtilityClient != null) {
this.llmUtilityClient.close();
this.llmUtilityClient = null;
}

if (this.llmUtilityRestClient != null) {
this.llmUtilityRestClient.close();
this.llmUtilityRestClient = null;
}
}

/**
Expand All @@ -230,78 +237,47 @@ public void setApiEndpoint(String apiEndpoint) {
*/
@InternalApi
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (this.transport == Transport.GRPC) {
return getPredictionServiceGrpcClient();
} else {
return getPredictionServiceRestClient();
if (predictionServiceClient != null) {
return predictionServiceClient;
}
}

/**
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
* method calls that map to the API methods.
*/
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
if (predictionServiceClient == null) {
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
lock.lock();
try {
if (predictionServiceClient == null) {
PredictionServiceSettings settings = getPredictionServiceSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = PredictionServiceClient.create(settings);
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
return predictionServiceClient;
} finally {
lock.unlock();
}
return predictionServiceClient;
}

/**
* Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
* first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
* method calls that map to the API methods.
*/
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
if (predictionServiceRestClient == null) {
PredictionServiceSettings.Builder settingsBuilder =
PredictionServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
private PredictionServiceSettings getPredictionServiceSettings() throws IOException {
PredictionServiceSettings.Builder builder;
if (transport == Transport.REST) {
builder = PredictionServiceSettings.newHttpJsonBuilder();
} else {
builder = PredictionServiceSettings.newBuilder();
}
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
builder.setCredentialsProvider(this.credentialsProvider);
}
return predictionServiceRestClient;
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
builder.setHeaderProvider(headerProvider);
return builder.build();
}

/**
Expand All @@ -313,78 +289,47 @@ private PredictionServiceClient getPredictionServiceRestClient() throws IOExcept
*/
@InternalApi
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (this.transport == Transport.GRPC) {
return getLlmUtilityGrpcClient();
} else {
return getLlmUtilityRestClient();
if (llmUtilityClient != null) {
return llmUtilityClient;
}
}

/**
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
* method calls that map to the API methods.
*/
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
if (llmUtilityClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
lock.lock();
try {
if (llmUtilityClient == null) {
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = LlmUtilityServiceClient.create(settings);
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
return llmUtilityClient;
} finally {
lock.unlock();
}
return llmUtilityClient;
}

/**
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
* method calls that map to the API methods.
*/
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
if (llmUtilityRestClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder =
LlmUtilityServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
LlmUtilityServiceSettings.Builder settingsBuilder;
if (transport == Transport.REST) {
settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder();
} else {
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
}
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
return llmUtilityRestClient;
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
return settingsBuilder.build();
}

/** Closes the VertexAI instance together with all its instantiated clients. */
Expand All @@ -393,14 +338,8 @@ public void close() {
if (predictionServiceClient != null) {
predictionServiceClient.close();
}
if (predictionServiceRestClient != null) {
predictionServiceRestClient.close();
}
if (llmUtilityClient != null) {
llmUtilityClient.close();
}
if (llmUtilityRestClient != null) {
llmUtilityRestClient.close();
}
}
}
@@ -0,0 +1,63 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.vertexai.generativeai;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.cloud.vertexai.api.FunctionDeclaration;
import com.google.common.base.Strings;
import com.google.gson.JsonObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;

/** Helper class to create {@link com.google.cloud.vertexai.api.FunctionDeclaration} */
public final class FunctionDeclarationMaker {

/**
* Creates a FunctionDeclaration from a JsonString
*
* @param jsonString A valid Json String that can be parsed to a FunctionDeclaration.
* @return a {@link FunctionDeclaration} by parsing the input json String.
* @throws InvalidProtocolBufferException if the String can't be parsed into a FunctionDeclaration
* proto.
*/
public static FunctionDeclaration fromJsonString(String jsonString)
throws InvalidProtocolBufferException {
checkArgument(!Strings.isNullOrEmpty(jsonString), "Input String can't be null or empty.");
FunctionDeclaration.Builder builder = FunctionDeclaration.newBuilder();
JsonFormat.parser().merge(jsonString, builder);
FunctionDeclaration declaration = builder.build();
if (declaration.getName().isEmpty()) {
throw new IllegalArgumentException("name field must be present.");
}
return declaration;
}

/**
* Creates a FunctionDeclaration from a JsonObject
*
* @param jsonObject A valid Json Object that can be parsed to a FunctionDeclaration.
* @return a {@link FunctionDeclaration} by parsing the input json Object.
* @throws InvalidProtocolBufferException if the jsonObject can't be parsed into a
* FunctionDeclaration proto.
*/
public static FunctionDeclaration fromJsonObject(JsonObject jsonObject)
throws InvalidProtocolBufferException {
checkNotNull(jsonObject, "JsonObject can't be null.");
return fromJsonString(jsonObject.toString());
}
}

0 comments on commit 17b01c6

Please sign in to comment.