Skip to content

Commit

Permalink
feat: [vertexai] sync the vertexai to google3 sot (#10225)
Browse files Browse the repository at this point in the history
It includes the following changes:

chore: remove the term "url" and replace with "uri"

chore: add user-agent header in Java SDK

feat: support "publishers/google/models/" prefix

feat: add apiEndpoint in VertexAI

chore: change the implementation of countTokens.

chore: switch to v1 gapic clients.

chore: remove URL support in from MultiModalData

chore: remove the logic to throw an exception when getting multi-modal
data in chat
  • Loading branch information
ZhenyiQ committed Jan 17, 2024
1 parent 57b0587 commit da6eea8
Show file tree
Hide file tree
Showing 318 changed files with 21,341 additions and 17,113 deletions.
4 changes: 2 additions & 2 deletions java-vertexai/README.md
Expand Up @@ -192,8 +192,8 @@ import java.util.Arrays;
import java.util.List;

public class Main {
private static final String PROJECT_ID = "cloud-llm-preview1";
private static final String LOCATION = "us-central1";
private static final String PROJECT_ID = <your project id>;
private static final String LOCATION = <location>;
private static final String MODEL_NAME = "gemini-pro";

public static void main(String[] args) throws IOException {
Expand Down
8 changes: 4 additions & 4 deletions java-vertexai/google-cloud-vertexai-bom/pom.xml
Expand Up @@ -31,13 +31,13 @@
</dependency>
<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>grpc-google-cloud-vertexai-v1beta1</artifactId>
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:grpc-google-cloud-vertexai-v1beta1:current} -->
<artifactId>grpc-google-cloud-vertexai-v1</artifactId>
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:grpc-google-cloud-vertexai-v1:current} -->
</dependency>
<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>proto-google-cloud-vertexai-v1beta1</artifactId>
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:proto-google-cloud-vertexai-v1beta1:current} -->
<artifactId>proto-google-cloud-vertexai-v1</artifactId>
<version>0.3.0-SNAPSHOT</version><!-- {x-version-update:proto-google-cloud-vertexai-v1:current} -->
</dependency>
</dependencies>
</dependencyManagement>
Expand Down
4 changes: 2 additions & 2 deletions java-vertexai/google-cloud-vertexai/pom.xml
Expand Up @@ -43,7 +43,7 @@

<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>proto-google-cloud-vertexai-v1beta1</artifactId>
<artifactId>proto-google-cloud-vertexai-v1</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
Expand Down Expand Up @@ -97,7 +97,7 @@

<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>grpc-google-cloud-vertexai-v1beta1</artifactId>
<artifactId>grpc-google-cloud-vertexai-v1</artifactId>
<scope>test</scope>
</dependency>
<!-- Need testing utility classes for generated gRPC clients tests -->
Expand Down
@@ -0,0 +1,25 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.vertexai;

/** A class that holds all constants for vertexai. */
public final class Constants {
// Constants for VertexAI class
public static final String USER_AGENT_HEADER = "model-builder";

private Constants() {}
}
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Google LLC
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Google LLC
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,19 @@

package com.google.cloud.vertexai;

import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.core.GoogleCredentialsProvider;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.auth.Credentials;
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.cloud.vertexai.api.stub.PredictionServiceStubSettings;
import java.io.IOException;
import java.util.List;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand All @@ -44,11 +50,14 @@ public class VertexAI implements AutoCloseable {

private final String projectId;
private final String location;
private final GoogleCredentials credentials;
private String apiEndpoint;
private CredentialsProvider credentialsProvider = null;
private Transport transport = Transport.GRPC;
// The clients will be instantiated lazily
private PredictionServiceClient predictionServiceClient = null;
private PredictionServiceClient predictionServiceRestClient = null;
private LlmUtilityServiceClient llmUtilityClient = null;
private LlmUtilityServiceClient llmUtilityRestClient = null;

/**
* Construct a VertexAI instance with custom credentials.
Expand All @@ -57,10 +66,11 @@ public class VertexAI implements AutoCloseable {
* @param location the default location to use when making API calls
* @param credentials the custom credentials to use when making API calls
*/
public VertexAI(String projectId, String location, GoogleCredentials credentials) {
public VertexAI(String projectId, String location, Credentials credentials) {
this.projectId = projectId;
this.location = location;
this.credentials = credentials;
this.apiEndpoint = String.format("%s-aiplatform.googleapis.com", this.location);
this.credentialsProvider = FixedCredentialsProvider.create(credentials);
}

/**
Expand All @@ -71,8 +81,7 @@ public VertexAI(String projectId, String location, GoogleCredentials credentials
* @param transport the default {@link Transport} layer to use to send API requests
* @param credentials the default custom credentials to use when making API calls
*/
public VertexAI(
String projectId, String location, Transport transport, GoogleCredentials credentials) {
public VertexAI(String projectId, String location, Transport transport, Credentials credentials) {
this(projectId, location, credentials);
this.transport = transport;
}
Expand All @@ -82,24 +91,22 @@ public VertexAI(
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
* @param scopes collection of scopes in the default credentials
* @param scopes collection of scopes in the default credentials. Make sure you have specified
* "https://www.googleapis.com/auth/cloud-platform" scope to access resources on Vertex AI.
*/
public VertexAI(String projectId, String location, String... scopes) throws IOException {
// Disable the warning message logged in getApplicationDefault
Logger logger = Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = logger.getLevel();
logger.setLevel(Level.SEVERE);
List<String> defaultScopes =
PredictionServiceStubSettings.defaultCredentialsProviderBuilder().getScopesToApply();
GoogleCredentials credentials =
CredentialsProvider credentialsProvider =
scopes.length == 0
? GoogleCredentials.getApplicationDefault().createScoped(defaultScopes)
: GoogleCredentials.getApplicationDefault().createScoped(scopes);
logger.setLevel(previousLevel);
? null
: GoogleCredentialsProvider.newBuilder()
.setScopesToApply(Arrays.asList(scopes))
.setUseJwtAccessWithScope(true)
.build();

this.projectId = projectId;
this.location = location;
this.credentials = credentials;
this.apiEndpoint = String.format("%s-aiplatform.googleapis.com", this.location);
this.credentialsProvider = credentialsProvider;
}

/**
Expand Down Expand Up @@ -131,28 +138,72 @@ public String getLocation() {
return this.location;
}

/** Returns the default endpoint to use when making API calls. */
public String getApiEndpoint() {
return this.apiEndpoint;
}

/** Returns the default credentials to use when making API calls. */
public GoogleCredentials getCredentials() {
return credentials;
public Credentials getCredentials() throws IOException {
return credentialsProvider.getCredentials();
}

/** Sets the value for {@link #getTransport()}. */
public void setTransport(Transport transport) {
this.transport = transport;
}

/** Sets the value for {@link #getApiEndpoint()}. */
public void setApiEndpoint(String apiEndpoint) {
this.apiEndpoint = apiEndpoint;

if (this.predictionServiceClient != null) {
this.predictionServiceClient.close();
this.predictionServiceClient = null;
}

if (this.predictionServiceRestClient != null) {
this.predictionServiceRestClient.close();
this.predictionServiceRestClient = null;
}

if (this.llmUtilityClient != null) {
this.llmUtilityClient.close();
this.llmUtilityClient = null;
}

if (this.llmUtilityRestClient != null) {
this.llmUtilityRestClient.close();
this.llmUtilityRestClient = null;
}
}

/**
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
*/
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (predictionServiceClient == null) {
PredictionServiceSettings settings =
PredictionServiceSettings.newBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location))
.setCredentialsProvider(FixedCredentialsProvider.create(this.credentials))
.build();
predictionServiceClient = PredictionServiceClient.create(settings);
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return predictionServiceClient;
}
Expand All @@ -163,16 +214,92 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException {
*/
public PredictionServiceClient getPredictionServiceRestClient() throws IOException {
if (predictionServiceRestClient == null) {
PredictionServiceSettings settings =
PredictionServiceSettings.newHttpJsonBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location))
.setCredentialsProvider(FixedCredentialsProvider.create(this.credentials))
.build();
predictionServiceRestClient = PredictionServiceClient.create(settings);
PredictionServiceSettings.Builder settingsBuilder =
PredictionServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return predictionServiceRestClient;
}

/**
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
*/
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (llmUtilityClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return llmUtilityClient;
}

/**
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
*/
public LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
if (llmUtilityRestClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder =
LlmUtilityServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return llmUtilityRestClient;
}

/** Closes the VertexAI instance together with all its instantiated clients. */
@Override
public void close() {
Expand All @@ -182,5 +309,11 @@ public void close() {
if (predictionServiceRestClient != null) {
predictionServiceRestClient.close();
}
if (llmUtilityClient != null) {
llmUtilityClient.close();
}
if (llmUtilityRestClient != null) {
llmUtilityRestClient.close();
}
}
}

0 comments on commit da6eea8

Please sign in to comment.