Skip to content

Commit

Permalink
[Backport 5.1] Azure OpenAI - experimental autocomplete support (#56323)
Browse files Browse the repository at this point in the history
Co-authored-by: Erik Seliger <erikseliger@me.com>
Co-authored-by: Chris Warwick <christopher.warwick@sourcegraph.com>
  • Loading branch information
3 people committed Sep 4, 2023
1 parent e60d87f commit 3f78b6d
Showing 1 changed file with 108 additions and 19 deletions.
127 changes: 108 additions & 19 deletions internal/completions/client/azureopenai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,21 @@ func (c *azureCompletionClient) Complete(
feature types.CompletionsFeature,
requestParams types.CompletionRequestParameters,
) (*types.CompletionResponse, error) {
var resp *http.Response
var err error
defer (func() {
if resp != nil {
resp.Body.Close()
}
})()
if feature == types.CompletionsFeatureCode {
return nil, errors.Newf("%q for Azure OpenAI is currently not supported")
resp, err = c.makeCompletionRequest(ctx, requestParams, false)
} else {
resp, err = c.makeRequest(ctx, requestParams, false)
}

resp, err := c.makeRequest(ctx, requestParams, false)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var response openaiResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
Expand All @@ -55,7 +61,7 @@ func (c *azureCompletionClient) Complete(
}

return &types.CompletionResponse{
Completion: response.Choices[0].Content,
Completion: response.Choices[0].Text,
StopReason: response.Choices[0].FinishReason,
}, nil
}
Expand All @@ -66,11 +72,22 @@ func (c *azureCompletionClient) Stream(
requestParams types.CompletionRequestParameters,
sendEvent types.SendCompletionEvent,
) error {
resp, err := c.makeRequest(ctx, requestParams, true)
var resp *http.Response
var err error

defer (func() {
if resp != nil {
resp.Body.Close()
}
})()
if feature == types.CompletionsFeatureCode {
resp, err = c.makeCompletionRequest(ctx, requestParams, true)
} else {
resp, err = c.makeRequest(ctx, requestParams, true)
}
if err != nil {
return err
}
defer resp.Body.Close()

dec := openai.NewDecoder(resp.Body)
var content string
Expand Down Expand Up @@ -114,28 +131,21 @@ func (c *azureCompletionClient) makeRequest(ctx context.Context, requestParams t
requestParams.TopP = 0
}

// TODO(sqs): make CompletionRequestParameters non-anthropic-specific
payload := azureChatCompletionsRequestParameters{
Temperature: requestParams.Temperature,
TopP: requestParams.TopP,
// TODO(sqs): map requestParams.TopK to openai
N: 1,
Stream: stream,
MaxTokens: requestParams.MaxTokensToSample,
// TODO: Our clients are currently heavily biased towards Anthropic,
// so the stop sequences we send might not actually be very useful
// for OpenAI.
Stop: requestParams.StopSequences,
N: 1,
Stream: stream,
MaxTokens: requestParams.MaxTokensToSample,
Stop: requestParams.StopSequences,
}
for _, m := range requestParams.Messages {
// TODO(sqs): map these 'roles' to openai system/user/assistant
var role string
switch m.Speaker {
case types.HUMAN_MESSAGE_SPEAKER:
role = "user"
case types.ASISSTANT_MESSAGE_SPEAKER:
role = "assistant"
//
default:
role = strings.ToLower(role)
}
Expand Down Expand Up @@ -179,6 +189,62 @@ func (c *azureCompletionClient) makeRequest(ctx context.Context, requestParams t
return resp, nil
}

func (c *azureCompletionClient) makeCompletionRequest(ctx context.Context, requestParams types.CompletionRequestParameters, stream bool) (*http.Response, error) {
if requestParams.TopK < 0 {
requestParams.TopK = 0
}
if requestParams.TopP < 0 {
requestParams.TopP = 0
}

prompt, err := getPrompt(requestParams.Messages)
if err != nil {
return nil, err
}

payload := azureCompletionsRequestParameters{
Temperature: requestParams.Temperature,
TopP: requestParams.TopP,
N: 1,
Stream: stream,
MaxTokens: requestParams.MaxTokensToSample,
Stop: requestParams.StopSequences,
Prompt: prompt,
}

reqBody, err := json.Marshal(payload)
if err != nil {
return nil, err
}
url, err := url.Parse(c.endpoint)
if err != nil {
return nil, errors.Wrap(err, "failed to parse configured endpoint")
}
q := url.Query()
q.Add("api-version", "2023-05-15")
url.RawQuery = q.Encode()
url.Path = fmt.Sprintf("/openai/deployments/%s/completions", requestParams.Model)

req, err := http.NewRequestWithContext(ctx, "POST", url.String(), bytes.NewReader(reqBody))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", c.accessToken)

resp, err := c.cli.Do(req)
if err != nil {
return nil, err
}

if resp.StatusCode != http.StatusOK {
return nil, types.NewErrStatusNotOK("AzureOpenAI", resp)
}

return resp, nil
}

type azureChatCompletionsRequestParameters struct {
Messages []message `json:"messages"`
Temperature float32 `json:"temperature,omitempty"`
Expand All @@ -193,6 +259,21 @@ type azureChatCompletionsRequestParameters struct {
User string `json:"user,omitempty"`
}

type azureCompletionsRequestParameters struct {
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
LogitBias map[string]float32 `json:"logit_bias,omitempty"`
Suffix string `json:"suffix,omitempty"`
User string `json:"user,omitempty"`
}

type message struct {
Role string `json:"role"`
Content string `json:"content"`
Expand All @@ -211,7 +292,7 @@ type openaiChoiceDelta struct {
type openaiChoice struct {
Delta openaiChoiceDelta `json:"delta"`
Role string `json:"role"`
Content string `json:"content"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}

Expand All @@ -221,3 +302,11 @@ type openaiResponse struct {
Model string `json:"model"`
Choices []openaiChoice `json:"choices"`
}

func getPrompt(messages []types.Message) (string, error) {
if len(messages) != 1 {
return "", errors.New("Expected to receive exactly one message with the prompt")
}

return messages[0].Text, nil
}

0 comments on commit 3f78b6d

Please sign in to comment.