mirror of
https://github.com/labring/FastGPT.git
synced 2025-12-26 04:32:50 +00:00
perf: chat framwork
This commit is contained in:
parent
91decc3683
commit
00a99261ae
File diff suppressed because one or more lines are too long
|
|
@ -1,6 +1,7 @@
|
|||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
|
||||
export const embeddingModel = 'text-embedding-ada-002';
|
||||
export type EmbeddingModelType = 'text-embedding-ada-002';
|
||||
|
||||
export enum OpenAiChatEnum {
|
||||
'GPT35' = 'gpt-3.5-turbo',
|
||||
|
|
@ -25,7 +26,7 @@ export const ChatModelMap = {
|
|||
},
|
||||
[OpenAiChatEnum.GPT432k]: {
|
||||
name: 'Gpt4-32k',
|
||||
contextMaxToken: 8000,
|
||||
contextMaxToken: 32000,
|
||||
maxTemperature: 1.5,
|
||||
price: 30
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { getOpenAIApi, authChat } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { authChat } from '@/service/utils/auth';
|
||||
import { modelServiceToolMap } from '@/service/utils/chat';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
import { resStreamResponse } from '@/service/utils/chat';
|
||||
import { searchKb } from '@/service/plugins/searchKb';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
|
|
@ -41,7 +42,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
await connectToDatabase();
|
||||
let startTime = Date.now();
|
||||
|
||||
const { model, showModelDetail, content, userApiKey, systemKey, userId } = await authChat({
|
||||
const { model, showModelDetail, content, userApiKey, systemApiKey, userId } = await authChat({
|
||||
modelId,
|
||||
chatId,
|
||||
authorization
|
||||
|
|
@ -54,9 +55,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
|
||||
// 使用了知识库搜索
|
||||
if (model.chat.useKb) {
|
||||
const { code, searchPrompt } = await searchKb_openai({
|
||||
apiKey: userApiKey || systemKey,
|
||||
isPay: !userApiKey,
|
||||
const { code, searchPrompt } = await searchKb({
|
||||
userApiKey,
|
||||
systemApiKey,
|
||||
text: prompt.value,
|
||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
|
||||
model,
|
||||
|
|
@ -73,53 +74,37 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
// 没有用知识库搜索,仅用系统提示词
|
||||
model.chat.systemPrompt &&
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
obj: ChatRoleEnum.System,
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
}
|
||||
|
||||
// 控制总 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
||||
});
|
||||
|
||||
// 计算温度
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// console.log(filterPrompts);
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
||||
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: true,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: 40000,
|
||||
responseType: 'stream',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
const { streamResponse } = await modelServiceToolMap[model.chat.chatModel].chatCompletion({
|
||||
apiKey: userApiKey || systemApiKey,
|
||||
temperature: +temperature,
|
||||
messages: prompts,
|
||||
stream: true
|
||||
});
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
step = 1;
|
||||
|
||||
const { responseContent } = await gpt35StreamResponse({
|
||||
const { totalTokens, finishMessages } = await resStreamResponse({
|
||||
model: model.chat.chatModel,
|
||||
res,
|
||||
stream,
|
||||
chatResponse,
|
||||
chatResponse: streamResponse,
|
||||
prompts,
|
||||
systemPrompt:
|
||||
showModelDetail && filterPrompts[0].role === 'system' ? filterPrompts[0].content : ''
|
||||
showModelDetail && prompts[0].obj === ChatRoleEnum.System ? prompts[0].value : ''
|
||||
});
|
||||
|
||||
// 只有使用平台的 key 才计费
|
||||
|
|
@ -128,7 +113,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
chatId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
textLen: finishMessages.map((item) => item.value).join('').length,
|
||||
tokens: totalTokens
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase } from '@/service/mongo';
|
||||
import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { authOpenApiKey, authModel } from '@/service/utils/auth';
|
||||
import { modelServiceToolMap, resStreamResponse } from '@/service/utils/chat';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
import { searchKb } from '@/service/plugins/searchKb';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
|
|
@ -64,9 +64,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
if (model.chat.useKb) {
|
||||
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
|
||||
|
||||
const { code, searchPrompt } = await searchKb_openai({
|
||||
apiKey,
|
||||
isPay: true,
|
||||
const { code, searchPrompt } = await searchKb({
|
||||
systemApiKey: apiKey,
|
||||
text: prompts[prompts.length - 1].value,
|
||||
similarity,
|
||||
model,
|
||||
|
|
@ -83,69 +82,55 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
// 没有用知识库搜索,仅用系统提示词
|
||||
if (model.chat.systemPrompt) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
obj: ChatRoleEnum.System,
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 控制总 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
||||
});
|
||||
|
||||
// 计算温度
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// console.log(filterPrompts);
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: 180000,
|
||||
responseType: isStream ? 'stream' : 'json',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
const { streamResponse, responseMessages, responseText, totalTokens } =
|
||||
await modelServiceToolMap[model.chat.chatModel].chatCompletion({
|
||||
apiKey,
|
||||
temperature: +temperature,
|
||||
messages: prompts,
|
||||
stream: isStream
|
||||
});
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
let responseContent = '';
|
||||
let textLen = 0;
|
||||
let tokens = totalTokens;
|
||||
|
||||
if (isStream) {
|
||||
step = 1;
|
||||
const streamResponse = await gpt35StreamResponse({
|
||||
const { finishMessages, totalTokens } = await resStreamResponse({
|
||||
model: model.chat.chatModel,
|
||||
res,
|
||||
stream,
|
||||
chatResponse
|
||||
chatResponse: streamResponse,
|
||||
prompts
|
||||
});
|
||||
responseContent = streamResponse.responseContent;
|
||||
textLen = finishMessages.map((item) => item.value).join('').length;
|
||||
tokens = totalTokens;
|
||||
} else {
|
||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
||||
textLen = responseMessages.map((item) => item.value).join('').length;
|
||||
jsonRes(res, {
|
||||
data: responseContent
|
||||
data: responseText
|
||||
});
|
||||
}
|
||||
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
textLen,
|
||||
tokens
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
|
|
|
|||
|
|
@ -1,144 +0,0 @@
|
|||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase, Model } from '@/service/mongo';
|
||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import { ChatModelMap } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
let step = 0; // step=1时,表示开始了流响应
|
||||
const stream = new PassThrough();
|
||||
stream.on('error', () => {
|
||||
console.log('error: ', 'stream error');
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('close', () => {
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('error', () => {
|
||||
console.log('error: ', 'request error');
|
||||
stream.destroy();
|
||||
});
|
||||
|
||||
try {
|
||||
const {
|
||||
prompts,
|
||||
modelId,
|
||||
isStream = true
|
||||
} = req.body as {
|
||||
prompts: ChatItemSimpleType[];
|
||||
modelId: string;
|
||||
isStream: boolean;
|
||||
};
|
||||
|
||||
if (!prompts || !modelId) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
if (!Array.isArray(prompts)) {
|
||||
throw new Error('prompts is not array');
|
||||
}
|
||||
if (prompts.length > 30 || prompts.length === 0) {
|
||||
throw new Error('prompts length range 1-30');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
let startTime = Date.now();
|
||||
|
||||
const { apiKey, userId } = await authOpenApiKey(req);
|
||||
|
||||
const model = await Model.findOne({
|
||||
_id: modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
if (!model) {
|
||||
throw new Error('无权使用该模型');
|
||||
}
|
||||
|
||||
const modelConstantsData = ChatModelMap[model.chat.chatModel];
|
||||
|
||||
// 如果有系统提示词,自动插入
|
||||
if (model.chat.systemPrompt) {
|
||||
prompts.unshift({
|
||||
obj: 'SYSTEM',
|
||||
value: model.chat.systemPrompt
|
||||
});
|
||||
}
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
||||
});
|
||||
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: 40000,
|
||||
responseType: isStream ? 'stream' : 'json',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
let responseContent = '';
|
||||
|
||||
if (isStream) {
|
||||
step = 1;
|
||||
const streamResponse = await gpt35StreamResponse({
|
||||
res,
|
||||
stream,
|
||||
chatResponse
|
||||
});
|
||||
responseContent = streamResponse.responseContent;
|
||||
} else {
|
||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
||||
jsonRes(res, {
|
||||
data: responseContent
|
||||
});
|
||||
}
|
||||
|
||||
// 只有使用平台的 key 才计费
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
// 直接结束流
|
||||
console.log('error,结束');
|
||||
stream.destroy();
|
||||
} else {
|
||||
res.status(500);
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase, Model } from '@/service/mongo';
|
||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { authOpenApiKey } from '@/service/utils/auth';
|
||||
import { resStreamResponse, modelServiceToolMap } from '@/service/utils/chat';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import { ChatModelMap, ModelVectorSearchModeMap, OpenAiChatEnum } from '@/constants/model';
|
||||
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
import { searchKb } from '@/service/plugins/searchKb';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
|
|
@ -57,20 +57,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
|
||||
console.log('laf gpt start');
|
||||
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
// 请求一次 chatgpt 拆解需求
|
||||
const promptResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: OpenAiChatEnum.GPT35,
|
||||
temperature: 0,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: `服务端逻辑生成器.根据用户输入的需求,拆解成 laf 云函数实现的步骤,只返回步骤,按格式返回步骤: 1.\n2.\n3.\n ......
|
||||
const { responseText: resolveText, totalTokens: resolveTokens } = await modelServiceToolMap[
|
||||
model.chat.chatModel
|
||||
].chatCompletion({
|
||||
apiKey,
|
||||
temperature: 0,
|
||||
messages: [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `服务端逻辑生成器.根据用户输入的需求,拆解成 laf 云函数实现的步骤,只返回步骤,按格式返回步骤: 1.\n2.\n3.\n ......
|
||||
下面是一些例子:
|
||||
一个 hello world 例子
|
||||
1. 返回字符串: "hello world"
|
||||
|
|
@ -103,35 +99,25 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
5. 获取当前时间,记录为 updateTime.
|
||||
6. 更新数据库数据,表为"blogs",更新符合 blogId 的记录的内容为{blogText, tags, updateTime}.
|
||||
7. 返回结果 "更新博客记录成功"`
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: prompt.value
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
timeout: 180000,
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
},
|
||||
{
|
||||
obj: ChatRoleEnum.Human,
|
||||
value: prompt.value
|
||||
}
|
||||
],
|
||||
stream: false
|
||||
});
|
||||
|
||||
const promptResolve = promptResponse.data.choices?.[0]?.message?.content || '';
|
||||
if (!promptResolve) {
|
||||
throw new Error('gpt 异常');
|
||||
}
|
||||
|
||||
prompt.value += ` ${promptResolve}`;
|
||||
prompt.value += ` ${resolveText}`;
|
||||
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
// 读取对话内容
|
||||
const prompts = [prompt];
|
||||
|
||||
// 获取向量匹配到的提示词
|
||||
const { searchPrompt } = await searchKb_openai({
|
||||
isPay: true,
|
||||
apiKey,
|
||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
|
||||
const { searchPrompt } = await searchKb({
|
||||
systemApiKey: apiKey,
|
||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
|
||||
text: prompt.value,
|
||||
model,
|
||||
userId
|
||||
|
|
@ -139,49 +125,41 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
|
||||
searchPrompt && prompts.unshift(searchPrompt);
|
||||
|
||||
// 控制上下文 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
||||
});
|
||||
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream
|
||||
},
|
||||
{
|
||||
timeout: 180000,
|
||||
responseType: isStream ? 'stream' : 'json',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
|
||||
let responseContent = '';
|
||||
// 发出请求
|
||||
const { streamResponse, responseMessages, responseText, totalTokens } =
|
||||
await modelServiceToolMap[model.chat.chatModel].chatCompletion({
|
||||
apiKey,
|
||||
temperature: +temperature,
|
||||
messages: prompts,
|
||||
stream: isStream
|
||||
});
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
let textLen = resolveText.length;
|
||||
let tokens = resolveTokens;
|
||||
|
||||
if (isStream) {
|
||||
step = 1;
|
||||
const streamResponse = await gpt35StreamResponse({
|
||||
const { finishMessages, totalTokens } = await resStreamResponse({
|
||||
model: model.chat.chatModel,
|
||||
res,
|
||||
stream,
|
||||
chatResponse
|
||||
chatResponse: streamResponse,
|
||||
prompts
|
||||
});
|
||||
responseContent = streamResponse.responseContent;
|
||||
textLen += finishMessages.map((item) => item.value).join('').length;
|
||||
tokens += totalTokens;
|
||||
} else {
|
||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
||||
textLen += responseMessages.map((item) => item.value).join('').length;
|
||||
tokens += totalTokens;
|
||||
jsonRes(res, {
|
||||
data: responseContent
|
||||
data: responseText
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -191,7 +169,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||
isPay: true,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
textLen,
|
||||
tokens
|
||||
});
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
|
|
|
|||
|
|
@ -1,159 +0,0 @@
|
|||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { connectToDatabase, Model } from '@/service/mongo';
|
||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { jsonRes } from '@/service/response';
|
||||
import { PassThrough } from 'stream';
|
||||
import {
|
||||
ChatModelMap,
|
||||
ModelVectorSearchModeMap,
|
||||
ModelVectorSearchModeEnum
|
||||
} from '@/constants/model';
|
||||
import { pushChatBill } from '@/service/events/pushBill';
|
||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
||||
|
||||
/* 发送提示词 */
|
||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||
let step = 0; // step=1时,表示开始了流响应
|
||||
const stream = new PassThrough();
|
||||
stream.on('error', () => {
|
||||
console.log('error: ', 'stream error');
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('close', () => {
|
||||
stream.destroy();
|
||||
});
|
||||
res.on('error', () => {
|
||||
console.log('error: ', 'request error');
|
||||
stream.destroy();
|
||||
});
|
||||
|
||||
try {
|
||||
const {
|
||||
prompts,
|
||||
modelId,
|
||||
isStream = true
|
||||
} = req.body as {
|
||||
prompts: ChatItemSimpleType[];
|
||||
modelId: string;
|
||||
isStream: boolean;
|
||||
};
|
||||
|
||||
if (!prompts || !modelId) {
|
||||
throw new Error('缺少参数');
|
||||
}
|
||||
if (!Array.isArray(prompts)) {
|
||||
throw new Error('prompts is not array');
|
||||
}
|
||||
if (prompts.length > 30 || prompts.length === 0) {
|
||||
throw new Error('prompts length range 1-30');
|
||||
}
|
||||
|
||||
await connectToDatabase();
|
||||
let startTime = Date.now();
|
||||
|
||||
/* 凭证校验 */
|
||||
const { apiKey, userId } = await authOpenApiKey(req);
|
||||
|
||||
const model = await Model.findOne({
|
||||
_id: modelId,
|
||||
userId
|
||||
});
|
||||
|
||||
if (!model) {
|
||||
throw new Error('无权使用该模型');
|
||||
}
|
||||
|
||||
const modelConstantsData = ChatModelMap[model.chat.chatModel];
|
||||
|
||||
// 获取向量匹配到的提示词
|
||||
const { code, searchPrompt } = await searchKb_openai({
|
||||
isPay: true,
|
||||
apiKey,
|
||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
|
||||
text: prompts[prompts.length - 1].value,
|
||||
model,
|
||||
userId
|
||||
});
|
||||
|
||||
// search result is empty
|
||||
if (code === 201) {
|
||||
return res.send(searchPrompt?.value);
|
||||
}
|
||||
|
||||
searchPrompt && prompts.unshift(searchPrompt);
|
||||
|
||||
// 控制在 tokens 数量,防止超出
|
||||
const filterPrompts = openaiChatFilter({
|
||||
model: model.chat.chatModel,
|
||||
prompts,
|
||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
||||
});
|
||||
|
||||
// console.log(filterPrompts);
|
||||
// 计算温度
|
||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||
2
|
||||
);
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
// 发出请求
|
||||
const chatResponse = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model: model.chat.chatModel,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: filterPrompts,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream: isStream,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: 180000,
|
||||
responseType: isStream ? 'stream' : 'json',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
|
||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||
|
||||
let responseContent = '';
|
||||
|
||||
if (isStream) {
|
||||
step = 1;
|
||||
const streamResponse = await gpt35StreamResponse({
|
||||
res,
|
||||
stream,
|
||||
chatResponse
|
||||
});
|
||||
responseContent = streamResponse.responseContent;
|
||||
} else {
|
||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
||||
jsonRes(res, {
|
||||
data: responseContent
|
||||
});
|
||||
}
|
||||
|
||||
pushChatBill({
|
||||
isPay: true,
|
||||
chatModel: model.chat.chatModel,
|
||||
userId,
|
||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
||||
});
|
||||
// jsonRes(res);
|
||||
} catch (err: any) {
|
||||
if (step === 1) {
|
||||
// 直接结束流
|
||||
console.log('error,结束');
|
||||
stream.destroy();
|
||||
} else {
|
||||
res.status(500);
|
||||
jsonRes(res, {
|
||||
code: 500,
|
||||
error: err
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
export enum OpenAiTuneStatusEnum {
|
||||
cancelled = 'cancelled',
|
||||
succeeded = 'succeeded',
|
||||
pending = 'pending'
|
||||
}
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
import { SplitData } from '@/service/mongo';
|
||||
import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { axiosConfig } from '@/service/utils/tools';
|
||||
import { getOpenApiKey } from '../utils/openai';
|
||||
import type { ChatCompletionRequestMessage } from 'openai';
|
||||
import { getApiKey } from '../utils/auth';
|
||||
import { OpenAiChatEnum } from '@/constants/model';
|
||||
import { pushSplitDataBill } from '@/service/events/pushBill';
|
||||
import { generateVector } from './generateVector';
|
||||
import { openaiError2 } from '../errorCode';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { ModelSplitDataSchema } from '@/types/mongoSchema';
|
||||
import { modelServiceToolMap } from '../utils/chat';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
|
||||
export async function generateQA(next = false): Promise<any> {
|
||||
if (process.env.queueTask !== '1') {
|
||||
|
|
@ -47,11 +46,11 @@ export async function generateQA(next = false): Promise<any> {
|
|||
|
||||
// 获取 openapi Key
|
||||
let userApiKey = '',
|
||||
systemKey = '';
|
||||
systemApiKey = '';
|
||||
try {
|
||||
const key = await getOpenApiKey(dataItem.userId);
|
||||
const key = await getApiKey({ model: OpenAiChatEnum.GPT35, userId: dataItem.userId });
|
||||
userApiKey = key.userApiKey;
|
||||
systemKey = key.systemKey;
|
||||
systemApiKey = key.systemApiKey;
|
||||
} catch (error: any) {
|
||||
if (error?.code === 501) {
|
||||
// 余额不够了, 清空该记录
|
||||
|
|
@ -69,55 +68,44 @@ export async function generateQA(next = false): Promise<any> {
|
|||
|
||||
const startTime = Date.now();
|
||||
|
||||
// 获取 openai 请求实例
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
||||
const systemPrompt: ChatCompletionRequestMessage = {
|
||||
role: 'system',
|
||||
content: `你是出题人
|
||||
${dataItem.prompt || '下面是"一段长文本"'}
|
||||
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
|
||||
A1:
|
||||
Q2:
|
||||
A2:
|
||||
...`
|
||||
};
|
||||
|
||||
// 请求 chatgpt 获取回答
|
||||
const response = await Promise.allSettled(
|
||||
textList.map((text) =>
|
||||
chatAPI
|
||||
.createChatCompletion(
|
||||
{
|
||||
model: OpenAiChatEnum.GPT35,
|
||||
temperature: 0.8,
|
||||
n: 1,
|
||||
messages: [
|
||||
systemPrompt,
|
||||
{
|
||||
role: 'user',
|
||||
content: text
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
timeout: 180000,
|
||||
...axiosConfig()
|
||||
}
|
||||
)
|
||||
.then((res) => {
|
||||
const rawContent = res?.data.choices[0].message?.content || ''; // chatgpt 原本的回复
|
||||
const result = formatSplitText(res?.data.choices[0].message?.content || ''); // 格式化后的QA对
|
||||
modelServiceToolMap[OpenAiChatEnum.GPT35]
|
||||
.chatCompletion({
|
||||
apiKey: userApiKey || systemApiKey,
|
||||
temperature: 0.8,
|
||||
messages: [
|
||||
{
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `你是出题人
|
||||
${dataItem.prompt || '下面是"一段长文本"'}
|
||||
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
|
||||
A1:
|
||||
Q2:
|
||||
A2:
|
||||
...`
|
||||
},
|
||||
{
|
||||
obj: 'Human',
|
||||
value: text
|
||||
}
|
||||
],
|
||||
stream: false
|
||||
})
|
||||
.then(({ totalTokens, responseText, responseMessages }) => {
|
||||
const result = formatSplitText(responseText); // 格式化后的QA对
|
||||
console.log(`split result length: `, result.length);
|
||||
// 计费
|
||||
pushSplitDataBill({
|
||||
isPay: !userApiKey && result.length > 0,
|
||||
userId: dataItem.userId,
|
||||
type: 'QA',
|
||||
text: systemPrompt.content + text + rawContent,
|
||||
tokenLen: res.data.usage?.total_tokens || 0
|
||||
textLen: responseMessages.map((item) => item.value).join('').length,
|
||||
totalTokens
|
||||
});
|
||||
return {
|
||||
rawContent,
|
||||
rawContent: responseText,
|
||||
result
|
||||
};
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import { openaiCreateEmbedding, getOpenApiKey } from '../utils/openai';
|
||||
import { openaiCreateEmbedding } from '../utils/chat/openai';
|
||||
import { getApiKey } from '../utils/auth';
|
||||
import { openaiError2 } from '../errorCode';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { embeddingModel } from '@/constants/model';
|
||||
|
||||
export async function generateVector(next = false): Promise<any> {
|
||||
if (process.env.queueTask !== '1') {
|
||||
|
|
@ -40,11 +42,11 @@ export async function generateVector(next = false): Promise<any> {
|
|||
dataId = dataItem.id;
|
||||
|
||||
// 获取 openapi Key
|
||||
let userApiKey, systemKey;
|
||||
let userApiKey, systemApiKey;
|
||||
try {
|
||||
const res = await getOpenApiKey(dataItem.userId);
|
||||
const res = await getApiKey({ model: embeddingModel, userId: dataItem.userId });
|
||||
userApiKey = res.userApiKey;
|
||||
systemKey = res.systemKey;
|
||||
systemApiKey = res.systemApiKey;
|
||||
} catch (error: any) {
|
||||
if (error?.code === 501) {
|
||||
await PgClient.delete('modelData', {
|
||||
|
|
@ -61,8 +63,8 @@ export async function generateVector(next = false): Promise<any> {
|
|||
const { vector } = await openaiCreateEmbedding({
|
||||
text: dataItem.q,
|
||||
userId: dataItem.userId,
|
||||
isPay: !userApiKey,
|
||||
apiKey: userApiKey || systemKey
|
||||
userApiKey,
|
||||
systemApiKey
|
||||
});
|
||||
|
||||
// 更新 pg 向量和状态数据
|
||||
|
|
|
|||
|
|
@ -1,60 +1,54 @@
|
|||
import { connectToDatabase, Bill, User } from '../mongo';
|
||||
import { ChatModelMap, OpenAiChatEnum, ChatModelType, embeddingModel } from '@/constants/model';
|
||||
import { BillTypeEnum } from '@/constants/user';
|
||||
import { countChatTokens } from '@/utils/tools';
|
||||
|
||||
export const pushChatBill = async ({
|
||||
isPay,
|
||||
chatModel,
|
||||
userId,
|
||||
chatId,
|
||||
messages
|
||||
textLen,
|
||||
tokens
|
||||
}: {
|
||||
isPay: boolean;
|
||||
chatModel: ChatModelType;
|
||||
userId: string;
|
||||
chatId?: '' | string;
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
|
||||
textLen: number;
|
||||
tokens: number;
|
||||
}) => {
|
||||
console.log(`chat generate success. text len: ${textLen}. token len: ${tokens}. pay:${isPay}`);
|
||||
if (!isPay) return;
|
||||
|
||||
let billId = '';
|
||||
|
||||
try {
|
||||
// 计算 token 数量
|
||||
const tokens = countChatTokens({ model: chatModel, messages });
|
||||
const text = messages.map((item) => item.content).join('');
|
||||
await connectToDatabase();
|
||||
|
||||
console.log(
|
||||
`chat generate success. text len: ${text.length}. token len: ${tokens}. pay:${isPay}`
|
||||
);
|
||||
// 计算价格
|
||||
const unitPrice = ChatModelMap[chatModel]?.price || 5;
|
||||
const price = unitPrice * tokens;
|
||||
|
||||
if (isPay) {
|
||||
await connectToDatabase();
|
||||
try {
|
||||
// 插入 Bill 记录
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type: 'chat',
|
||||
modelName: chatModel,
|
||||
chatId: chatId ? chatId : undefined,
|
||||
textLen,
|
||||
tokenLen: tokens,
|
||||
price
|
||||
});
|
||||
billId = res._id;
|
||||
|
||||
// 计算价格
|
||||
const unitPrice = ChatModelMap[chatModel]?.price || 5;
|
||||
const price = unitPrice * tokens;
|
||||
|
||||
try {
|
||||
// 插入 Bill 记录
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type: 'chat',
|
||||
modelName: chatModel,
|
||||
chatId: chatId ? chatId : undefined,
|
||||
textLen: text.length,
|
||||
tokenLen: tokens,
|
||||
price
|
||||
});
|
||||
billId = res._id;
|
||||
|
||||
// 账号扣费
|
||||
await User.findByIdAndUpdate(userId, {
|
||||
$inc: { balance: -price }
|
||||
});
|
||||
} catch (error) {
|
||||
console.log('创建账单失败:', error);
|
||||
billId && Bill.findByIdAndDelete(billId);
|
||||
}
|
||||
// 账号扣费
|
||||
await User.findByIdAndUpdate(userId, {
|
||||
$inc: { balance: -price }
|
||||
});
|
||||
} catch (error) {
|
||||
console.log('创建账单失败:', error);
|
||||
billId && Bill.findByIdAndDelete(billId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
|
|
@ -64,54 +58,49 @@ export const pushChatBill = async ({
|
|||
export const pushSplitDataBill = async ({
|
||||
isPay,
|
||||
userId,
|
||||
tokenLen,
|
||||
text,
|
||||
totalTokens,
|
||||
textLen,
|
||||
type
|
||||
}: {
|
||||
isPay: boolean;
|
||||
userId: string;
|
||||
tokenLen: number;
|
||||
text: string;
|
||||
totalTokens: number;
|
||||
textLen: number;
|
||||
type: `${BillTypeEnum}`;
|
||||
}) => {
|
||||
await connectToDatabase();
|
||||
console.log(
|
||||
`splitData generate success. text len: ${textLen}. token len: ${totalTokens}. pay:${isPay}`
|
||||
);
|
||||
if (!isPay) return;
|
||||
|
||||
let billId;
|
||||
|
||||
try {
|
||||
console.log(
|
||||
`splitData generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
|
||||
);
|
||||
await connectToDatabase();
|
||||
|
||||
if (isPay) {
|
||||
try {
|
||||
// 获取模型单价格, 都是用 gpt35 拆分
|
||||
const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35]?.price || 3;
|
||||
// 计算价格
|
||||
const price = unitPrice * tokenLen;
|
||||
// 获取模型单价格, 都是用 gpt35 拆分
|
||||
const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35].price || 3;
|
||||
// 计算价格
|
||||
const price = unitPrice * totalTokens;
|
||||
|
||||
// 插入 Bill 记录
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type,
|
||||
modelName: OpenAiChatEnum.GPT35,
|
||||
textLen: text.length,
|
||||
tokenLen,
|
||||
price
|
||||
});
|
||||
billId = res._id;
|
||||
// 插入 Bill 记录
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type,
|
||||
modelName: OpenAiChatEnum.GPT35,
|
||||
textLen,
|
||||
tokenLen: totalTokens,
|
||||
price
|
||||
});
|
||||
billId = res._id;
|
||||
|
||||
// 账号扣费
|
||||
await User.findByIdAndUpdate(userId, {
|
||||
$inc: { balance: -price }
|
||||
});
|
||||
} catch (error) {
|
||||
console.log('创建账单失败:', error);
|
||||
billId && Bill.findByIdAndDelete(billId);
|
||||
}
|
||||
}
|
||||
// 账号扣费
|
||||
await User.findByIdAndUpdate(userId, {
|
||||
$inc: { balance: -price }
|
||||
});
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
console.log('创建账单失败:', error);
|
||||
billId && Bill.findByIdAndDelete(billId);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -126,41 +115,40 @@ export const pushGenerateVectorBill = async ({
|
|||
text: string;
|
||||
tokenLen: number;
|
||||
}) => {
|
||||
await connectToDatabase();
|
||||
console.log(
|
||||
`vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
|
||||
);
|
||||
if (!isPay) return;
|
||||
|
||||
let billId;
|
||||
|
||||
try {
|
||||
console.log(
|
||||
`vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
|
||||
);
|
||||
await connectToDatabase();
|
||||
|
||||
if (isPay) {
|
||||
try {
|
||||
const unitPrice = 0.4;
|
||||
// 计算价格. 至少为1
|
||||
let price = unitPrice * tokenLen;
|
||||
price = price > 1 ? price : 1;
|
||||
try {
|
||||
const unitPrice = 0.4;
|
||||
// 计算价格. 至少为1
|
||||
let price = unitPrice * tokenLen;
|
||||
price = price > 1 ? price : 1;
|
||||
|
||||
// 插入 Bill 记录
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type: BillTypeEnum.vector,
|
||||
modelName: embeddingModel,
|
||||
textLen: text.length,
|
||||
tokenLen,
|
||||
price
|
||||
});
|
||||
billId = res._id;
|
||||
// 插入 Bill 记录
|
||||
const res = await Bill.create({
|
||||
userId,
|
||||
type: BillTypeEnum.vector,
|
||||
modelName: embeddingModel,
|
||||
textLen: text.length,
|
||||
tokenLen,
|
||||
price
|
||||
});
|
||||
billId = res._id;
|
||||
|
||||
// 账号扣费
|
||||
await User.findByIdAndUpdate(userId, {
|
||||
$inc: { balance: -price }
|
||||
});
|
||||
} catch (error) {
|
||||
console.log('创建账单失败:', error);
|
||||
billId && Bill.findByIdAndDelete(billId);
|
||||
}
|
||||
// 账号扣费
|
||||
await User.findByIdAndUpdate(userId, {
|
||||
$inc: { balance: -price }
|
||||
});
|
||||
} catch (error) {
|
||||
console.log('创建账单失败:', error);
|
||||
billId && Bill.findByIdAndDelete(billId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import { Schema, model, models, Model } from 'mongoose';
|
||||
import { ChatSchema as ChatType } from '@/types/mongoSchema';
|
||||
import { ChatRoleMap } from '@/constants/chat';
|
||||
|
||||
const ChatSchema = new Schema({
|
||||
userId: {
|
||||
|
|
@ -36,7 +37,7 @@ const ChatSchema = new Schema({
|
|||
obj: {
|
||||
type: String,
|
||||
required: true,
|
||||
enum: ['Human', 'AI', 'SYSTEM']
|
||||
enum: Object.keys(ChatRoleMap)
|
||||
},
|
||||
value: {
|
||||
type: String,
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
import { openaiCreateEmbedding } from '../utils/openai';
|
||||
import { PgClient } from '@/service/pg';
|
||||
import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model';
|
||||
import { ModelSchema } from '@/types/mongoSchema';
|
||||
import { systemPromptFilter } from '../utils/tools';
|
||||
import { openaiCreateEmbedding } from '../utils/chat/openai';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
import { sliceTextByToken } from '@/utils/chat';
|
||||
|
||||
/**
|
||||
* use openai embedding search kb
|
||||
*/
|
||||
export const searchKb_openai = async ({
|
||||
apiKey,
|
||||
isPay = true,
|
||||
export const searchKb = async ({
|
||||
userApiKey,
|
||||
systemApiKey,
|
||||
text,
|
||||
similarity = 0.2,
|
||||
model,
|
||||
userId
|
||||
}: {
|
||||
apiKey: string;
|
||||
isPay: boolean;
|
||||
userApiKey?: string;
|
||||
systemApiKey: string;
|
||||
text: string;
|
||||
model: ModelSchema;
|
||||
userId: string;
|
||||
|
|
@ -24,7 +25,7 @@ export const searchKb_openai = async ({
|
|||
}): Promise<{
|
||||
code: 200 | 201;
|
||||
searchPrompt?: {
|
||||
obj: 'Human' | 'AI' | 'SYSTEM';
|
||||
obj: `${ChatRoleEnum}`;
|
||||
value: string;
|
||||
};
|
||||
}> => {
|
||||
|
|
@ -32,8 +33,8 @@ export const searchKb_openai = async ({
|
|||
|
||||
// 获取提示词的向量
|
||||
const { vector: promptVector } = await openaiCreateEmbedding({
|
||||
isPay,
|
||||
apiKey,
|
||||
userApiKey,
|
||||
systemApiKey,
|
||||
userId,
|
||||
text
|
||||
});
|
||||
|
|
@ -61,7 +62,7 @@ export const searchKb_openai = async ({
|
|||
return {
|
||||
code: 201,
|
||||
searchPrompt: {
|
||||
obj: 'AI',
|
||||
obj: ChatRoleEnum.AI,
|
||||
value: '对不起,你的问题不在知识库中。'
|
||||
}
|
||||
};
|
||||
|
|
@ -72,7 +73,7 @@ export const searchKb_openai = async ({
|
|||
code: 200,
|
||||
searchPrompt: model.chat.systemPrompt
|
||||
? {
|
||||
obj: 'SYSTEM',
|
||||
obj: ChatRoleEnum.System,
|
||||
value: model.chat.systemPrompt
|
||||
}
|
||||
: undefined
|
||||
|
|
@ -81,16 +82,16 @@ export const searchKb_openai = async ({
|
|||
|
||||
// 有匹配情况下,system 添加知识库内容。
|
||||
// 系统提示词过滤,最多 65% tokens
|
||||
const filterSystemPrompt = systemPromptFilter({
|
||||
const filterSystemPrompt = sliceTextByToken({
|
||||
model: model.chat.chatModel,
|
||||
prompts: systemPrompts,
|
||||
maxTokens: Math.floor(modelConstantsData.contextMaxToken * 0.65)
|
||||
text: systemPrompts.join('\n'),
|
||||
length: Math.floor(modelConstantsData.contextMaxToken * 0.65)
|
||||
});
|
||||
|
||||
return {
|
||||
code: 200,
|
||||
searchPrompt: {
|
||||
obj: 'SYSTEM',
|
||||
obj: ChatRoleEnum.System,
|
||||
value: `
|
||||
${model.chat.systemPrompt}
|
||||
${
|
||||
|
|
@ -1,14 +1,18 @@
|
|||
import { Configuration, OpenAIApi } from 'openai';
|
||||
import type { NextApiRequest } from 'next';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { Chat, Model, OpenApi, User } from '../mongo';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { getOpenApiKey } from './openai';
|
||||
import type { ChatItemSimpleType } from '@/types/chat';
|
||||
import mongoose from 'mongoose';
|
||||
import { defaultModel } from '@/constants/model';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
import {
|
||||
ChatModelType,
|
||||
OpenAiChatEnum,
|
||||
embeddingModel,
|
||||
EmbeddingModelType
|
||||
} from '@/constants/model';
|
||||
|
||||
/* 校验 token */
|
||||
export const authToken = (token?: string): Promise<string> => {
|
||||
|
|
@ -29,13 +33,63 @@ export const authToken = (token?: string): Promise<string> => {
|
|||
});
|
||||
};
|
||||
|
||||
export const getOpenAIApi = (apiKey: string) => {
|
||||
const configuration = new Configuration({
|
||||
apiKey,
|
||||
basePath: process.env.OPENAI_BASE_URL
|
||||
});
|
||||
/* 获取 api 请求的 key */
|
||||
export const getApiKey = async ({
|
||||
model,
|
||||
userId
|
||||
}: {
|
||||
model: ChatModelType | EmbeddingModelType;
|
||||
userId: string;
|
||||
}) => {
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '找不到用户'
|
||||
});
|
||||
}
|
||||
|
||||
return new OpenAIApi(configuration);
|
||||
const keyMap = {
|
||||
[OpenAiChatEnum.GPT35]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
},
|
||||
[OpenAiChatEnum.GPT4]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
},
|
||||
[OpenAiChatEnum.GPT432k]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
},
|
||||
[embeddingModel]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
}
|
||||
};
|
||||
|
||||
// 有自己的key
|
||||
if (keyMap[model].userApiKey) {
|
||||
return {
|
||||
user,
|
||||
userApiKey: keyMap[model].userApiKey,
|
||||
systemApiKey: ''
|
||||
};
|
||||
}
|
||||
|
||||
// 平台账号余额校验
|
||||
if (formatPrice(user.balance) <= 0) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '账号余额不足'
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
user,
|
||||
userApiKey: '',
|
||||
systemApiKey: keyMap[model].systemApiKey
|
||||
};
|
||||
};
|
||||
|
||||
// 模型使用权校验
|
||||
|
|
@ -122,11 +176,11 @@ export const authChat = async ({
|
|||
]);
|
||||
}
|
||||
// 获取 user 的 apiKey
|
||||
const { userApiKey, systemKey } = await getOpenApiKey(userId);
|
||||
const { userApiKey, systemApiKey } = await getApiKey({ model: model.chat.chatModel, userId });
|
||||
|
||||
return {
|
||||
userApiKey,
|
||||
systemKey,
|
||||
systemApiKey,
|
||||
content,
|
||||
userId,
|
||||
model,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,155 @@
|
|||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { modelToolMap } from '@/utils/chat';
|
||||
import type { ChatModelType } from '@/constants/model';
|
||||
import { ChatRoleEnum, SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
||||
import { OpenAiChatEnum } from '@/constants/model';
|
||||
import { chatResponse, openAiStreamResponse } from './openai';
|
||||
import type { NextApiResponse } from 'next';
|
||||
import type { PassThrough } from 'stream';
|
||||
|
||||
export type ChatCompletionType = {
|
||||
apiKey: string;
|
||||
temperature: number;
|
||||
messages: ChatItemSimpleType[];
|
||||
stream: boolean;
|
||||
};
|
||||
export type StreamResponseType = {
|
||||
stream: PassThrough;
|
||||
chatResponse: any;
|
||||
prompts: ChatItemSimpleType[];
|
||||
};
|
||||
|
||||
export const modelServiceToolMap = {
|
||||
[OpenAiChatEnum.GPT35]: {
|
||||
chatCompletion: (data: ChatCompletionType) =>
|
||||
chatResponse({ model: OpenAiChatEnum.GPT35, ...data }),
|
||||
streamResponse: (data: StreamResponseType) =>
|
||||
openAiStreamResponse({
|
||||
model: OpenAiChatEnum.GPT35,
|
||||
...data
|
||||
})
|
||||
},
|
||||
[OpenAiChatEnum.GPT4]: {
|
||||
chatCompletion: (data: ChatCompletionType) =>
|
||||
chatResponse({ model: OpenAiChatEnum.GPT4, ...data }),
|
||||
streamResponse: (data: StreamResponseType) =>
|
||||
openAiStreamResponse({
|
||||
model: OpenAiChatEnum.GPT4,
|
||||
...data
|
||||
})
|
||||
},
|
||||
[OpenAiChatEnum.GPT432k]: {
|
||||
chatCompletion: (data: ChatCompletionType) =>
|
||||
chatResponse({ model: OpenAiChatEnum.GPT432k, ...data }),
|
||||
streamResponse: (data: StreamResponseType) =>
|
||||
openAiStreamResponse({
|
||||
model: OpenAiChatEnum.GPT432k,
|
||||
...data
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
/* delete invalid symbol */
|
||||
const simplifyStr = (str: string) =>
|
||||
str
|
||||
.replace(/\n+/g, '\n') // 连续空行
|
||||
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
||||
.trim();
|
||||
|
||||
/* 聊天上下文 tokens 截断 */
|
||||
export const ChatContextFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: ChatModelType;
|
||||
prompts: ChatItemSimpleType[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
let rawTextLen = 0;
|
||||
const formatPrompts = prompts.map<ChatItemSimpleType>((item) => {
|
||||
const val = simplifyStr(item.value);
|
||||
rawTextLen += val.length;
|
||||
return {
|
||||
obj: item.obj,
|
||||
value: val
|
||||
};
|
||||
});
|
||||
|
||||
// 长度太小时,不需要进行 token 截断
|
||||
if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) {
|
||||
return formatPrompts;
|
||||
}
|
||||
|
||||
// 根据 tokens 截断内容
|
||||
const chats: ChatItemSimpleType[] = [];
|
||||
let systemPrompt: ChatItemSimpleType | null = null;
|
||||
|
||||
// System 词保留
|
||||
if (formatPrompts[0].obj === ChatRoleEnum.System) {
|
||||
const prompt = formatPrompts.shift();
|
||||
if (prompt) {
|
||||
systemPrompt = prompt;
|
||||
}
|
||||
}
|
||||
|
||||
let messages: ChatItemSimpleType[] = [];
|
||||
|
||||
// 从后往前截取对话内容
|
||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
||||
chats.unshift(formatPrompts[i]);
|
||||
|
||||
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
|
||||
|
||||
const tokens = modelToolMap[model].countTokens({
|
||||
messages
|
||||
});
|
||||
|
||||
/* 整体 tokens 超出范围 */
|
||||
if (tokens >= maxTokens) {
|
||||
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
};
|
||||
|
||||
/* stream response */
|
||||
export const resStreamResponse = async ({
|
||||
model,
|
||||
res,
|
||||
stream,
|
||||
chatResponse,
|
||||
systemPrompt,
|
||||
prompts
|
||||
}: StreamResponseType & {
|
||||
model: ChatModelType;
|
||||
res: NextApiResponse;
|
||||
systemPrompt?: string;
|
||||
}) => {
|
||||
// 创建响应流
|
||||
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
stream.pipe(res);
|
||||
|
||||
const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[
|
||||
model
|
||||
].streamResponse({
|
||||
chatResponse,
|
||||
stream,
|
||||
prompts
|
||||
});
|
||||
|
||||
// push system prompt
|
||||
!stream.destroyed &&
|
||||
systemPrompt &&
|
||||
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
|
||||
|
||||
// close stream
|
||||
!stream.destroyed && stream.push(null);
|
||||
stream.destroy();
|
||||
|
||||
return { responseContent, totalTokens, finishMessages };
|
||||
};
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
import { Configuration, OpenAIApi } from 'openai';
|
||||
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
||||
import { axiosConfig } from '../tools';
|
||||
import { ChatModelMap, embeddingModel, OpenAiChatEnum } from '@/constants/model';
|
||||
import { pushGenerateVectorBill } from '../../events/pushBill';
|
||||
import { adaptChatItem_openAI } from '@/utils/chat/openai';
|
||||
import { modelToolMap } from '@/utils/chat';
|
||||
import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
|
||||
export const getOpenAIApi = (apiKey: string) => {
|
||||
const configuration = new Configuration({
|
||||
apiKey,
|
||||
basePath: process.env.OPENAI_BASE_URL
|
||||
});
|
||||
|
||||
return new OpenAIApi(configuration);
|
||||
};
|
||||
|
||||
/* 获取向量 */
|
||||
export const openaiCreateEmbedding = async ({
|
||||
userApiKey,
|
||||
systemApiKey,
|
||||
userId,
|
||||
text
|
||||
}: {
|
||||
userApiKey?: string;
|
||||
systemApiKey: string;
|
||||
userId: string;
|
||||
text: string;
|
||||
}) => {
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemApiKey);
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const res = await chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: embeddingModel,
|
||||
input: text
|
||||
},
|
||||
{
|
||||
timeout: 60000,
|
||||
...axiosConfig()
|
||||
}
|
||||
)
|
||||
.then((res) => ({
|
||||
tokenLen: res.data.usage.total_tokens || 0,
|
||||
vector: res.data.data?.[0]?.embedding || []
|
||||
}));
|
||||
|
||||
pushGenerateVectorBill({
|
||||
isPay: !userApiKey,
|
||||
userId,
|
||||
text,
|
||||
tokenLen: res.tokenLen
|
||||
});
|
||||
|
||||
return {
|
||||
vector: res.vector,
|
||||
chatAPI
|
||||
};
|
||||
};
|
||||
|
||||
/* 模型对话 */
|
||||
export const chatResponse = async ({
|
||||
model,
|
||||
apiKey,
|
||||
temperature,
|
||||
messages,
|
||||
stream
|
||||
}: ChatCompletionType & { model: `${OpenAiChatEnum}` }) => {
|
||||
const filterMessages = ChatContextFilter({
|
||||
model,
|
||||
prompts: messages,
|
||||
maxTokens: Math.ceil(ChatModelMap[model].contextMaxToken * 0.9)
|
||||
});
|
||||
|
||||
const adaptMessages = adaptChatItem_openAI({ messages: filterMessages });
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
const response = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: adaptMessages,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: stream ? 40000 : 240000,
|
||||
responseType: stream ? 'stream' : 'json',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
|
||||
let responseText = '';
|
||||
let totalTokens = 0;
|
||||
|
||||
// adapt data
|
||||
if (!stream) {
|
||||
responseText = response.data.choices[0].message?.content || '';
|
||||
totalTokens = response.data.usage?.total_tokens || 0;
|
||||
}
|
||||
|
||||
return {
|
||||
streamResponse: response,
|
||||
responseMessages: filterMessages.concat({ obj: 'AI', value: responseText }),
|
||||
responseText,
|
||||
totalTokens
|
||||
};
|
||||
};
|
||||
|
||||
/* openai stream response */
|
||||
export const openAiStreamResponse = async ({
|
||||
model,
|
||||
stream,
|
||||
chatResponse,
|
||||
prompts
|
||||
}: StreamResponseType & {
|
||||
model: `${OpenAiChatEnum}`;
|
||||
}) => {
|
||||
try {
|
||||
let responseContent = '';
|
||||
|
||||
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
|
||||
if (event.type !== 'event') return;
|
||||
const data = event.data;
|
||||
if (data === '[DONE]') return;
|
||||
try {
|
||||
const json = JSON.parse(data);
|
||||
const content: string = json?.choices?.[0].delta.content || '';
|
||||
responseContent += content;
|
||||
|
||||
!stream.destroyed && content && stream.push(content.replace(/\n/g, '<br/>'));
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
const decoder = new TextDecoder();
|
||||
const parser = createParser(onParse);
|
||||
for await (const chunk of chatResponse.data as any) {
|
||||
if (stream.destroyed) {
|
||||
// 流被中断了,直接忽略后面的内容
|
||||
break;
|
||||
}
|
||||
parser.feed(decoder.decode(chunk, { stream: true }));
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('pipe error', error);
|
||||
}
|
||||
|
||||
// count tokens
|
||||
const finishMessages = prompts.concat({
|
||||
obj: ChatRoleEnum.AI,
|
||||
value: responseContent
|
||||
});
|
||||
const totalTokens = modelToolMap[model].countTokens({
|
||||
messages: finishMessages
|
||||
});
|
||||
|
||||
return {
|
||||
responseContent,
|
||||
totalTokens,
|
||||
finishMessages
|
||||
};
|
||||
} catch (error) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
};
|
||||
|
|
@ -1,179 +0,0 @@
|
|||
import type { NextApiResponse } from 'next';
|
||||
import type { PassThrough } from 'stream';
|
||||
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
||||
import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { axiosConfig } from './tools';
|
||||
import { User } from '../models/user';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { embeddingModel } from '@/constants/model';
|
||||
import { pushGenerateVectorBill } from '../events/pushBill';
|
||||
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
||||
|
||||
/* 获取用户 api 的 openai 信息 */
|
||||
export const getUserApiOpenai = async (userId: string) => {
|
||||
const user = await User.findById(userId);
|
||||
|
||||
const userApiKey = user?.openaiKey;
|
||||
|
||||
if (!userApiKey) {
|
||||
return Promise.reject('缺少ApiKey, 无法请求');
|
||||
}
|
||||
|
||||
return {
|
||||
user,
|
||||
openai: getOpenAIApi(userApiKey),
|
||||
apiKey: userApiKey
|
||||
};
|
||||
};
|
||||
|
||||
/* 获取 open api key,如果用户没有自己的key,就用平台的,用平台记得加账单 */
|
||||
export const getOpenApiKey = async (userId: string) => {
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '找不到用户'
|
||||
});
|
||||
}
|
||||
|
||||
const userApiKey = user?.openaiKey;
|
||||
|
||||
// 有自己的key
|
||||
if (userApiKey) {
|
||||
return {
|
||||
user,
|
||||
userApiKey,
|
||||
systemKey: ''
|
||||
};
|
||||
}
|
||||
|
||||
// 平台账号余额校验
|
||||
if (formatPrice(user.balance) <= 0) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '账号余额不足'
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
user,
|
||||
userApiKey: '',
|
||||
systemKey: process.env.OPENAIKEY as string
|
||||
};
|
||||
};
|
||||
|
||||
/* 获取向量 */
|
||||
export const openaiCreateEmbedding = async ({
|
||||
isPay,
|
||||
userId,
|
||||
apiKey,
|
||||
text
|
||||
}: {
|
||||
isPay: boolean;
|
||||
userId: string;
|
||||
apiKey: string;
|
||||
text: string;
|
||||
}) => {
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const res = await chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: embeddingModel,
|
||||
input: text
|
||||
},
|
||||
{
|
||||
timeout: 60000,
|
||||
...axiosConfig()
|
||||
}
|
||||
)
|
||||
.then((res) => ({
|
||||
tokenLen: res.data.usage.total_tokens || 0,
|
||||
vector: res.data.data?.[0]?.embedding || []
|
||||
}));
|
||||
|
||||
pushGenerateVectorBill({
|
||||
isPay,
|
||||
userId,
|
||||
text,
|
||||
tokenLen: res.tokenLen
|
||||
});
|
||||
|
||||
return {
|
||||
vector: res.vector,
|
||||
chatAPI
|
||||
};
|
||||
};
|
||||
|
||||
/* gpt35 响应 */
|
||||
export const gpt35StreamResponse = ({
|
||||
res,
|
||||
stream,
|
||||
chatResponse,
|
||||
systemPrompt = ''
|
||||
}: {
|
||||
res: NextApiResponse;
|
||||
stream: PassThrough;
|
||||
chatResponse: any;
|
||||
systemPrompt?: string;
|
||||
}) =>
|
||||
new Promise<{ responseContent: string }>(async (resolve, reject) => {
|
||||
try {
|
||||
// 创建响应流
|
||||
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
stream.pipe(res);
|
||||
|
||||
let responseContent = '';
|
||||
|
||||
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
|
||||
if (event.type !== 'event') return;
|
||||
const data = event.data;
|
||||
if (data === '[DONE]') return;
|
||||
try {
|
||||
const json = JSON.parse(data);
|
||||
const content: string = json?.choices?.[0].delta.content || '';
|
||||
responseContent += content;
|
||||
|
||||
if (!stream.destroyed && content) {
|
||||
stream.push(content.replace(/\n/g, '<br/>'));
|
||||
}
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
const decoder = new TextDecoder();
|
||||
const parser = createParser(onParse);
|
||||
for await (const chunk of chatResponse.data as any) {
|
||||
if (stream.destroyed) {
|
||||
// 流被中断了,直接忽略后面的内容
|
||||
break;
|
||||
}
|
||||
parser.feed(decoder.decode(chunk, { stream: true }));
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('pipe error', error);
|
||||
}
|
||||
|
||||
// push system prompt
|
||||
!stream.destroyed &&
|
||||
systemPrompt &&
|
||||
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
|
||||
|
||||
// close stream
|
||||
!stream.destroyed && stream.push(null);
|
||||
stream.destroy();
|
||||
|
||||
resolve({
|
||||
responseContent
|
||||
});
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
}
|
||||
});
|
||||
|
|
@ -1,9 +1,5 @@
|
|||
import crypto from 'crypto';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { countChatTokens, sliceTextByToken } from '@/utils/tools';
|
||||
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
|
||||
import type { ChatModelType } from '@/constants/model';
|
||||
|
||||
/* 密码加密 */
|
||||
export const hashPassword = (psw: string) => {
|
||||
|
|
@ -30,92 +26,3 @@ export const axiosConfig = () => ({
|
|||
auth: process.env.OPENAI_BASE_URL_AUTH || ''
|
||||
}
|
||||
});
|
||||
|
||||
/* delete invalid symbol */
|
||||
const simplifyStr = (str: string) =>
|
||||
str
|
||||
.replace(/\n+/g, '\n') // 连续空行
|
||||
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
||||
.trim();
|
||||
|
||||
/* 聊天内容 tokens 截断 */
|
||||
export const openaiChatFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: ChatModelType;
|
||||
prompts: ChatItemSimpleType[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
// role map
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
|
||||
let rawTextLen = 0;
|
||||
const formatPrompts = prompts.map((item) => {
|
||||
const val = simplifyStr(item.value);
|
||||
rawTextLen += val.length;
|
||||
return {
|
||||
role: map[item.obj],
|
||||
content: val
|
||||
};
|
||||
});
|
||||
|
||||
// 长度太小时,不需要进行 token 截断
|
||||
if (rawTextLen < maxTokens * 0.5) {
|
||||
return formatPrompts;
|
||||
}
|
||||
|
||||
// 根据 tokens 截断内容
|
||||
const chats: ChatCompletionRequestMessage[] = [];
|
||||
let systemPrompt: ChatCompletionRequestMessage | null = null;
|
||||
|
||||
// System 词保留
|
||||
if (formatPrompts[0]?.role === 'system') {
|
||||
systemPrompt = formatPrompts.shift() as ChatCompletionRequestMessage;
|
||||
}
|
||||
|
||||
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
|
||||
|
||||
// 从后往前截取对话内容
|
||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
||||
chats.unshift(formatPrompts[i]);
|
||||
|
||||
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
|
||||
|
||||
const tokens = countChatTokens({
|
||||
model,
|
||||
messages
|
||||
});
|
||||
|
||||
/* 整体 tokens 超出范围 */
|
||||
if (tokens >= maxTokens) {
|
||||
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
};
|
||||
|
||||
/* system 内容截断. 相似度从高到低 */
|
||||
export const systemPromptFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
|
||||
prompts: string[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
const systemPrompt = prompts.join('\n');
|
||||
|
||||
return sliceTextByToken({
|
||||
model,
|
||||
text: systemPrompt,
|
||||
length: maxTokens
|
||||
});
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
|
||||
export type ChatItemSimpleType = {
|
||||
obj: 'Human' | 'AI' | 'SYSTEM';
|
||||
obj: `${ChatRoleEnum}`;
|
||||
value: string;
|
||||
systemPrompt?: string;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
import { OpenAiChatEnum } from '@/constants/model';
|
||||
import type { ChatModelType } from '@/constants/model';
|
||||
import type { ChatItemSimpleType } from '@/types/chat';
|
||||
import { countOpenAIToken, getOpenAiEncMap, adaptChatItem_openAI } from './openai';
|
||||
|
||||
export type CountTokenType = { messages: ChatItemSimpleType[] };
|
||||
|
||||
export const modelToolMap = {
|
||||
[OpenAiChatEnum.GPT35]: {
|
||||
countTokens: ({ messages }: CountTokenType) =>
|
||||
countOpenAIToken({ model: OpenAiChatEnum.GPT35, messages }),
|
||||
adaptChatMessages: adaptChatItem_openAI
|
||||
},
|
||||
[OpenAiChatEnum.GPT4]: {
|
||||
countTokens: ({ messages }: CountTokenType) =>
|
||||
countOpenAIToken({ model: OpenAiChatEnum.GPT4, messages }),
|
||||
adaptChatMessages: adaptChatItem_openAI
|
||||
},
|
||||
[OpenAiChatEnum.GPT432k]: {
|
||||
countTokens: ({ messages }: CountTokenType) =>
|
||||
countOpenAIToken({ model: OpenAiChatEnum.GPT432k, messages }),
|
||||
adaptChatMessages: adaptChatItem_openAI
|
||||
}
|
||||
};
|
||||
|
||||
export const sliceTextByToken = ({
|
||||
model = 'gpt-3.5-turbo',
|
||||
text,
|
||||
length
|
||||
}: {
|
||||
model: ChatModelType;
|
||||
text: string;
|
||||
length: number;
|
||||
}) => {
|
||||
const enc = getOpenAiEncMap()[model];
|
||||
const encodeText = enc.encode(text);
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(enc.decode(encodeText.slice(0, length)));
|
||||
};
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
|
||||
import type { ChatItemSimpleType } from '@/types/chat';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||
|
||||
import Graphemer from 'graphemer';
|
||||
|
||||
const textDecoder = new TextDecoder();
|
||||
const graphemer = new Graphemer();
|
||||
|
||||
export const adaptChatItem_openAI = ({
|
||||
messages
|
||||
}: {
|
||||
messages: ChatItemSimpleType[];
|
||||
}): ChatCompletionRequestMessage[] => {
|
||||
const map = {
|
||||
[ChatRoleEnum.AI]: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
[ChatRoleEnum.Human]: ChatCompletionRequestMessageRoleEnum.User,
|
||||
[ChatRoleEnum.System]: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
return messages.map((item) => ({
|
||||
role: map[item.obj] || ChatCompletionRequestMessageRoleEnum.System,
|
||||
content: item.value || ''
|
||||
}));
|
||||
};
|
||||
|
||||
/* count openai chat token*/
|
||||
let OpenAiEncMap: Record<string, Tiktoken>;
|
||||
export const getOpenAiEncMap = () => {
|
||||
if (OpenAiEncMap) return OpenAiEncMap;
|
||||
OpenAiEncMap = {
|
||||
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
}),
|
||||
'gpt-4': encoding_for_model('gpt-4', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
}),
|
||||
'gpt-4-32k': encoding_for_model('gpt-4-32k', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
})
|
||||
};
|
||||
return OpenAiEncMap;
|
||||
};
|
||||
export function countOpenAIToken({
|
||||
messages,
|
||||
model
|
||||
}: {
|
||||
messages: ChatItemSimpleType[];
|
||||
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k';
|
||||
}) {
|
||||
function getChatGPTEncodingText(
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string; name?: string }[],
|
||||
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'
|
||||
) {
|
||||
const isGpt3 = model === 'gpt-3.5-turbo';
|
||||
|
||||
const msgSep = isGpt3 ? '\n' : '';
|
||||
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
|
||||
|
||||
return [
|
||||
messages
|
||||
.map(({ name = '', role, content }) => {
|
||||
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
|
||||
})
|
||||
.join(msgSep),
|
||||
`<|im_start|>assistant${roleSep}`
|
||||
].join(msgSep);
|
||||
}
|
||||
function text2TokensLen(encoder: Tiktoken, inputText: string) {
|
||||
const encoding = encoder.encode(inputText, 'all');
|
||||
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
|
||||
|
||||
let byteAcc: number[] = [];
|
||||
let tokenAcc: { id: number; idx: number }[] = [];
|
||||
let inputGraphemes = graphemer.splitGraphemes(inputText);
|
||||
|
||||
for (let idx = 0; idx < encoding.length; idx++) {
|
||||
const token = encoding[idx]!;
|
||||
byteAcc.push(...encoder.decode_single_token_bytes(token));
|
||||
tokenAcc.push({ id: token, idx });
|
||||
|
||||
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
|
||||
const graphemes = graphemer.splitGraphemes(segmentText);
|
||||
|
||||
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
|
||||
segments.push({ text: segmentText, tokens: tokenAcc });
|
||||
|
||||
byteAcc = [];
|
||||
tokenAcc = [];
|
||||
inputGraphemes = inputGraphemes.slice(graphemes.length);
|
||||
}
|
||||
}
|
||||
|
||||
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
|
||||
}
|
||||
|
||||
const adaptMessages = adaptChatItem_openAI({ messages });
|
||||
|
||||
return text2TokensLen(getOpenAiEncMap()[model], getChatGPTEncodingText(adaptMessages, model));
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import mammoth from 'mammoth';
|
||||
import Papa from 'papaparse';
|
||||
import { getEncMap } from './tools';
|
||||
import { getOpenAiEncMap } from './chat/openai';
|
||||
|
||||
/**
|
||||
* 读取 txt 文件内容
|
||||
|
|
@ -154,7 +154,7 @@ export const splitText_token = ({
|
|||
maxLen: number;
|
||||
slideLen: number;
|
||||
}) => {
|
||||
const enc = getEncMap()['gpt-3.5-turbo'];
|
||||
const enc = getOpenAiEncMap()['gpt-3.5-turbo'];
|
||||
// filter empty text. encode sentence
|
||||
const encodeText = enc.encode(text);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +1,5 @@
|
|||
import crypto from 'crypto';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
|
||||
import Graphemer from 'graphemer';
|
||||
import type { ChatModelType } from '@/constants/model';
|
||||
|
||||
const textDecoder = new TextDecoder();
|
||||
const graphemer = new Graphemer();
|
||||
let encMap: Record<string, Tiktoken>;
|
||||
export const getEncMap = () => {
|
||||
if (encMap) return encMap;
|
||||
encMap = {
|
||||
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
}),
|
||||
'gpt-4': encoding_for_model('gpt-4', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
}),
|
||||
'gpt-4-32k': encoding_for_model('gpt-4-32k', {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
'<|im_sep|>': 100266
|
||||
})
|
||||
};
|
||||
return encMap;
|
||||
};
|
||||
|
||||
/**
|
||||
* copy text data
|
||||
|
|
@ -79,75 +51,3 @@ export const Obj2Query = (obj: Record<string, string | number>) => {
|
|||
}
|
||||
return queryParams.toString();
|
||||
};
|
||||
|
||||
/* 格式化 chat 聊天内容 */
|
||||
function getChatGPTEncodingText(
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string; name?: string }[],
|
||||
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'
|
||||
) {
|
||||
const isGpt3 = model === 'gpt-3.5-turbo';
|
||||
|
||||
const msgSep = isGpt3 ? '\n' : '';
|
||||
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
|
||||
|
||||
return [
|
||||
messages
|
||||
.map(({ name = '', role, content }) => {
|
||||
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
|
||||
})
|
||||
.join(msgSep),
|
||||
`<|im_start|>assistant${roleSep}`
|
||||
].join(msgSep);
|
||||
}
|
||||
function text2TokensLen(encoder: Tiktoken, inputText: string) {
|
||||
const encoding = encoder.encode(inputText, 'all');
|
||||
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
|
||||
|
||||
let byteAcc: number[] = [];
|
||||
let tokenAcc: { id: number; idx: number }[] = [];
|
||||
let inputGraphemes = graphemer.splitGraphemes(inputText);
|
||||
|
||||
for (let idx = 0; idx < encoding.length; idx++) {
|
||||
const token = encoding[idx]!;
|
||||
byteAcc.push(...encoder.decode_single_token_bytes(token));
|
||||
tokenAcc.push({ id: token, idx });
|
||||
|
||||
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
|
||||
const graphemes = graphemer.splitGraphemes(segmentText);
|
||||
|
||||
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
|
||||
segments.push({ text: segmentText, tokens: tokenAcc });
|
||||
|
||||
byteAcc = [];
|
||||
tokenAcc = [];
|
||||
inputGraphemes = inputGraphemes.slice(graphemes.length);
|
||||
}
|
||||
}
|
||||
|
||||
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
|
||||
}
|
||||
export const countChatTokens = ({
|
||||
model = 'gpt-3.5-turbo',
|
||||
messages
|
||||
}: {
|
||||
model?: ChatModelType;
|
||||
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
|
||||
}) => {
|
||||
const text = getChatGPTEncodingText(messages, model);
|
||||
return text2TokensLen(getEncMap()[model], text);
|
||||
};
|
||||
|
||||
export const sliceTextByToken = ({
|
||||
model = 'gpt-3.5-turbo',
|
||||
text,
|
||||
length
|
||||
}: {
|
||||
model?: ChatModelType;
|
||||
text: string;
|
||||
length: number;
|
||||
}) => {
|
||||
const enc = getEncMap()[model];
|
||||
const encodeText = enc.encode(text);
|
||||
const decoder = new TextDecoder();
|
||||
return decoder.decode(enc.decode(encodeText.slice(0, length)));
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue