Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ai-proxy): Fix Cohere breaks with model parameter in body; Fix OpenAI token counting for function requests; Fix user sending own-model parameter #13000

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-azure-streaming copy.yml
Copy link
Member

@vm-001 vm-001 Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename this file to make it more appropriate

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where setting OpenAI SDK model parameter "null" caused analytics
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a past tense Fixed. (@outsinre please correct me if I'm wrong)

to not be written to the logging plugin(s).
scope: Plugin
type: bugfix
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-azure-streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where certain Azure models would return partial tokens/words
when in response-streaming mode.
scope: Plugin
type: bugfix
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-fix-model-parameter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where Cohere and Anthropic providers don't read the `model` parameter properly
from the caller's request body.
scope: Plugin
type: bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix the bug where using "OpenAI Function" inference requests would log a
request error, and then hang until timeout.
scope: Plugin
type: bugfix
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-fix-sending-own-model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where AI Proxy would still allow callers to specify their own model,
ignoring the plugin-configured model name.
scope: Plugin
type: bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where AI Proxy would not take precedence of the
plugin's configured model tuning options, over those in the user's LLM request.
scope: Plugin
type: bugfix
16 changes: 5 additions & 11 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ local transformers_to = {
return nil, nil, err
end

messages.temperature = request_table.temperature or (model.options and model.options.temperature) or nil
messages.max_tokens = request_table.max_tokens or (model.options and model.options.max_tokens) or nil
messages.temperature = (model.options and model.options.temperature) or request_table.temperature or nil
Copy link
Member

@vm-001 vm-001 Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or nil can be omitted

> v = a or b or nil
> print(v)
nil
> v1 = a or b 
> print(v1)
nil

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the following places

messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens or nil
messages.model = model.name or request_table.model
messages.stream = request_table.stream or false -- explicitly set this if nil

Expand All @@ -110,9 +110,8 @@ local transformers_to = {
return nil, nil, err
end

prompt.temperature = request_table.temperature or (model.options and model.options.temperature) or nil
prompt.max_tokens_to_sample = request_table.max_tokens or (model.options and model.options.max_tokens) or nil
prompt.model = model.name
prompt.temperature = (model.options and model.options.temperature) or request_table.temperature or nil
prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens or nil
prompt.model = model.name or request_table.model
prompt.stream = request_table.stream or false -- explicitly set this if nil

Expand Down Expand Up @@ -442,12 +441,7 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return nil, "cannot use own model for this instance"
end

return true, nil
return true
end

-- returns err or nil
Expand Down
13 changes: 2 additions & 11 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,7 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return false, "cannot use own model for this instance"
end

return true, nil
return true
end

function _M.subrequest(body, conf, http_opts, return_res_table)
Expand Down Expand Up @@ -467,7 +462,7 @@ end
function _M.configure_request(conf)
local parsed_url

if conf.model.options.upstream_url then
if conf.model.options and conf.model.options.upstream_url then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can conf.model be nil?

parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
Expand All @@ -476,10 +471,6 @@ function _M.configure_request(conf)
or ai_shared.operation_map[DRIVER_NAME][conf.route_type]
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"

if not parsed_url.path then
return false, fmt("operation %s is not supported for cohere provider", conf.route_type)
end
end

-- if the path is read from a URL capture, ensure that it is valid
Expand Down
4 changes: 2 additions & 2 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model_info, route_type)
request_table.model = request_table.model or model_info.name
request_table.model = model_info.name or request_table.model
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

return request_table, "application/json", nil
end,

["llm/v1/completions"] = function(request_table, model_info, route_type)
request_table.model = model_info.name
request_table.model = model_info.name or request_table.model
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

Expand Down
36 changes: 27 additions & 9 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ _M.clear_response_headers = {
-- @return {string} error if any is thrown - request should definitely be terminated if this is not nil
function _M.merge_config_defaults(request, options, request_format)
if options then
request.temperature = request.temperature or options.temperature
request.max_tokens = request.max_tokens or options.max_tokens
request.top_p = request.top_p or options.top_p
request.top_k = request.top_k or options.top_k
request.temperature = options.temperature or request.temperature
request.max_tokens = options.max_tokens or request.max_tokens
request.top_p = options.top_p or request.top_p
request.top_k = options.top_k or request.top_k
end

return request, nil
Expand Down Expand Up @@ -206,19 +206,35 @@ function _M.frame_to_events(frame)
}
end
else
-- standard SSE parser
local event_lines = split(frame, "\n")
local struct = { event = nil, id = nil, data = nil }

for _, dat in ipairs(event_lines) do
for i, dat in ipairs(event_lines) do
if #dat < 1 then
events[#events + 1] = struct
struct = { event = nil, id = nil, data = nil }
end

-- test for truncated chunk on the last line (no trailing \r\n\r\n)
if #dat > 0 and #event_lines == i then
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head")
kong.ctx.plugin.truncated_frame = dat
break -- stop parsing immediately, server has done something wrong
end

-- test for abnormal start-of-frame (truncation tail)
if kong and kong.ctx.plugin.truncated_frame then
-- this is the tail of a previous incomplete chunk
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail")
dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat)
kong.ctx.plugin.truncated_frame = nil
end

local s1, _ = str_find(dat, ":") -- find where the cut point is

if s1 and s1 ~= 1 then
local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world
local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world
local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world

-- for now not checking if the value is already been set
Expand Down Expand Up @@ -249,7 +265,7 @@ function _M.to_ollama(request_table, model)

-- common parameters
input.stream = request_table.stream or false -- for future capability
input.model = model.name
input.model = model.name or request_table.name

if model.options then
input.options = {}
Expand Down Expand Up @@ -603,8 +619,10 @@ end
-- Function to count the number of words in a string
local function count_words(str)
local count = 0
for word in str:gmatch("%S+") do
count = count + 1
if type(str) == "string" then
for word in str:gmatch("%S+") do
count = count + 1
end
end
return count
end
Expand Down
15 changes: 12 additions & 3 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ local function get_token_text(event_t)
-- - event_t.choices[1].delta.content
-- - event_t.choices[1].text
-- - ""
return (first_choice.delta or EMPTY).content or first_choice.text or ""
local token_text = (first_choice.delta or EMPTY).content or first_choice.text or ""
return (type(token_text) == "string" and token_text) or ""
end


Expand Down Expand Up @@ -334,17 +335,25 @@ function _M:access(conf)

-- copy from the user request if present
if (not multipart) and (not conf_m.model.name) and (request_table.model) then
conf_m.model.name = request_table.model
if request_table.model ~= cjson.null then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check the case that the model in request is a blank string?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about checking type(request_table.model) == "string"

conf_m.model.name = request_table.model
end
elseif multipart then
conf_m.model.name = "NOT_SPECIFIED"
end

-- check that the user isn't trying to override the plugin conf model in the request body
if request_table and request_table.model and type(request_table.model) == "string" then
if request_table.model ~= conf_m.model.name then
return bad_request("cannot use own model - must be: " .. conf_m.model.name)
end
end

-- model is stashed in the copied plugin conf, for consistency in transformation functions
if not conf_m.model.name then
return bad_request("model parameter not found in request, nor in gateway configuration")
end

-- stash for analytics later
kong_ctx_plugin.llm_model_requested = conf_m.model.name

-- check the incoming format is the same as the configured LLM format
Expand Down
6 changes: 3 additions & 3 deletions spec/03-plugins/38-ai-proxy/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS,
{
max_tokens = 1024,
top_p = 1.0,
top_p = 0.5,
},
"llm/v1/chat"
)
Expand All @@ -638,9 +638,9 @@ describe(PLUGIN_NAME .. ": (unit)", function()

assert.is_nil(err)
assert.same({
max_tokens = 256,
max_tokens = 1024,
temperature = 0.1,
top_p = 0.2,
top_p = 0.5,
some_extra_param = "string_val",
another_extra_param = 0.5,
}, formatted)
Expand Down
46 changes: 46 additions & 0 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,52 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}, json.choices[1].message)
end)

it("good request, parses model of cjson.null", function()
local body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json")
body = cjson.decode(body)
body.model = cjson.null
body = cjson.encode(body)

local r = client:get("/openai/llm/v1/chat/good", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = body,
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)
end)

it("tries to override configured model", function()
local r = client:get("/openai/llm/v1/chat/good", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"),
})

local body = assert.res_status(400 , r)
local json = cjson.decode(body)

assert.same(json, {error = { message = "cannot use own model - must be: gpt-3.5-turbo" } })
end)

it("bad upstream response", function()
local r = client:get("/openai/llm/v1/chat/bad_upstream_response", {
headers = {
Expand Down