Skip to content

Commit

Permalink
chore(vertexai): remove flags from client test (#10150)
Browse files Browse the repository at this point in the history
The flags preclude running `go test ./...`, which is natural for Go projects. The model is replaced by a constant (OK to have model name in the code now). The project is replaced by an env var.
  • Loading branch information
eliben committed May 13, 2024
1 parent 8875511 commit 1f9ecc4
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions vertexai/genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package genai
import (
"context"
"errors"
"flag"
"fmt"
"io"
"log"
Expand All @@ -31,24 +30,26 @@ import (
"google.golang.org/api/iterator"
)

var (
projectID = flag.String("project", "", "project ID")
modelName = flag.String("model", "", "model")
)

const defaultModel = "gemini-1.0-pro"
const imageFile = "personWorkingOnComputer.jpg"

func TestLive(t *testing.T) {
if *projectID == "" || *modelName == "" {
t.Skip("need -project and -model")
projectID := os.Getenv("VERTEX_PROJECT_ID")
if testing.Short() {
t.Skip("skipping live test in -short mode")
}

if projectID == "" {
t.Skip("set a VERTEX_PROJECT_ID env var to run live tests")
}

ctx := context.Background()
client, err := NewClient(ctx, *projectID, "us-central1")
client, err := NewClient(ctx, projectID, "us-central1")
if err != nil {
t.Fatal(err)
}
defer client.Close()
model := client.GenerativeModel(*modelName)
model := client.GenerativeModel(defaultModel)
model.Temperature = Ptr[float32](0)

t.Run("GenerateContent", func(t *testing.T) {
Expand All @@ -60,7 +61,7 @@ func TestLive(t *testing.T) {
checkMatch(t, got, `15.* cm|[1-9].* inches`)
})
t.Run("system-instructions", func(t *testing.T) {
model := client.GenerativeModel(*modelName)
model := client.GenerativeModel(defaultModel)
model.Temperature = Ptr[float32](0)
model.SystemInstruction = &Content{
Parts: []Part{Text("You are Yoda from Star Wars.")},
Expand Down Expand Up @@ -127,7 +128,7 @@ func TestLive(t *testing.T) {
})

t.Run("image", func(t *testing.T) {
vmodel := client.GenerativeModel(*modelName + "-vision")
vmodel := client.GenerativeModel(defaultModel + "-vision")
vmodel.Temperature = Ptr[float32](0)

data, err := os.ReadFile(filepath.Join("testdata", imageFile))
Expand Down Expand Up @@ -182,7 +183,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("max-tokens", func(t *testing.T) {
maxModel := client.GenerativeModel(*modelName)
maxModel := client.GenerativeModel(defaultModel)
maxModel.Temperature = Ptr(float32(0))
maxModel.SetMaxOutputTokens(10)
res, err := maxModel.GenerateContent(ctx, Text("What is a dog?"))
Expand All @@ -196,7 +197,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("max-tokens-streaming", func(t *testing.T) {
maxModel := client.GenerativeModel(*modelName)
maxModel := client.GenerativeModel(defaultModel)
maxModel.Temperature = Ptr[float32](0)
maxModel.MaxOutputTokens = Ptr[int32](10)
iter := maxModel.GenerateContentStream(ctx, Text("What is a dog?"))
Expand Down Expand Up @@ -246,7 +247,7 @@ func TestLive(t *testing.T) {
},
}},
}
model := client.GenerativeModel(*modelName)
model := client.GenerativeModel(defaultModel)
model.SetTemperature(0)
model.Tools = []*Tool{weatherTool}
t.Run("funcall", func(t *testing.T) {
Expand Down Expand Up @@ -299,16 +300,22 @@ func TestLive(t *testing.T) {
}

func TestLiveREST(t *testing.T) {
if *projectID == "" || *modelName == "" {
t.Skip("need -project and -model")
projectID := os.Getenv("VERTEX_PROJECT_ID")
if testing.Short() {
t.Skip("skipping live test in -short mode")
}

if projectID == "" {
t.Skip("set a VERTEX_PROJECT_ID env var to run live tests")
}

ctx := context.Background()
client, err := NewClient(ctx, *projectID, "us-central1", WithREST())
client, err := NewClient(ctx, projectID, "us-central1", WithREST())
if err != nil {
t.Fatal(err)
}
defer client.Close()
model := client.GenerativeModel(*modelName)
model := client.GenerativeModel(defaultModel)
model.SetTemperature(0.0)

resp, err := model.GenerateContent(ctx, Text("What is the average size of a swallow?"))
Expand Down

0 comments on commit 1f9ecc4

Please sign in to comment.