Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
jba committed Dec 7, 2023
1 parent 4ffbf76 commit ed65e5d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
4 changes: 4 additions & 0 deletions vertexai/genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ import (
"context"
)

// A ChatSession provides interactive chat.
type ChatSession struct {
m *GenerativeModel
History []*Content
}

// StartChat starts a chat session.
func (m *GenerativeModel) StartChat() *ChatSession {
return &ChatSession{m: m}
}

// SendMessage sends a request to the model as part of a chat session.
func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
// Call the underlying client with the entire history plus the argument Content.
cs.History = append(cs.History, newUserContent(parts))
Expand All @@ -41,6 +44,7 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat
return resp, nil
}

// SendMessageStream is like SendMessage, but with a streaming request.
func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
cs.History = append(cs.History, newUserContent(parts))
req := cs.m.newRequest(cs.History...)
Expand Down
12 changes: 10 additions & 2 deletions vertexai/genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func NewClient(ctx context.Context, projectID, location string, opts ...option.C
}, nil
}

// Close closes the client.
func (c *Client) Close() error {
return c.c.Close()
}
Expand Down Expand Up @@ -88,16 +89,17 @@ func (c *Client) GenerativeModel(name string) *GenerativeModel {
}
}

// Name returns the name of the model.
func (m *GenerativeModel) Name() string {
return m.name
}

// Use GenerateContent for a single request and response.
// GenerateContent produces a single request and response.
func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
return m.generateContent(ctx, m.newRequest(newUserContent(parts)))
}

// Streaming version returns an iterator, following the pattern of the other Go clients.
// GenerateContentStream returns an iterator that enumerates responses.
func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
streamClient, err := m.c.c.StreamGenerateContent(ctx, m.newRequest(newUserContent(parts)))
return &GenerateContentResponseIterator{
Expand Down Expand Up @@ -137,13 +139,15 @@ func newUserContent(parts []Part) *Content {
return &Content{Role: roleUser, Parts: parts}
}

// GenerateContentResponseIterator is an iterator over GnerateContentResponse.
type GenerateContentResponseIterator struct {
sc pb.PredictionService_StreamGenerateContentClient
err error
merged *GenerateContentResponse
cs *ChatSession
}

// Next returns the next response.
func (iter *GenerateContentResponseIterator) Next() (*GenerateContentResponse, error) {
if iter.err != nil {
return nil, iter.err
Expand All @@ -170,6 +174,7 @@ func (iter *GenerateContentResponseIterator) Next() (*GenerateContentResponse, e
return gcp, nil
}

// GenerateContentResponse is the response from a GenerateContent or GenerateContentStream call.
type GenerateContentResponse struct {
Candidates []*Candidate
PromptFeedback *PromptFeedback
Expand All @@ -193,6 +198,7 @@ func protoToResponse(resp *pb.GenerateContentResponse) (*GenerateContentResponse
return &GenerateContentResponse{Candidates: cands}, nil
}

// PromptFeedback is feedback about a prompt.
type PromptFeedback struct {
BlockReason BlockedReason
BlockReasonMessage string
Expand All @@ -210,8 +216,10 @@ func protoToPromptFeedback(p *pb.GenerateContentResponse_PromptFeedback) *Prompt
}
}

// BlockedReason doc TBD.
type BlockedReason int32

// Constants for BlockedReason.
const (
BlockedReasonSafety = BlockedReason(pb.GenerateContentResponse_PromptFeedback_SAFETY)
BlockedReasonOther = BlockedReason(pb.GenerateContentResponse_PromptFeedback_OTHER)
Expand Down
21 changes: 20 additions & 1 deletion vertexai/genai/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,43 @@ import (
pb "cloud.google.com/go/vertexai/internal/aiplatform/apiv1beta1/aiplatformpb"
)

// HarmCategory doc TBD.
type HarmCategory int32

// Constants for HarmCategory.
const (
HarmCategoryHateSpeech = HarmCategory(pb.HarmCategory_HARM_CATEGORY_HATE_SPEECH)
HarmCategoryDangerousContent = HarmCategory(pb.HarmCategory_HARM_CATEGORY_DANGEROUS_CONTENT)
HarmCategoryHarassment = HarmCategory(pb.HarmCategory_HARM_CATEGORY_HARASSMENT)
HarmCategorySexuallyExplicit = HarmCategory(pb.HarmCategory_HARM_CATEGORY_SEXUALLY_EXPLICIT)
)

// HarmBlockThreshold doc TBD.
type HarmBlockThreshold int32

// Constants for HarmBlock.
const (
HarmBlockLowAndAbove = HarmBlockThreshold(pb.SafetySetting_BLOCK_LOW_AND_ABOVE)
HarmBlockMediumAndAbove = HarmBlockThreshold(pb.SafetySetting_BLOCK_MEDIUM_AND_ABOVE)
HarmBlockOnlyHigh = HarmBlockThreshold(pb.SafetySetting_BLOCK_ONLY_HIGH)
HarmBlockNone = HarmBlockThreshold(pb.SafetySetting_BLOCK_NONE)
)

// HarmProbability doc TBD.
type HarmProbability int32

// Constants for HarmProbability.
const (
HarmProbabilityNegligible = HarmProbability(pb.SafetyRating_NEGLIGIBLE)
HarmProbabilityLow = HarmProbability(pb.SafetyRating_LOW)
HarmProbabilityMedium = HarmProbability(pb.SafetyRating_MEDIUM)
HarmProbabilityHigh = HarmProbability(pb.SafetyRating_HIGH)
)

// FinishReason doc TBD.
type FinishReason int32

// Constants for FinishReason.
const (
FinishReasonUnspecified = FinishReason(pb.Candidate_FINISH_REASON_UNSPECIFIED)
FinishReasonStop = FinishReason(pb.Candidate_STOP)
Expand All @@ -77,6 +85,7 @@ func (f FinishReason) String() string {
return fmt.Sprintf("FinishReason(%d)", f)
}

// MarshalJSON implements [encoding/json.Marshaler].
func (f FinishReason) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(f.String())), nil
}
Expand All @@ -86,6 +95,7 @@ const (
roleModel = "model"
)

// Content doc TBD.
type Content struct {
Role string
Parts []Part
Expand All @@ -105,7 +115,7 @@ func protoToContent(c *pb.Content) *Content {
}
}

// A part is either a Text, a Blob, or a FileData.
// A Part is either a Text, a Blob, or a FileData.
type Part interface {
proto() *pb.Part
}
Expand All @@ -129,6 +139,7 @@ func protoToPart(p *pb.Part) Part {
}
}

// Text doc TBD.
type Text string

func (t Text) proto() *pb.Part {
Expand All @@ -137,6 +148,7 @@ func (t Text) proto() *pb.Part {
}
}

// Blob doc TBD.
type Blob struct {
MIMEType string
Data []byte
Expand All @@ -153,6 +165,7 @@ func (b Blob) proto() *pb.Part {
}
}

// FileData doc TBD.
type FileData struct {
MIMEType string
FileURI string
Expand Down Expand Up @@ -180,6 +193,7 @@ func ImageData(format string, data []byte) Blob {
}
}

// GenerationConfig doc TBD.
type GenerationConfig struct {
Temperature float32
TopP float32 // if non-zero, use nucleus sampling
Expand All @@ -205,6 +219,7 @@ func (c *GenerationConfig) proto() *pb.GenerationConfig {
}
}

// SafetySetting doc TBD.
type SafetySetting struct {
Category HarmCategory
Threshold HarmBlockThreshold
Expand All @@ -217,6 +232,7 @@ func (s *SafetySetting) proto() *pb.SafetySetting {
}
}

// SafetyRating doc TBD.
type SafetyRating struct {
Category HarmCategory
Probability HarmProbability
Expand All @@ -231,6 +247,7 @@ func protoToSafetyRating(r *pb.SafetyRating) *SafetyRating {
}
}

// CitationMetadata doc TBD.
type CitationMetadata struct {
Citations []*Citation
}
Expand All @@ -244,6 +261,7 @@ func protoToCitationMetadata(cm *pb.CitationMetadata) *CitationMetadata {
}
}

// Citation doc TBD.
type Citation struct {
StartIndex, EndIndex int32
URI string
Expand All @@ -270,6 +288,7 @@ func protoToCitation(c *pb.Citation) *Citation {
return r
}

// Candidate doc TBD.
type Candidate struct {
Index int32
Content *Content
Expand Down

0 comments on commit ed65e5d

Please sign in to comment.