Skip to content

Commit

Permalink
feat: support for custom models, drop of support for v1/completions
Browse files Browse the repository at this point in the history
  • Loading branch information
liby committed Nov 8, 2023
1 parent c0b2841 commit a3b9d39
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 55 deletions.
34 changes: 25 additions & 9 deletions src/info.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@
"identifier": "deploymentName",
"type": "text",
"title": "Dep. Name",
"desc": "可选项。此值为在部署模型时为部署选择的自定义名称,可在 Azure 门户中的 “资源管理>“部署下查看",
"desc": "可选项。此值为在部署 Azure 模型时为部署选择的自定义名称,可在 Azure 门户中的 “资源管理>“部署下查看",
"textConfig": {
"type": "visible"
}
},
{
"identifier": "apiVersion",
"type": "text",
"title": "API Version",
"desc": "可选项。此值为在使用 Azure 模型时采用的 Chat completions API 版本,不支持 2023-03-15-preview 之前的版本",
"textConfig": {
"type": "visible",
"placeholderText": "2023-08-01-preview"
}
},
{
"identifier": "apiKeys",
"type": "text",
Expand All @@ -47,6 +57,10 @@
"title": "模型",
"defaultValue": "gpt-3.5-turbo",
"menuValues": [
{
"title": "custom",
"value": "custom"
},
{
"title": "gpt-3.5-turbo-1106 (recommended)",
"value": "gpt-3.5-turbo-1106"
Expand Down Expand Up @@ -90,17 +104,19 @@
{
"title": "gpt-4-32k-0613",
"value": "gpt-4-32k-0613"
},
{
"title": "text-davinci-003",
"value": "text-davinci-003"
},
{
"title": "text-davinci-002",
"value": "text-davinci-002"
}
]
},
{
"identifier": "customModel",
"type": "text",
"title": "自定义模型",
"desc": "可选项。当 Model 选择为 custom 时,此项为必填项。请填写有效的模型名称",
"textConfig": {
"type": "visible",
"placeholderText": "gpt-3.5-turbo"
}
},
{
"identifier": "customSystemPrompt",
"type": "text",
Expand Down
83 changes: 37 additions & 46 deletions src/main.js
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
//@ts-check

var lang = require("./lang.js");
var ChatGPTModels = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-1106",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
];

var SYSTEM_PROMPT = "You are a translation engine that can only translate text and cannot interpret it."

Expand Down Expand Up @@ -146,11 +133,10 @@ function replacePromptKeywords(prompt, query) {
}

/**
* @param {typeof ChatGPTModels[number]} model
* @param {boolean} isChatGPTModel
* @param {string} model
* @param {Bob.TranslateQuery} query
* @returns {{
* model: typeof ChatGPTModels[number];
* model: string;
* temperature: number;
* max_tokens: number;
* top_p: number;
Expand All @@ -163,7 +149,7 @@ function replacePromptKeywords(prompt, query) {
* prompt?: string;
* }}
*/
function buildRequestBody(model, isChatGPTModel, query) {
function buildRequestBody(model, query) {
let { customSystemPrompt, customUserPrompt } = $option;
const { generatedSystemPrompt, generatedUserPrompt } = generatePrompts(query);

Expand All @@ -174,7 +160,7 @@ function buildRequestBody(model, isChatGPTModel, query) {
const userPrompt = customUserPrompt || generatedUserPrompt;

const standardBody = {
model,
model: model,
stream: true,
temperature: 0.2,
max_tokens: 1000,
Expand All @@ -183,26 +169,19 @@ function buildRequestBody(model, isChatGPTModel, query) {
presence_penalty: 1,
};

if (isChatGPTModel) {
return {
...standardBody,
model,
messages: [
{
role: "system",
content: systemPrompt,
},
{
role: "user",
content: userPrompt,
},
],
};
}
return {
...standardBody,
model,
prompt: userPrompt,
model: model,
messages: [
{
role: "system",
content: systemPrompt,
},
{
role: "user",
content: userPrompt,
},
],
};
}

Expand All @@ -225,12 +204,11 @@ function handleError(query, result) {

/**
* @param {Bob.TranslateQuery} query
* @param {boolean} isChatGPTModel
* @param {string} targetText
* @param {string} textFromResponse
* @returns {string}
*/
function handleResponse(query, isChatGPTModel, targetText, textFromResponse) {
function handleResponse(query, targetText, textFromResponse) {
if (textFromResponse !== '[DONE]') {
try {
const dataObj = JSON.parse(textFromResponse);
Expand All @@ -246,7 +224,7 @@ function handleResponse(query, isChatGPTModel, targetText, textFromResponse) {
return targetText;
}

const content = isChatGPTModel ? choices[0].delta.content : choices[0].text;
const content = choices[0].delta.content;
if (content !== undefined) {
targetText += content;
query.onStream({
Expand Down Expand Up @@ -284,7 +262,18 @@ function translate(query) {
});
}

const { model, apiKeys, apiUrl, deploymentName } = $option;
const { model, customModel, apiKeys, apiVersion, apiUrl, deploymentName } = $option;

const isCustomModelRequired = model === "custom";
if (isCustomModelRequired && !customModel) {
query.onCompletion({
error: {
type: "param",
message: "配置错误 - 请确保您在插件配置中填入了正确的自定义模型名称",
addtion: "请在插件配置中填写自定义模型名称",
},
});
}

if (!apiKeys) {
query.onCompletion({
Expand All @@ -295,20 +284,22 @@ function translate(query) {
},
});
}

const modelValue = isCustomModelRequired ? customModel : model;

const trimmedApiKeys = apiKeys.endsWith(",") ? apiKeys.slice(0, -1) : apiKeys;
const apiKeySelection = trimmedApiKeys.split(",").map(key => key.trim());
const apiKey = apiKeySelection[Math.floor(Math.random() * apiKeySelection.length)];

const modifiedApiUrl = ensureHttpsAndNoTrailingSlash(apiUrl || "https://api.openai.com");

const isChatGPTModel = ChatGPTModels.includes(model);
const isAzureServiceProvider = modifiedApiUrl.includes("openai.azure.com");
let apiUrlPath = isChatGPTModel ? "/v1/chat/completions" : "/v1/completions";
let apiUrlPath = "/v1/chat/completions";
const apiVersionQuery = apiVersion ? `?api-version=${apiVersion}` : "?api-version=2023-08-01-preview";

if (isAzureServiceProvider) {
if (deploymentName) {
apiUrlPath = `/openai/deployments/${deploymentName}`;
apiUrlPath += isChatGPTModel ? "/chat/completions?api-version=2023-03-15-preview" : "/completions?api-version=2022-12-01";
apiUrlPath = `/openai/deployments/${deploymentName}/chat/completions${apiVersionQuery}`;
} else {
query.onCompletion({
error: {
Expand All @@ -321,7 +312,7 @@ function translate(query) {
}

const header = buildHeader(isAzureServiceProvider, apiKey);
const body = buildRequestBody(model, isChatGPTModel, query);
const body = buildRequestBody(modelValue, query);


let targetText = ""; // 初始化拼接结果变量
Expand Down Expand Up @@ -351,7 +342,7 @@ function translate(query) {
if (match) {
// 如果是一个完整的消息,处理它并从缓冲变量中移除
const textFromResponse = match[1].trim();
targetText = handleResponse(query, isChatGPTModel, targetText, textFromResponse);
targetText = handleResponse(query, targetText, textFromResponse);
buffer = buffer.slice(match[0].length);
} else {
// 如果没有完整的消息,等待更多的数据
Expand Down

0 comments on commit a3b9d39

Please sign in to comment.