Skip to content

Commit d2faf1a

Browse files
authoredMar 6, 2024··
feat: Added setLlmTokenCountCallback API endpoint to register a callback for calculating token count when none is provided (#2065)
1 parent 793abe8 commit d2faf1a

File tree

11 files changed

+322
-11
lines changed

11 files changed

+322
-11
lines changed
 

‎api.js

+34-1
Original file line numberDiff line numberDiff line change
@@ -1839,7 +1839,7 @@ API.prototype.setErrorGroupCallback = function setErrorGroupCallback(callback) {
18391839
)
18401840
metric.incrementCallCount()
18411841

1842-
if (!this.shim.isFunction(callback) || this.shim.isPromise(callback)) {
1842+
if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) {
18431843
logger.warn(
18441844
'Error Group callback must be a synchronous function, Error Group attribute will not be added'
18451845
)
@@ -1849,4 +1849,37 @@ API.prototype.setErrorGroupCallback = function setErrorGroupCallback(callback) {
18491849
this.agent.errors.errorGroupCallback = callback
18501850
}
18511851

1852+
/**
1853+
* Registers a callback which will be used for calculating token counts on Llm events when they are not
1854+
* available. This function will typically only be used if `ai_monitoring.record_content.enabled` is false
1855+
* and you want to still capture token counts for Llm events.
1856+
*
1857+
* Provided callbacks must return an integer value for the token count for a given piece of content.
1858+
*
1859+
* @param {Function} callback - synchronous function called to calculate token count for content.
1860+
* @example
1861+
* // @param {string} model - name of model (i.e. gpt-3.5-turbo)
1862+
* // @param {string} content - prompt or completion response
1863+
* function tokenCallback(model, content) {
1864+
* // calculate tokens based on model and content
1865+
* // return token count
1866+
* return 40
1867+
* }
1868+
*/
1869+
API.prototype.setLlmTokenCountCallback = function setLlmTokenCountCallback(callback) {
1870+
const metric = this.agent.metrics.getOrCreateMetric(
1871+
NAMES.SUPPORTABILITY.API + '/setLlmTokenCountCallback'
1872+
)
1873+
metric.incrementCallCount()
1874+
1875+
if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) {
1876+
logger.warn(
1877+
'Llm token count callback must be a synchronous function, callback will not be registered.'
1878+
)
1879+
return
1880+
}
1881+
1882+
this.agent.llm.tokenCountCallback = callback
1883+
}
1884+
18521885
module.exports = API

‎lib/instrumentation/restify.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function wrapMiddleware(shim, middleware, _name, route) {
9393
})
9494

9595
const wrappedMw = shim.recordMiddleware(middleware, spec)
96-
if (middleware.constructor.name === 'AsyncFunction') {
96+
if (shim.isAsyncFunction(middleware)) {
9797
return async function asyncShim() {
9898
return wrappedMw.apply(this, arguments)
9999
}

‎lib/llm-events/openai/chat-completion-message.js

+6-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@ module.exports = class LlmChatCompletionMessage extends LlmEvent {
2020
}
2121

2222
if (this.is_response) {
23-
this.token_count = response?.usage?.completion_tokens
23+
this.token_count =
24+
response?.usage?.completion_tokens ||
25+
agent.llm?.tokenCountCallback?.(this['response.model'], message?.content)
2426
} else {
25-
this.token_count = response?.usage?.prompt_tokens
27+
this.token_count =
28+
response?.usage?.prompt_tokens ||
29+
agent.llm?.tokenCountCallback?.(request.model || request.engine, message?.content)
2630
}
2731
}
2832
}

‎lib/llm-events/openai/embedding.js

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ module.exports = class LlmEmbedding extends LlmEvent {
1414
if (agent.config.ai_monitoring.record_content.enabled === true) {
1515
this.input = request.input?.toString()
1616
}
17-
this.token_count = response?.usage?.prompt_tokens
17+
this.token_count =
18+
response?.usage?.prompt_tokens ||
19+
agent.llm?.tokenCountCallback?.(this['request.model'], request.input?.toString())
1820
}
1921
}

‎lib/shim/shim.js

+15
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ Shim.prototype.getName = getName
125125
Shim.prototype.isObject = isObject
126126
Shim.prototype.isFunction = isFunction
127127
Shim.prototype.isPromise = isPromise
128+
Shim.prototype.isAsyncFunction = isAsyncFunction
128129
Shim.prototype.isString = isString
129130
Shim.prototype.isNumber = isNumber
130131
Shim.prototype.isBoolean = isBoolean
@@ -1345,6 +1346,20 @@ function isPromise(obj) {
13451346
return obj && typeof obj.then === 'function'
13461347
}
13471348

1349+
/**
1350+
* Determines if function is an async function.
1351+
* Note it does not test if the return value of function is a
1352+
* promise or async function
1353+
*
1354+
* @memberof Shim.prototype
1355+
* @param fn
1356+
* @param (function) function to test if async
1357+
* @returns {boolean} True if the function is an async function
1358+
*/
1359+
function isAsyncFunction(fn) {
1360+
return fn.constructor.name === 'AsyncFunction'
1361+
}
1362+
13481363
/**
13491364
* Determines if the given value is null.
13501365
*

‎test/unit/api/api-llm.test.js

+47-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ tap.test('Agent API LLM methods', (t) => {
2828
loggerMock.warn.reset()
2929
const agent = helper.loadMockedAgent()
3030
t.context.api = new API(agent)
31-
t.context.api.agent.config.ai_monitoring.enabled = true
31+
agent.config.ai_monitoring.enabled = true
32+
t.context.agent = agent
3233
})
3334

3435
t.afterEach((t) => {
@@ -119,4 +120,49 @@ tap.test('Agent API LLM methods', (t) => {
119120
})
120121
})
121122
})
123+
124+
t.test('setLlmTokenCount should register callback to calculate token counts', async (t) => {
125+
const { api, agent } = t.context
126+
function callback(model, content) {
127+
if (model === 'foo' && content === 'bar') {
128+
return 10
129+
}
130+
131+
return 1
132+
}
133+
api.setLlmTokenCountCallback(callback)
134+
t.same(agent.llm.tokenCountCallback, callback)
135+
})
136+
137+
t.test('should not store token count callback if it is async', async (t) => {
138+
const { api, agent } = t.context
139+
async function callback(model, content) {
140+
return await new Promise((resolve) => {
141+
if (model === 'foo' && content === 'bar') {
142+
resolve(10)
143+
}
144+
})
145+
}
146+
api.setLlmTokenCountCallback(callback)
147+
t.same(agent.llm.tokenCountCallback, undefined)
148+
t.equal(loggerMock.warn.callCount, 1)
149+
t.equal(
150+
loggerMock.warn.args[0][0],
151+
'Llm token count callback must be a synchronous function, callback will not be registered.'
152+
)
153+
})
154+
155+
t.test(
156+
'should not store token count callback if callback is not actually a function',
157+
async (t) => {
158+
const { api, agent } = t.context
159+
api.setLlmTokenCountCallback({ unit: 'test' })
160+
t.same(agent.llm.tokenCountCallback, undefined)
161+
t.equal(loggerMock.warn.callCount, 1)
162+
t.equal(
163+
loggerMock.warn.args[0][0],
164+
'Llm token count callback must be a synchronous function, callback will not be registered.'
165+
)
166+
}
167+
)
122168
})

‎test/unit/api/api-set-error-group-callback.test.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ tap.test('Agent API = set Error Group callback', (t) => {
8484
})
8585

8686
t.test('should not attach the callback when async function', (t) => {
87-
function callback() {
88-
return new Promise((resolve) => {
87+
async function callback() {
88+
return await new Promise((resolve) => {
8989
setTimeout(() => {
9090
resolve()
9191
}, 200)

‎test/unit/api/stub.test.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
const tap = require('tap')
99
const API = require('../../../stub_api')
1010

11-
const EXPECTED_API_COUNT = 34
11+
const EXPECTED_API_COUNT = 35
1212

1313
tap.test('Agent API - Stubbed Agent API', (t) => {
1414
t.autoend()

‎test/unit/llm-events/openai/chat-completion-message.test.js

+111-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ const helper = require('../../../lib/agent_helper')
1111
const { req, chatRes, getExpectedResult } = require('./common')
1212

1313
tap.test('LlmChatCompletionMessage', (t) => {
14-
t.autoend()
15-
1614
let agent
1715
t.beforeEach(() => {
1816
agent = helper.loadMockedAgent()
@@ -104,4 +102,115 @@ tap.test('LlmChatCompletionMessage', (t) => {
104102
t.end()
105103
})
106104
})
105+
106+
t.test('should use token_count from tokenCountCallback for prompt message', (t) => {
107+
const api = helper.getAgentApi()
108+
const expectedCount = 4
109+
function cb(model, content) {
110+
t.equal(model, 'gpt-3.5-turbo-0613')
111+
t.equal(content, 'What is a woodchuck?')
112+
return expectedCount
113+
}
114+
api.setLlmTokenCountCallback(cb)
115+
helper.runInTransaction(agent, () => {
116+
api.startSegment('fakeSegment', false, () => {
117+
const segment = api.shim.getActiveSegment()
118+
const summaryId = 'chat-summary-id'
119+
delete chatRes.usage
120+
const chatMessageEvent = new LlmChatCompletionMessage({
121+
agent,
122+
segment,
123+
request: req,
124+
response: chatRes,
125+
completionId: summaryId,
126+
message: req.messages[0],
127+
index: 0
128+
})
129+
t.equal(chatMessageEvent.token_count, expectedCount)
130+
t.end()
131+
})
132+
})
133+
})
134+
135+
t.test('should use token_count from tokenCountCallback for completion messages', (t) => {
136+
const api = helper.getAgentApi()
137+
const expectedCount = 4
138+
function cb(model, content) {
139+
t.equal(model, 'gpt-3.5-turbo-0613')
140+
t.equal(content, 'a lot')
141+
return expectedCount
142+
}
143+
api.setLlmTokenCountCallback(cb)
144+
helper.runInTransaction(agent, () => {
145+
api.startSegment('fakeSegment', false, () => {
146+
const segment = api.shim.getActiveSegment()
147+
const summaryId = 'chat-summary-id'
148+
delete chatRes.usage
149+
const chatMessageEvent = new LlmChatCompletionMessage({
150+
agent,
151+
segment,
152+
request: req,
153+
response: chatRes,
154+
completionId: summaryId,
155+
message: chatRes.choices[0].message,
156+
index: 2
157+
})
158+
t.equal(chatMessageEvent.token_count, expectedCount)
159+
t.end()
160+
})
161+
})
162+
})
163+
164+
t.test('should not set token_count if not set in usage nor a callback registered', (t) => {
165+
const api = helper.getAgentApi()
166+
helper.runInTransaction(agent, () => {
167+
api.startSegment('fakeSegment', false, () => {
168+
const segment = api.shim.getActiveSegment()
169+
const summaryId = 'chat-summary-id'
170+
delete chatRes.usage
171+
const chatMessageEvent = new LlmChatCompletionMessage({
172+
agent,
173+
segment,
174+
request: req,
175+
response: chatRes,
176+
completionId: summaryId,
177+
message: chatRes.choices[0].message,
178+
index: 2
179+
})
180+
t.equal(chatMessageEvent.token_count, undefined)
181+
t.end()
182+
})
183+
})
184+
})
185+
186+
t.test(
187+
'should not set token_count if not set in usage nor a callback registered returns count',
188+
(t) => {
189+
const api = helper.getAgentApi()
190+
function cb() {
191+
// empty cb
192+
}
193+
api.setLlmTokenCountCallback(cb)
194+
helper.runInTransaction(agent, () => {
195+
api.startSegment('fakeSegment', false, () => {
196+
const segment = api.shim.getActiveSegment()
197+
const summaryId = 'chat-summary-id'
198+
delete chatRes.usage
199+
const chatMessageEvent = new LlmChatCompletionMessage({
200+
agent,
201+
segment,
202+
request: req,
203+
response: chatRes,
204+
completionId: summaryId,
205+
message: chatRes.choices[0].message,
206+
index: 2
207+
})
208+
t.equal(chatMessageEvent.token_count, undefined)
209+
t.end()
210+
})
211+
})
212+
}
213+
)
214+
215+
t.end()
107216
})

‎test/unit/llm-events/openai/embedding.test.js

+50
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,54 @@ tap.test('LlmEmbedding', (t) => {
113113
t.end()
114114
})
115115
})
116+
117+
t.test('should calculate token count from tokenCountCallback', (t) => {
118+
const req = {
119+
input: 'This is my test input',
120+
model: 'gpt-3.5-turbo-0613'
121+
}
122+
123+
const api = helper.getAgentApi()
124+
125+
function cb(model, content) {
126+
if (model === req.model) {
127+
return content.length
128+
}
129+
}
130+
131+
api.setLlmTokenCountCallback(cb)
132+
helper.runInTransaction(agent, () => {
133+
const segment = api.shim.getActiveSegment()
134+
delete res.usage
135+
const embeddingEvent = new LlmEmbedding({
136+
agent,
137+
segment,
138+
request: req,
139+
response: res
140+
})
141+
t.equal(embeddingEvent.token_count, 21)
142+
t.end()
143+
})
144+
})
145+
146+
t.test('should not set token count when not present in usage nor tokenCountCallback', (t) => {
147+
const req = {
148+
input: 'This is my test input',
149+
model: 'gpt-3.5-turbo-0613'
150+
}
151+
152+
const api = helper.getAgentApi()
153+
helper.runInTransaction(agent, () => {
154+
const segment = api.shim.getActiveSegment()
155+
delete res.usage
156+
const embeddingEvent = new LlmEmbedding({
157+
agent,
158+
segment,
159+
request: req,
160+
response: res
161+
})
162+
t.equal(embeddingEvent.token_count, undefined)
163+
t.end()
164+
})
165+
})
116166
})

‎test/versioned/openai/chat-completions.tap.js

+52
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,58 @@ tap.test('OpenAI instrumentation - chat completions', (t) => {
190190
}
191191
)
192192

193+
t.test('should call the tokenCountCallback in streaming', (test) => {
194+
const { client, agent } = t.context
195+
const promptContent = 'Streamed response'
196+
const promptContent2 = 'What does 1 plus 1 equal?'
197+
let res = ''
198+
const expectedModel = 'gpt-4'
199+
const api = helper.getAgentApi()
200+
function cb(model, content) {
201+
t.equal(model, expectedModel)
202+
if (content === promptContent || content === promptContent2) {
203+
return 53
204+
} else if (content === res) {
205+
return 11
206+
}
207+
}
208+
api.setLlmTokenCountCallback(cb)
209+
test.teardown(() => {
210+
delete agent.llm.tokenCountCallback
211+
})
212+
helper.runInTransaction(agent, async (tx) => {
213+
const stream = await client.chat.completions.create({
214+
max_tokens: 100,
215+
temperature: 0.5,
216+
model: expectedModel,
217+
messages: [
218+
{ role: 'user', content: promptContent },
219+
{ role: 'user', content: promptContent2 }
220+
],
221+
stream: true
222+
})
223+
224+
for await (const chunk of stream) {
225+
res += chunk.choices[0]?.delta?.content
226+
}
227+
228+
const events = agent.customEventAggregator.events.toArray()
229+
const chatMsgs = events.filter(([{ type }]) => type === 'LlmChatCompletionMessage')
230+
test.llmMessages({
231+
tokenUsage: true,
232+
tx,
233+
chatMsgs,
234+
id: 'chatcmpl-8MzOfSMbLxEy70lYAolSwdCzfguQZ',
235+
model: expectedModel,
236+
resContent: res,
237+
reqContent: promptContent
238+
})
239+
240+
tx.end()
241+
test.end()
242+
})
243+
})
244+
193245
t.test('handles error in stream', (test) => {
194246
const { client, agent } = t.context
195247
helper.runInTransaction(agent, async (tx) => {

0 commit comments

Comments
 (0)
Please sign in to comment.