Skip to content

Commit ff2e509

Browse files
authoredJun 27, 2024··
feat: Added AIM support for Meta Llama3 models in AWS Bedrock (#2306)
1 parent 0bf8908 commit ff2e509

File tree

11 files changed

+85
-37
lines changed

11 files changed

+85
-37
lines changed
 

‎ai-support.json

+13-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,19 @@
6464
}
6565
]
6666
},
67-
67+
{
68+
"name": "Meta Llama3",
69+
"features": [
70+
{
71+
"title": "Text",
72+
"supported": true
73+
},
74+
{
75+
"title": "Image",
76+
"supported": false
77+
}
78+
]
79+
},
6880
{
6981
"name": "Amazon Titan",
7082
"features": [

‎lib/llm-events/aws-bedrock/bedrock-command.js

+5-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class BedrockCommand {
3737
result = this.#body.max_tokens_to_sample
3838
} else if (this.isClaude3() === true || this.isCohere() === true) {
3939
result = this.#body.max_tokens
40-
} else if (this.isLlama2() === true) {
40+
} else if (this.isLlama() === true) {
4141
result = this.#body.max_gen_length
4242
} else if (this.isTitan() === true) {
4343
result = this.#body.textGenerationConfig?.maxTokenCount
@@ -80,7 +80,7 @@ class BedrockCommand {
8080
this.isClaude() === true ||
8181
this.isAi21() === true ||
8282
this.isCohere() === true ||
83-
this.isLlama2() === true
83+
this.isLlama() === true
8484
) {
8585
result = this.#body.prompt
8686
} else if (this.isClaude3() === true) {
@@ -104,7 +104,7 @@ class BedrockCommand {
104104
this.isClaude3() === true ||
105105
this.isAi21() === true ||
106106
this.isCohere() === true ||
107-
this.isLlama2() === true
107+
this.isLlama() === true
108108
) {
109109
result = this.#body.temperature
110110
}
@@ -131,8 +131,8 @@ class BedrockCommand {
131131
return this.#modelId.startsWith('cohere.embed')
132132
}
133133

134-
isLlama2() {
135-
return this.#modelId.startsWith('meta.llama2')
134+
isLlama() {
135+
return this.#modelId.startsWith('meta.llama')
136136
}
137137

138138
isTitan() {

‎lib/llm-events/aws-bedrock/bedrock-response.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class BedrockResponse {
7070
} else if (cmd.isCohere() === true) {
7171
this.#completions = body.generations?.map((g) => g.text) ?? []
7272
this.#id = body.id
73-
} else if (cmd.isLlama2() === true) {
73+
} else if (cmd.isLlama() === true) {
7474
body.generation && this.#completions.push(body.generation)
7575
} else if (cmd.isTitan() === true) {
7676
this.#completions = body.results?.map((r) => r.outputText) ?? []
@@ -107,7 +107,7 @@ class BedrockResponse {
107107
result = this.#parsedBody.stop_reason
108108
} else if (cmd.isCohere() === true) {
109109
result = this.#parsedBody.generations?.find((r) => r.finish_reason !== null)?.finish_reason
110-
} else if (cmd.isLlama2() === true) {
110+
} else if (cmd.isLlama() === true) {
111111
result = this.#parsedBody.stop_reason
112112
} else if (cmd.isTitan() === true) {
113113
result = this.#parsedBody.results?.find((r) => r.completionReason !== null)?.completionReason

‎lib/llm-events/aws-bedrock/stream-handler.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ class StreamHandler {
114114
} else if (bedrockCommand.isCohereEmbed() === true) {
115115
this.stopReasonKey = 'nr_none'
116116
this.generator = handleCohereEmbed
117-
} else if (bedrockCommand.isLlama2() === true) {
117+
} else if (bedrockCommand.isLlama() === true) {
118118
this.stopReasonKey = 'stop_reason'
119-
this.generator = handleLlama2
119+
this.generator = handleLlama
120120
} else if (bedrockCommand.isTitan() === true) {
121121
this.stopReasonKey = 'completionReason'
122122
this.generator = handleTitan
@@ -271,7 +271,7 @@ async function* handleCohereEmbed() {
271271
}
272272
}
273273

274-
async function* handleLlama2() {
274+
async function* handleLlama() {
275275
let currentBody = {}
276276
let generation = ''
277277

‎test/lib/aws-server-stubs/ai-server/index.js

+5-2
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,11 @@ function handler(req, res) {
114114
}
115115

116116
case 'meta.llama2-13b-chat-v1':
117-
case 'meta.llama2-70b-chat-v1': {
118-
response = responses.llama2.get(payload.prompt)
117+
case 'meta.llama2-70b-chat-v1':
118+
// llama3 responses are indentical, just return llama2 data
119+
case 'meta.llama3-8b-instruct-v1:0':
120+
case 'meta.llama3-70b-instruct-v1:0': {
121+
response = responses.llama.get(payload.prompt)
119122
break
120123
}
121124

‎test/lib/aws-server-stubs/ai-server/responses/index.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ const amazon = require('./amazon')
1010
const claude = require('./claude')
1111
const claude3 = require('./claude3')
1212
const cohere = require('./cohere')
13-
const llama2 = require('./llama2')
13+
const llama = require('./llama')
1414

1515
module.exports = {
1616
ai21,
1717
amazon,
1818
claude,
1919
claude3,
2020
cohere,
21-
llama2
21+
llama
2222
}

‎test/lib/aws-server-stubs/ai-server/responses/llama2.js ‎test/lib/aws-server-stubs/ai-server/responses/llama.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
const responses = new Map()
99
const { contentType, reqId } = require('./constants')
1010

11-
responses.set('text llama2 ultimate question', {
11+
responses.set('text llama ultimate question', {
1212
headers: {
1313
'content-type': contentType,
1414
'x-amzn-requestid': reqId,
@@ -25,7 +25,7 @@ responses.set('text llama2 ultimate question', {
2525
}
2626
})
2727

28-
responses.set('text llama2 ultimate question streamed', {
28+
responses.set('text llama ultimate question streamed', {
2929
headers: {
3030
'content-type': 'application/vnd.amazon.eventstream',
3131
'x-amzn-requestid': reqId,
@@ -68,7 +68,7 @@ responses.set('text llama2 ultimate question streamed', {
6868
]
6969
})
7070

71-
responses.set('text llama2 ultimate question error', {
71+
responses.set('text llama ultimate question error', {
7272
headers: {
7373
'content-type': contentType,
7474
'x-amzn-requestid': reqId,

‎test/unit/llm-events/aws-bedrock/bedrock-command.test.js

+35-3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ const llama2 = {
5252
}
5353
}
5454

55+
const llama3 = {
56+
modelId: 'meta.llama3-8b-instruct-v1:0',
57+
body: {
58+
prompt: 'who are you'
59+
}
60+
}
61+
5562
const titan = {
5663
modelId: 'amazon.titan-text-lite-v1',
5764
body: {
@@ -85,7 +92,7 @@ tap.test('non-conforming command is handled gracefully', async (t) => {
8592
'Claude3',
8693
'Cohere',
8794
'CohereEmbed',
88-
'Llama2',
95+
'Llama',
8996
'Titan',
9097
'TitanEmbed'
9198
]) {
@@ -212,7 +219,7 @@ tap.test('cohere embed minimal command works', async (t) => {
212219
tap.test('llama2 minimal command works', async (t) => {
213220
t.context.updatePayload(structuredClone(llama2))
214221
const cmd = new BedrockCommand(t.context.input)
215-
t.equal(cmd.isLlama2(), true)
222+
t.equal(cmd.isLlama(), true)
216223
t.equal(cmd.maxTokens, undefined)
217224
t.equal(cmd.modelId, llama2.modelId)
218225
t.equal(cmd.modelType, 'completion')
@@ -226,7 +233,32 @@ tap.test('llama2 complete command works', async (t) => {
226233
payload.body.temperature = 0.5
227234
t.context.updatePayload(payload)
228235
const cmd = new BedrockCommand(t.context.input)
229-
t.equal(cmd.isLlama2(), true)
236+
t.equal(cmd.isLlama(), true)
237+
t.equal(cmd.maxTokens, 25)
238+
t.equal(cmd.modelId, payload.modelId)
239+
t.equal(cmd.modelType, 'completion')
240+
t.equal(cmd.prompt, payload.body.prompt)
241+
t.equal(cmd.temperature, payload.body.temperature)
242+
})
243+
244+
tap.test('llama3 minimal command works', async (t) => {
245+
t.context.updatePayload(structuredClone(llama3))
246+
const cmd = new BedrockCommand(t.context.input)
247+
t.equal(cmd.isLlama(), true)
248+
t.equal(cmd.maxTokens, undefined)
249+
t.equal(cmd.modelId, llama3.modelId)
250+
t.equal(cmd.modelType, 'completion')
251+
t.equal(cmd.prompt, llama3.body.prompt)
252+
t.equal(cmd.temperature, undefined)
253+
})
254+
255+
tap.test('llama3 complete command works', async (t) => {
256+
const payload = structuredClone(llama3)
257+
payload.body.max_gen_length = 25
258+
payload.body.temperature = 0.5
259+
t.context.updatePayload(payload)
260+
const cmd = new BedrockCommand(t.context.input)
261+
t.equal(cmd.isLlama(), true)
230262
t.equal(cmd.maxTokens, 25)
231263
t.equal(cmd.modelId, payload.modelId)
232264
t.equal(cmd.modelType, 'completion')

‎test/unit/llm-events/aws-bedrock/bedrock-response.test.js

+9-9
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ const cohere = {
3838
]
3939
}
4040

41-
const llama2 = {
42-
generation: 'llama2-response',
41+
const llama = {
42+
generation: 'llama-response',
4343
stop_reason: 'done'
4444
}
4545

@@ -79,7 +79,7 @@ tap.beforeEach((t) => {
7979
isCohere() {
8080
return false
8181
},
82-
isLlama2() {
82+
isLlama() {
8383
return false
8484
},
8585
isTitan() {
@@ -172,8 +172,8 @@ tap.test('cohere complete responses work', async (t) => {
172172
t.equal(res.statusCode, 200)
173173
})
174174

175-
tap.test('llama2 malformed responses work', async (t) => {
176-
t.context.bedrockCommand.isLlama2 = () => true
175+
tap.test('llama malformed responses work', async (t) => {
176+
t.context.bedrockCommand.isLlama = () => true
177177
const res = new BedrockResponse(t.context)
178178
t.same(res.completions, [])
179179
t.equal(res.finishReason, undefined)
@@ -183,11 +183,11 @@ tap.test('llama2 malformed responses work', async (t) => {
183183
t.equal(res.statusCode, 200)
184184
})
185185

186-
tap.test('llama2 complete responses work', async (t) => {
187-
t.context.bedrockCommand.isLlama2 = () => true
188-
t.context.updatePayload(structuredClone(llama2))
186+
tap.test('llama complete responses work', async (t) => {
187+
t.context.bedrockCommand.isLlama = () => true
188+
t.context.updatePayload(structuredClone(llama))
189189
const res = new BedrockResponse(t.context)
190-
t.same(res.completions, ['llama2-response'])
190+
t.same(res.completions, ['llama-response'])
191191
t.equal(res.finishReason, 'done')
192192
t.same(res.headers, t.context.response.response.headers)
193193
t.equal(res.id, undefined)

‎test/unit/llm-events/aws-bedrock/stream-handler.test.js

+5-5
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ tap.beforeEach((t) => {
4545
isClaude3() {
4646
return false
4747
},
48-
isLlama2() {
48+
isLlama() {
4949
return false
5050
},
5151
isTitan() {
@@ -242,15 +242,15 @@ tap.test('handles cohere embedding streams', async (t) => {
242242
t.equal(br.statusCode, 200)
243243
})
244244

245-
tap.test('handles llama2 streams', async (t) => {
246-
t.context.passThroughParams.bedrockCommand.isLlama2 = () => true
245+
tap.test('handles llama streams', async (t) => {
246+
t.context.passThroughParams.bedrockCommand.isLlama = () => true
247247
t.context.chunks = [
248248
{ generation: '1', stop_reason: null },
249249
{ generation: '2', stop_reason: 'done', ...t.context.metrics }
250250
]
251251
const handler = new StreamHandler(t.context)
252252

253-
t.equal(handler.generator.name, 'handleLlama2')
253+
t.equal(handler.generator.name, 'handleLlama')
254254
for await (const event of handler.generator()) {
255255
t.type(event.chunk.bytes, Uint8Array)
256256
}
@@ -267,7 +267,7 @@ tap.test('handles llama2 streams', async (t) => {
267267
})
268268

269269
const bc = new BedrockCommand({
270-
modelId: 'meta.llama2',
270+
modelId: 'meta.llama',
271271
body: JSON.stringify({
272272
prompt: 'prompt',
273273
max_gen_length: 5

‎test/versioned/aws-sdk-v3/bedrock-chat-completions.tap.js

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ const requests = {
4848
body: JSON.stringify({ prompt, temperature: 0.5, max_tokens: 100 }),
4949
modelId
5050
}),
51-
llama2: (prompt, modelId) => ({
51+
llama: (prompt, modelId) => ({
5252
body: JSON.stringify({ prompt, max_gen_length: 100, temperature: 0.5 }),
5353
modelId
5454
})
@@ -98,7 +98,8 @@ tap.afterEach(async (t) => {
9898
{ modelId: 'anthropic.claude-v2', resKey: 'claude' },
9999
{ modelId: 'anthropic.claude-3-haiku-20240307-v1:0', resKey: 'claude3' },
100100
{ modelId: 'cohere.command-text-v14', resKey: 'cohere' },
101-
{ modelId: 'meta.llama2-13b-chat-v1', resKey: 'llama2' }
101+
{ modelId: 'meta.llama2-13b-chat-v1', resKey: 'llama' },
102+
{ modelId: 'meta.llama3-8b-instruct-v1:0', resKey: 'llama' }
102103
].forEach(({ modelId, resKey }) => {
103104
tap.test(`${modelId}: should properly create completion segment`, (t) => {
104105
const { bedrock, client, responses, agent, expectedExternalPath } = t.context

0 commit comments

Comments
 (0)
Please sign in to comment.