Skip to content

Commit

Permalink
feat(vertexai): add WithREST option to vertexai client (#9389)
Browse files Browse the repository at this point in the history
Initially configure whether the client uses REST or gRPC
  • Loading branch information
eliben committed Feb 8, 2024
1 parent 1879551 commit f5d56eb
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
16 changes: 13 additions & 3 deletions vertexai/genai/client.go
Expand Up @@ -46,16 +46,26 @@ type Client struct {
// Clients should be reused instead of created as needed. The methods of Client
// are safe for concurrent use by multiple goroutines.
//
// You may configure the client by passing in options from the [google.golang.org/api/option]
// package.
// You may configure the client by passing in options from the
// [google.golang.org/api/option] package. You may also use options defined in
// this package, such as [WithREST].
func NewClient(ctx context.Context, projectID, location string, opts ...option.ClientOption) (*Client, error) {
opts = append([]option.ClientOption{
option.WithEndpoint(fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)),
}, opts...)
c, err := aiplatform.NewPredictionClient(ctx, opts...)
conf := newConfig(opts...)

var c *aiplatform.PredictionClient
var err error
if conf.withREST {
c, err = aiplatform.NewPredictionRESTClient(ctx, opts...)
} else {
c, err = aiplatform.NewPredictionClient(ctx, opts...)
}
if err != nil {
return nil, err
}

c.SetGoogleClientInfo("gccl", internal.Version)
return &Client{
c: c,
Expand Down
21 changes: 21 additions & 0 deletions vertexai/genai/client_test.go
Expand Up @@ -263,6 +263,27 @@ func TestLive(t *testing.T) {
})
}

func TestLiveREST(t *testing.T) {
if *projectID == "" || *modelName == "" {
t.Skip("need -project and -model")
}
ctx := context.Background()
client, err := NewClient(ctx, *projectID, "us-central1", WithREST())
if err != nil {
t.Fatal(err)
}
defer client.Close()
model := client.GenerativeModel(*modelName)
model.SetTemperature(0.0)

resp, err := model.GenerateContent(ctx, Text("What is the average size of a swallow?"))
if err != nil {
t.Fatal(err)
}
got := responseString(resp)
checkMatch(t, got, `15.* cm|[1-9].* inches`)
}

func TestJoinResponses(t *testing.T) {
r1 := &GenerateContentResponse{
Candidates: []*Candidate{
Expand Down
57 changes: 57 additions & 0 deletions vertexai/genai/option.go
@@ -0,0 +1,57 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package genai

import (
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
)

// WithREST is an option that enables REST transport for the client.
// The default transport (if this option isn't provided) is gRPC.
func WithREST() option.ClientOption {
return &withREST{}
}

func (w *withREST) applyVertexaiOpt(c *config) {
c.withREST = true
}

type config struct {
// withREST tells the client to use REST as the underlying transport.
withREST bool
}

// newConfig generates a new config with all the given
// vertexaiClientOptions applied.
func newConfig(opts ...option.ClientOption) config {
var conf config
for _, opt := range opts {
if vOpt, ok := opt.(vertexaiClientOption); ok {
vOpt.applyVertexaiOpt(&conf)
}
}
return conf
}

// A vertexaiClientOption is an option for a vertexai client.
type vertexaiClientOption interface {
option.ClientOption
applyVertexaiOpt(*config)
}

type withREST struct {
internaloption.EmbeddableAdapter
}

0 comments on commit f5d56eb

Please sign in to comment.