perf: chat framwork

This commit is contained in:
archer 2023-05-03 15:28:25 +08:00
parent 91decc3683
commit 00a99261ae
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
23 changed files with 811 additions and 1011 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
import type { ModelSchema } from '@/types/mongoSchema';
export const embeddingModel = 'text-embedding-ada-002';
export type EmbeddingModelType = 'text-embedding-ada-002';
export enum OpenAiChatEnum {
'GPT35' = 'gpt-3.5-turbo',
@ -25,7 +26,7 @@ export const ChatModelMap = {
},
[OpenAiChatEnum.GPT432k]: {
name: 'Gpt4-32k',
contextMaxToken: 8000,
contextMaxToken: 32000,
maxTemperature: 1.5,
price: 30
}

View File

@ -1,14 +1,15 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { authChat } from '@/service/utils/auth';
import { modelServiceToolMap } from '@/service/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
import { resStreamResponse } from '@/service/utils/chat';
import { searchKb } from '@/service/plugins/searchKb';
import { ChatRoleEnum } from '@/constants/chat';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -41,7 +42,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
await connectToDatabase();
let startTime = Date.now();
const { model, showModelDetail, content, userApiKey, systemKey, userId } = await authChat({
const { model, showModelDetail, content, userApiKey, systemApiKey, userId } = await authChat({
modelId,
chatId,
authorization
@ -54,9 +55,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 使用了知识库搜索
if (model.chat.useKb) {
const { code, searchPrompt } = await searchKb_openai({
apiKey: userApiKey || systemKey,
isPay: !userApiKey,
const { code, searchPrompt } = await searchKb({
userApiKey,
systemApiKey,
text: prompt.value,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model,
@ -73,53 +74,37 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 没有用知识库搜索,仅用系统提示词
model.chat.systemPrompt &&
prompts.unshift({
obj: 'SYSTEM',
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
});
}
// 控制总 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 300
});
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// console.log(filterPrompts);
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true,
stop: ['.!?。']
},
{
timeout: 40000,
responseType: 'stream',
...axiosConfig()
}
);
const { streamResponse } = await modelServiceToolMap[model.chat.chatModel].chatCompletion({
apiKey: userApiKey || systemApiKey,
temperature: +temperature,
messages: prompts,
stream: true
});
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
step = 1;
const { responseContent } = await gpt35StreamResponse({
const { totalTokens, finishMessages } = await resStreamResponse({
model: model.chat.chatModel,
res,
stream,
chatResponse,
chatResponse: streamResponse,
prompts,
systemPrompt:
showModelDetail && filterPrompts[0].role === 'system' ? filterPrompts[0].content : ''
showModelDetail && prompts[0].obj === ChatRoleEnum.System ? prompts[0].value : ''
});
// 只有使用平台的 key 才计费
@ -128,7 +113,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
chatModel: model.chat.chatModel,
userId,
chatId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
textLen: finishMessages.map((item) => item.value).join('').length,
tokens: totalTokens
});
} catch (err: any) {
if (step === 1) {

View File

@ -1,14 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { authOpenApiKey, authModel } from '@/service/utils/auth';
import { modelServiceToolMap, resStreamResponse } from '@/service/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
import { searchKb } from '@/service/plugins/searchKb';
import { ChatRoleEnum } from '@/constants/chat';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -64,9 +64,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (model.chat.useKb) {
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
const { code, searchPrompt } = await searchKb_openai({
apiKey,
isPay: true,
const { code, searchPrompt } = await searchKb({
systemApiKey: apiKey,
text: prompts[prompts.length - 1].value,
similarity,
model,
@ -83,69 +82,55 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 没有用知识库搜索,仅用系统提示词
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
});
}
}
// 控制总 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 300
});
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// console.log(filterPrompts);
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream,
stop: ['.!?。']
},
{
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
...axiosConfig()
}
);
const { streamResponse, responseMessages, responseText, totalTokens } =
await modelServiceToolMap[model.chat.chatModel].chatCompletion({
apiKey,
temperature: +temperature,
messages: prompts,
stream: isStream
});
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let responseContent = '';
let textLen = 0;
let tokens = totalTokens;
if (isStream) {
step = 1;
const streamResponse = await gpt35StreamResponse({
const { finishMessages, totalTokens } = await resStreamResponse({
model: model.chat.chatModel,
res,
stream,
chatResponse
chatResponse: streamResponse,
prompts
});
responseContent = streamResponse.responseContent;
textLen = finishMessages.map((item) => item.value).join('').length;
tokens = totalTokens;
} else {
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
textLen = responseMessages.map((item) => item.value).join('').length;
jsonRes(res, {
data: responseContent
data: responseText
});
}
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
textLen,
tokens
});
} catch (err: any) {
if (step === 1) {

View File

@ -1,144 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { ChatItemSimpleType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { ChatModelMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const {
prompts,
modelId,
isStream = true
} = req.body as {
prompts: ChatItemSimpleType[];
modelId: string;
isStream: boolean;
};
if (!prompts || !modelId) {
throw new Error('缺少参数');
}
if (!Array.isArray(prompts)) {
throw new Error('prompts is not array');
}
if (prompts.length > 30 || prompts.length === 0) {
throw new Error('prompts length range 1-30');
}
await connectToDatabase();
let startTime = Date.now();
const { apiKey, userId } = await authOpenApiKey(req);
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权使用该模型');
}
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// 如果有系统提示词,自动插入
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 300
});
// console.log(filterPrompts);
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream,
stop: ['.!?。']
},
{
timeout: 40000,
responseType: isStream ? 'stream' : 'json',
...axiosConfig()
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let responseContent = '';
if (isStream) {
step = 1;
const streamResponse = await gpt35StreamResponse({
res,
stream,
chatResponse
});
responseContent = streamResponse.responseContent;
} else {
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
jsonRes(res, {
data: responseContent
});
}
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@ -1,14 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { authOpenApiKey } from '@/service/utils/auth';
import { resStreamResponse, modelServiceToolMap } from '@/service/utils/chat';
import { ChatItemSimpleType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { ChatModelMap, ModelVectorSearchModeMap, OpenAiChatEnum } from '@/constants/model';
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
import { searchKb } from '@/service/plugins/searchKb';
import { ChatRoleEnum } from '@/constants/chat';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -57,20 +57,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
console.log('laf gpt start');
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 请求一次 chatgpt 拆解需求
const promptResponse = await chatAPI.createChatCompletion(
{
model: OpenAiChatEnum.GPT35,
temperature: 0,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
messages: [
{
role: 'system',
content: `服务端逻辑生成器.根据用户输入的需求,拆解成 laf 云函数实现的步骤,只返回步骤,按格式返回步骤: 1.\n2.\n3.\n ......
const { responseText: resolveText, totalTokens: resolveTokens } = await modelServiceToolMap[
model.chat.chatModel
].chatCompletion({
apiKey,
temperature: 0,
messages: [
{
obj: ChatRoleEnum.System,
value: `服务端逻辑生成器.根据用户输入的需求,拆解成 laf 云函数实现的步骤,只返回步骤,按格式返回步骤: 1.\n2.\n3.\n ......
:
hello world
1. : "hello world"
@ -103,35 +99,25 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
5. , updateTime.
6. ,"blogs", blogId {blogText, tags, updateTime}.
7. "更新博客记录成功"`
},
{
role: 'user',
content: prompt.value
}
]
},
{
timeout: 180000,
...axiosConfig()
}
);
},
{
obj: ChatRoleEnum.Human,
value: prompt.value
}
],
stream: false
});
const promptResolve = promptResponse.data.choices?.[0]?.message?.content || '';
if (!promptResolve) {
throw new Error('gpt 异常');
}
prompt.value += ` ${promptResolve}`;
prompt.value += ` ${resolveText}`;
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
// 读取对话内容
const prompts = [prompt];
// 获取向量匹配到的提示词
const { searchPrompt } = await searchKb_openai({
isPay: true,
apiKey,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
const { searchPrompt } = await searchKb({
systemApiKey: apiKey,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
text: prompt.value,
model,
userId
@ -139,49 +125,41 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
searchPrompt && prompts.unshift(searchPrompt);
// 控制上下文 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 300
});
// console.log(filterPrompts);
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream
},
{
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
...axiosConfig()
}
);
let responseContent = '';
// 发出请求
const { streamResponse, responseMessages, responseText, totalTokens } =
await modelServiceToolMap[model.chat.chatModel].chatCompletion({
apiKey,
temperature: +temperature,
messages: prompts,
stream: isStream
});
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let textLen = resolveText.length;
let tokens = resolveTokens;
if (isStream) {
step = 1;
const streamResponse = await gpt35StreamResponse({
const { finishMessages, totalTokens } = await resStreamResponse({
model: model.chat.chatModel,
res,
stream,
chatResponse
chatResponse: streamResponse,
prompts
});
responseContent = streamResponse.responseContent;
textLen += finishMessages.map((item) => item.value).join('').length;
tokens += totalTokens;
} else {
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
textLen += responseMessages.map((item) => item.value).join('').length;
tokens += totalTokens;
jsonRes(res, {
data: responseContent
data: responseText
});
}
@ -191,7 +169,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
isPay: true,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
textLen,
tokens
});
} catch (err: any) {
if (step === 1) {

View File

@ -1,159 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { ChatItemSimpleType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
ChatModelMap,
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const {
prompts,
modelId,
isStream = true
} = req.body as {
prompts: ChatItemSimpleType[];
modelId: string;
isStream: boolean;
};
if (!prompts || !modelId) {
throw new Error('缺少参数');
}
if (!Array.isArray(prompts)) {
throw new Error('prompts is not array');
}
if (prompts.length > 30 || prompts.length === 0) {
throw new Error('prompts length range 1-30');
}
await connectToDatabase();
let startTime = Date.now();
/* 凭证校验 */
const { apiKey, userId } = await authOpenApiKey(req);
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权使用该模型');
}
const modelConstantsData = ChatModelMap[model.chat.chatModel];
// 获取向量匹配到的提示词
const { code, searchPrompt } = await searchKb_openai({
isPay: true,
apiKey,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
text: prompts[prompts.length - 1].value,
model,
userId
});
// search result is empty
if (code === 201) {
return res.send(searchPrompt?.value);
}
searchPrompt && prompts.unshift(searchPrompt);
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 300
});
// console.log(filterPrompts);
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream,
stop: ['.!?。']
},
{
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
...axiosConfig()
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let responseContent = '';
if (isStream) {
step = 1;
const streamResponse = await gpt35StreamResponse({
res,
stream,
chatResponse
});
responseContent = streamResponse.responseContent;
} else {
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
jsonRes(res, {
data: responseContent
});
}
pushChatBill({
isPay: true,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
// jsonRes(res);
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@ -1,5 +0,0 @@
export enum OpenAiTuneStatusEnum {
cancelled = 'cancelled',
succeeded = 'succeeded',
pending = 'pending'
}

View File

@ -1,14 +1,13 @@
import { SplitData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth';
import { axiosConfig } from '@/service/utils/tools';
import { getOpenApiKey } from '../utils/openai';
import type { ChatCompletionRequestMessage } from 'openai';
import { getApiKey } from '../utils/auth';
import { OpenAiChatEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { openaiError2 } from '../errorCode';
import { PgClient } from '@/service/pg';
import { ModelSplitDataSchema } from '@/types/mongoSchema';
import { modelServiceToolMap } from '../utils/chat';
import { ChatRoleEnum } from '@/constants/chat';
export async function generateQA(next = false): Promise<any> {
if (process.env.queueTask !== '1') {
@ -47,11 +46,11 @@ export async function generateQA(next = false): Promise<any> {
// 获取 openapi Key
let userApiKey = '',
systemKey = '';
systemApiKey = '';
try {
const key = await getOpenApiKey(dataItem.userId);
const key = await getApiKey({ model: OpenAiChatEnum.GPT35, userId: dataItem.userId });
userApiKey = key.userApiKey;
systemKey = key.systemKey;
systemApiKey = key.systemApiKey;
} catch (error: any) {
if (error?.code === 501) {
// 余额不够了, 清空该记录
@ -69,55 +68,44 @@ export async function generateQA(next = false): Promise<any> {
const startTime = Date.now();
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey);
const systemPrompt: ChatCompletionRequestMessage = {
role: 'system',
content: `你是出题人
${dataItem.prompt || '下面是"一段长文本"'}
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
A1:
Q2:
A2:
...`
};
// 请求 chatgpt 获取回答
const response = await Promise.allSettled(
textList.map((text) =>
chatAPI
.createChatCompletion(
{
model: OpenAiChatEnum.GPT35,
temperature: 0.8,
n: 1,
messages: [
systemPrompt,
{
role: 'user',
content: text
}
]
},
{
timeout: 180000,
...axiosConfig()
}
)
.then((res) => {
const rawContent = res?.data.choices[0].message?.content || ''; // chatgpt 原本的回复
const result = formatSplitText(res?.data.choices[0].message?.content || ''); // 格式化后的QA对
modelServiceToolMap[OpenAiChatEnum.GPT35]
.chatCompletion({
apiKey: userApiKey || systemApiKey,
temperature: 0.8,
messages: [
{
obj: ChatRoleEnum.System,
value: `你是出题人
${dataItem.prompt || '下面是"一段长文本"'}
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
A1:
Q2:
A2:
...`
},
{
obj: 'Human',
value: text
}
],
stream: false
})
.then(({ totalTokens, responseText, responseMessages }) => {
const result = formatSplitText(responseText); // 格式化后的QA对
console.log(`split result length: `, result.length);
// 计费
pushSplitDataBill({
isPay: !userApiKey && result.length > 0,
userId: dataItem.userId,
type: 'QA',
text: systemPrompt.content + text + rawContent,
tokenLen: res.data.usage?.total_tokens || 0
textLen: responseMessages.map((item) => item.value).join('').length,
totalTokens
});
return {
rawContent,
rawContent: responseText,
result
};
})

View File

@ -1,6 +1,8 @@
import { openaiCreateEmbedding, getOpenApiKey } from '../utils/openai';
import { openaiCreateEmbedding } from '../utils/chat/openai';
import { getApiKey } from '../utils/auth';
import { openaiError2 } from '../errorCode';
import { PgClient } from '@/service/pg';
import { embeddingModel } from '@/constants/model';
export async function generateVector(next = false): Promise<any> {
if (process.env.queueTask !== '1') {
@ -40,11 +42,11 @@ export async function generateVector(next = false): Promise<any> {
dataId = dataItem.id;
// 获取 openapi Key
let userApiKey, systemKey;
let userApiKey, systemApiKey;
try {
const res = await getOpenApiKey(dataItem.userId);
const res = await getApiKey({ model: embeddingModel, userId: dataItem.userId });
userApiKey = res.userApiKey;
systemKey = res.systemKey;
systemApiKey = res.systemApiKey;
} catch (error: any) {
if (error?.code === 501) {
await PgClient.delete('modelData', {
@ -61,8 +63,8 @@ export async function generateVector(next = false): Promise<any> {
const { vector } = await openaiCreateEmbedding({
text: dataItem.q,
userId: dataItem.userId,
isPay: !userApiKey,
apiKey: userApiKey || systemKey
userApiKey,
systemApiKey
});
// 更新 pg 向量和状态数据

View File

@ -1,60 +1,54 @@
import { connectToDatabase, Bill, User } from '../mongo';
import { ChatModelMap, OpenAiChatEnum, ChatModelType, embeddingModel } from '@/constants/model';
import { BillTypeEnum } from '@/constants/user';
import { countChatTokens } from '@/utils/tools';
export const pushChatBill = async ({
isPay,
chatModel,
userId,
chatId,
messages
textLen,
tokens
}: {
isPay: boolean;
chatModel: ChatModelType;
userId: string;
chatId?: '' | string;
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
textLen: number;
tokens: number;
}) => {
console.log(`chat generate success. text len: ${textLen}. token len: ${tokens}. pay:${isPay}`);
if (!isPay) return;
let billId = '';
try {
// 计算 token 数量
const tokens = countChatTokens({ model: chatModel, messages });
const text = messages.map((item) => item.content).join('');
await connectToDatabase();
console.log(
`chat generate success. text len: ${text.length}. token len: ${tokens}. pay:${isPay}`
);
// 计算价格
const unitPrice = ChatModelMap[chatModel]?.price || 5;
const price = unitPrice * tokens;
if (isPay) {
await connectToDatabase();
try {
// 插入 Bill 记录
const res = await Bill.create({
userId,
type: 'chat',
modelName: chatModel,
chatId: chatId ? chatId : undefined,
textLen,
tokenLen: tokens,
price
});
billId = res._id;
// 计算价格
const unitPrice = ChatModelMap[chatModel]?.price || 5;
const price = unitPrice * tokens;
try {
// 插入 Bill 记录
const res = await Bill.create({
userId,
type: 'chat',
modelName: chatModel,
chatId: chatId ? chatId : undefined,
textLen: text.length,
tokenLen: tokens,
price
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
} catch (error) {
console.log(error);
@ -64,54 +58,49 @@ export const pushChatBill = async ({
export const pushSplitDataBill = async ({
isPay,
userId,
tokenLen,
text,
totalTokens,
textLen,
type
}: {
isPay: boolean;
userId: string;
tokenLen: number;
text: string;
totalTokens: number;
textLen: number;
type: `${BillTypeEnum}`;
}) => {
await connectToDatabase();
console.log(
`splitData generate success. text len: ${textLen}. token len: ${totalTokens}. pay:${isPay}`
);
if (!isPay) return;
let billId;
try {
console.log(
`splitData generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
);
await connectToDatabase();
if (isPay) {
try {
// 获取模型单价格, 都是用 gpt35 拆分
const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35]?.price || 3;
// 计算价格
const price = unitPrice * tokenLen;
// 获取模型单价格, 都是用 gpt35 拆分
const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35].price || 3;
// 计算价格
const price = unitPrice * totalTokens;
// 插入 Bill 记录
const res = await Bill.create({
userId,
type,
modelName: OpenAiChatEnum.GPT35,
textLen: text.length,
tokenLen,
price
});
billId = res._id;
// 插入 Bill 记录
const res = await Bill.create({
userId,
type,
modelName: OpenAiChatEnum.GPT35,
textLen,
tokenLen: totalTokens,
price
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
}
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log(error);
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
};
@ -126,41 +115,40 @@ export const pushGenerateVectorBill = async ({
text: string;
tokenLen: number;
}) => {
await connectToDatabase();
console.log(
`vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
);
if (!isPay) return;
let billId;
try {
console.log(
`vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
);
await connectToDatabase();
if (isPay) {
try {
const unitPrice = 0.4;
// 计算价格. 至少为1
let price = unitPrice * tokenLen;
price = price > 1 ? price : 1;
try {
const unitPrice = 0.4;
// 计算价格. 至少为1
let price = unitPrice * tokenLen;
price = price > 1 ? price : 1;
// 插入 Bill 记录
const res = await Bill.create({
userId,
type: BillTypeEnum.vector,
modelName: embeddingModel,
textLen: text.length,
tokenLen,
price
});
billId = res._id;
// 插入 Bill 记录
const res = await Bill.create({
userId,
type: BillTypeEnum.vector,
modelName: embeddingModel,
textLen: text.length,
tokenLen,
price
});
billId = res._id;
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
// 账号扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
console.log('创建账单失败:', error);
billId && Bill.findByIdAndDelete(billId);
}
} catch (error) {
console.log(error);

View File

@ -1,5 +1,6 @@
import { Schema, model, models, Model } from 'mongoose';
import { ChatSchema as ChatType } from '@/types/mongoSchema';
import { ChatRoleMap } from '@/constants/chat';
const ChatSchema = new Schema({
userId: {
@ -36,7 +37,7 @@ const ChatSchema = new Schema({
obj: {
type: String,
required: true,
enum: ['Human', 'AI', 'SYSTEM']
enum: Object.keys(ChatRoleMap)
},
value: {
type: String,

View File

@ -1,22 +1,23 @@
import { openaiCreateEmbedding } from '../utils/openai';
import { PgClient } from '@/service/pg';
import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model';
import { ModelSchema } from '@/types/mongoSchema';
import { systemPromptFilter } from '../utils/tools';
import { openaiCreateEmbedding } from '../utils/chat/openai';
import { ChatRoleEnum } from '@/constants/chat';
import { sliceTextByToken } from '@/utils/chat';
/**
* use openai embedding search kb
*/
export const searchKb_openai = async ({
apiKey,
isPay = true,
export const searchKb = async ({
userApiKey,
systemApiKey,
text,
similarity = 0.2,
model,
userId
}: {
apiKey: string;
isPay: boolean;
userApiKey?: string;
systemApiKey: string;
text: string;
model: ModelSchema;
userId: string;
@ -24,7 +25,7 @@ export const searchKb_openai = async ({
}): Promise<{
code: 200 | 201;
searchPrompt?: {
obj: 'Human' | 'AI' | 'SYSTEM';
obj: `${ChatRoleEnum}`;
value: string;
};
}> => {
@ -32,8 +33,8 @@ export const searchKb_openai = async ({
// 获取提示词的向量
const { vector: promptVector } = await openaiCreateEmbedding({
isPay,
apiKey,
userApiKey,
systemApiKey,
userId,
text
});
@ -61,7 +62,7 @@ export const searchKb_openai = async ({
return {
code: 201,
searchPrompt: {
obj: 'AI',
obj: ChatRoleEnum.AI,
value: '对不起,你的问题不在知识库中。'
}
};
@ -72,7 +73,7 @@ export const searchKb_openai = async ({
code: 200,
searchPrompt: model.chat.systemPrompt
? {
obj: 'SYSTEM',
obj: ChatRoleEnum.System,
value: model.chat.systemPrompt
}
: undefined
@ -81,16 +82,16 @@ export const searchKb_openai = async ({
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 65% tokens
const filterSystemPrompt = systemPromptFilter({
const filterSystemPrompt = sliceTextByToken({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: Math.floor(modelConstantsData.contextMaxToken * 0.65)
text: systemPrompts.join('\n'),
length: Math.floor(modelConstantsData.contextMaxToken * 0.65)
});
return {
code: 200,
searchPrompt: {
obj: 'SYSTEM',
obj: ChatRoleEnum.System,
value: `
${model.chat.systemPrompt}
${

View File

@ -1,14 +1,18 @@
import { Configuration, OpenAIApi } from 'openai';
import type { NextApiRequest } from 'next';
import jwt from 'jsonwebtoken';
import { Chat, Model, OpenApi, User } from '../mongo';
import type { ModelSchema } from '@/types/mongoSchema';
import { getOpenApiKey } from './openai';
import type { ChatItemSimpleType } from '@/types/chat';
import mongoose from 'mongoose';
import { defaultModel } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import {
ChatModelType,
OpenAiChatEnum,
embeddingModel,
EmbeddingModelType
} from '@/constants/model';
/* 校验 token */
export const authToken = (token?: string): Promise<string> => {
@ -29,13 +33,63 @@ export const authToken = (token?: string): Promise<string> => {
});
};
export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({
apiKey,
basePath: process.env.OPENAI_BASE_URL
});
/* 获取 api 请求的 key */
export const getApiKey = async ({
model,
userId
}: {
model: ChatModelType | EmbeddingModelType;
userId: string;
}) => {
const user = await User.findById(userId);
if (!user) {
return Promise.reject({
code: 501,
message: '找不到用户'
});
}
return new OpenAIApi(configuration);
const keyMap = {
[OpenAiChatEnum.GPT35]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
},
[OpenAiChatEnum.GPT4]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
},
[OpenAiChatEnum.GPT432k]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
},
[embeddingModel]: {
userApiKey: user.openaiKey || '',
systemApiKey: process.env.OPENAIKEY as string
}
};
// 有自己的key
if (keyMap[model].userApiKey) {
return {
user,
userApiKey: keyMap[model].userApiKey,
systemApiKey: ''
};
}
// 平台账号余额校验
if (formatPrice(user.balance) <= 0) {
return Promise.reject({
code: 501,
message: '账号余额不足'
});
}
return {
user,
userApiKey: '',
systemApiKey: keyMap[model].systemApiKey
};
};
// 模型使用权校验
@ -122,11 +176,11 @@ export const authChat = async ({
]);
}
// 获取 user 的 apiKey
const { userApiKey, systemKey } = await getOpenApiKey(userId);
const { userApiKey, systemApiKey } = await getApiKey({ model: model.chat.chatModel, userId });
return {
userApiKey,
systemKey,
systemApiKey,
content,
userId,
model,

View File

@ -0,0 +1,155 @@
import { ChatItemSimpleType } from '@/types/chat';
import { modelToolMap } from '@/utils/chat';
import type { ChatModelType } from '@/constants/model';
import { ChatRoleEnum, SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
import { OpenAiChatEnum } from '@/constants/model';
import { chatResponse, openAiStreamResponse } from './openai';
import type { NextApiResponse } from 'next';
import type { PassThrough } from 'stream';
export type ChatCompletionType = {
apiKey: string;
temperature: number;
messages: ChatItemSimpleType[];
stream: boolean;
};
export type StreamResponseType = {
stream: PassThrough;
chatResponse: any;
prompts: ChatItemSimpleType[];
};
export const modelServiceToolMap = {
[OpenAiChatEnum.GPT35]: {
chatCompletion: (data: ChatCompletionType) =>
chatResponse({ model: OpenAiChatEnum.GPT35, ...data }),
streamResponse: (data: StreamResponseType) =>
openAiStreamResponse({
model: OpenAiChatEnum.GPT35,
...data
})
},
[OpenAiChatEnum.GPT4]: {
chatCompletion: (data: ChatCompletionType) =>
chatResponse({ model: OpenAiChatEnum.GPT4, ...data }),
streamResponse: (data: StreamResponseType) =>
openAiStreamResponse({
model: OpenAiChatEnum.GPT4,
...data
})
},
[OpenAiChatEnum.GPT432k]: {
chatCompletion: (data: ChatCompletionType) =>
chatResponse({ model: OpenAiChatEnum.GPT432k, ...data }),
streamResponse: (data: StreamResponseType) =>
openAiStreamResponse({
model: OpenAiChatEnum.GPT432k,
...data
})
}
};
/* delete invalid symbol */
const simplifyStr = (str: string) =>
str
.replace(/\n+/g, '\n') // 连续空行
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
.trim();
/* 聊天上下文 tokens 截断 */
export const ChatContextFilter = ({
model,
prompts,
maxTokens
}: {
model: ChatModelType;
prompts: ChatItemSimpleType[];
maxTokens: number;
}) => {
let rawTextLen = 0;
const formatPrompts = prompts.map<ChatItemSimpleType>((item) => {
const val = simplifyStr(item.value);
rawTextLen += val.length;
return {
obj: item.obj,
value: val
};
});
// 长度太小时,不需要进行 token 截断
if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) {
return formatPrompts;
}
// 根据 tokens 截断内容
const chats: ChatItemSimpleType[] = [];
let systemPrompt: ChatItemSimpleType | null = null;
// System 词保留
if (formatPrompts[0].obj === ChatRoleEnum.System) {
const prompt = formatPrompts.shift();
if (prompt) {
systemPrompt = prompt;
}
}
let messages: ChatItemSimpleType[] = [];
// 从后往前截取对话内容
for (let i = formatPrompts.length - 1; i >= 0; i--) {
chats.unshift(formatPrompts[i]);
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
const tokens = modelToolMap[model].countTokens({
messages
});
/* 整体 tokens 超出范围 */
if (tokens >= maxTokens) {
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
}
}
return messages;
};
/* stream response */
export const resStreamResponse = async ({
model,
res,
stream,
chatResponse,
systemPrompt,
prompts
}: StreamResponseType & {
model: ChatModelType;
res: NextApiResponse;
systemPrompt?: string;
}) => {
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
stream.pipe(res);
const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[
model
].streamResponse({
chatResponse,
stream,
prompts
});
// push system prompt
!stream.destroyed &&
systemPrompt &&
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();
return { responseContent, totalTokens, finishMessages };
};

View File

@ -0,0 +1,174 @@
import { Configuration, OpenAIApi } from 'openai';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { axiosConfig } from '../tools';
import { ChatModelMap, embeddingModel, OpenAiChatEnum } from '@/constants/model';
import { pushGenerateVectorBill } from '../../events/pushBill';
import { adaptChatItem_openAI } from '@/utils/chat/openai';
import { modelToolMap } from '@/utils/chat';
import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index';
import { ChatRoleEnum } from '@/constants/chat';
export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({
apiKey,
basePath: process.env.OPENAI_BASE_URL
});
return new OpenAIApi(configuration);
};
/* 获取向量 */
export const openaiCreateEmbedding = async ({
userApiKey,
systemApiKey,
userId,
text
}: {
userApiKey?: string;
systemApiKey: string;
userId: string;
text: string;
}) => {
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemApiKey);
// 把输入的内容转成向量
const res = await chatAPI
.createEmbedding(
{
model: embeddingModel,
input: text
},
{
timeout: 60000,
...axiosConfig()
}
)
.then((res) => ({
tokenLen: res.data.usage.total_tokens || 0,
vector: res.data.data?.[0]?.embedding || []
}));
pushGenerateVectorBill({
isPay: !userApiKey,
userId,
text,
tokenLen: res.tokenLen
});
return {
vector: res.vector,
chatAPI
};
};
/* 模型对话 */
export const chatResponse = async ({
model,
apiKey,
temperature,
messages,
stream
}: ChatCompletionType & { model: `${OpenAiChatEnum}` }) => {
const filterMessages = ChatContextFilter({
model,
prompts: messages,
maxTokens: Math.ceil(ChatModelMap[model].contextMaxToken * 0.9)
});
const adaptMessages = adaptChatItem_openAI({ messages: filterMessages });
const chatAPI = getOpenAIApi(apiKey);
const response = await chatAPI.createChatCompletion(
{
model,
temperature: Number(temperature) || 0,
messages: adaptMessages,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream,
stop: ['.!?。']
},
{
timeout: stream ? 40000 : 240000,
responseType: stream ? 'stream' : 'json',
...axiosConfig()
}
);
let responseText = '';
let totalTokens = 0;
// adapt data
if (!stream) {
responseText = response.data.choices[0].message?.content || '';
totalTokens = response.data.usage?.total_tokens || 0;
}
return {
streamResponse: response,
responseMessages: filterMessages.concat({ obj: 'AI', value: responseText }),
responseText,
totalTokens
};
};
/* openai stream response */
export const openAiStreamResponse = async ({
model,
stream,
chatResponse,
prompts
}: StreamResponseType & {
model: `${OpenAiChatEnum}`;
}) => {
try {
let responseContent = '';
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
const data = event.data;
if (data === '[DONE]') return;
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
responseContent += content;
!stream.destroyed && content && stream.push(content.replace(/\n/g, '<br/>'));
} catch (error) {
error;
}
};
try {
const decoder = new TextDecoder();
const parser = createParser(onParse);
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
parser.feed(decoder.decode(chunk, { stream: true }));
}
} catch (error) {
console.log('pipe error', error);
}
// count tokens
const finishMessages = prompts.concat({
obj: ChatRoleEnum.AI,
value: responseContent
});
const totalTokens = modelToolMap[model].countTokens({
messages: finishMessages
});
return {
responseContent,
totalTokens,
finishMessages
};
} catch (error) {
return Promise.reject(error);
}
};

View File

@ -1,179 +0,0 @@
import type { NextApiResponse } from 'next';
import type { PassThrough } from 'stream';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { getOpenAIApi } from '@/service/utils/auth';
import { axiosConfig } from './tools';
import { User } from '../models/user';
import { formatPrice } from '@/utils/user';
import { embeddingModel } from '@/constants/model';
import { pushGenerateVectorBill } from '../events/pushBill';
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
/* 获取用户 api 的 openai 信息 */
export const getUserApiOpenai = async (userId: string) => {
const user = await User.findById(userId);
const userApiKey = user?.openaiKey;
if (!userApiKey) {
return Promise.reject('缺少ApiKey, 无法请求');
}
return {
user,
openai: getOpenAIApi(userApiKey),
apiKey: userApiKey
};
};
/* 获取 open api key如果用户没有自己的key就用平台的用平台记得加账单 */
export const getOpenApiKey = async (userId: string) => {
const user = await User.findById(userId);
if (!user) {
return Promise.reject({
code: 501,
message: '找不到用户'
});
}
const userApiKey = user?.openaiKey;
// 有自己的key
if (userApiKey) {
return {
user,
userApiKey,
systemKey: ''
};
}
// 平台账号余额校验
if (formatPrice(user.balance) <= 0) {
return Promise.reject({
code: 501,
message: '账号余额不足'
});
}
return {
user,
userApiKey: '',
systemKey: process.env.OPENAIKEY as string
};
};
/* 获取向量 */
export const openaiCreateEmbedding = async ({
isPay,
userId,
apiKey,
text
}: {
isPay: boolean;
userId: string;
apiKey: string;
text: string;
}) => {
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 把输入的内容转成向量
const res = await chatAPI
.createEmbedding(
{
model: embeddingModel,
input: text
},
{
timeout: 60000,
...axiosConfig()
}
)
.then((res) => ({
tokenLen: res.data.usage.total_tokens || 0,
vector: res.data.data?.[0]?.embedding || []
}));
pushGenerateVectorBill({
isPay,
userId,
text,
tokenLen: res.tokenLen
});
return {
vector: res.vector,
chatAPI
};
};
/* gpt35 响应 */
export const gpt35StreamResponse = ({
res,
stream,
chatResponse,
systemPrompt = ''
}: {
res: NextApiResponse;
stream: PassThrough;
chatResponse: any;
systemPrompt?: string;
}) =>
new Promise<{ responseContent: string }>(async (resolve, reject) => {
try {
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
stream.pipe(res);
let responseContent = '';
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
const data = event.data;
if (data === '[DONE]') return;
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
responseContent += content;
if (!stream.destroyed && content) {
stream.push(content.replace(/\n/g, '<br/>'));
}
} catch (error) {
error;
}
};
try {
const decoder = new TextDecoder();
const parser = createParser(onParse);
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
parser.feed(decoder.decode(chunk, { stream: true }));
}
} catch (error) {
console.log('pipe error', error);
}
// push system prompt
!stream.destroyed &&
systemPrompt &&
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();
resolve({
responseContent
});
} catch (error) {
reject(error);
}
});

View File

@ -1,9 +1,5 @@
import crypto from 'crypto';
import jwt from 'jsonwebtoken';
import { ChatItemSimpleType } from '@/types/chat';
import { countChatTokens, sliceTextByToken } from '@/utils/tools';
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
import type { ChatModelType } from '@/constants/model';
/* 密码加密 */
export const hashPassword = (psw: string) => {
@ -30,92 +26,3 @@ export const axiosConfig = () => ({
auth: process.env.OPENAI_BASE_URL_AUTH || ''
}
});
/* delete invalid symbol */
const simplifyStr = (str: string) =>
str
.replace(/\n+/g, '\n') // 连续空行
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
.trim();
/* 聊天内容 tokens 截断 */
export const openaiChatFilter = ({
model,
prompts,
maxTokens
}: {
model: ChatModelType;
prompts: ChatItemSimpleType[];
maxTokens: number;
}) => {
// role map
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
let rawTextLen = 0;
const formatPrompts = prompts.map((item) => {
const val = simplifyStr(item.value);
rawTextLen += val.length;
return {
role: map[item.obj],
content: val
};
});
// 长度太小时,不需要进行 token 截断
if (rawTextLen < maxTokens * 0.5) {
return formatPrompts;
}
// 根据 tokens 截断内容
const chats: ChatCompletionRequestMessage[] = [];
let systemPrompt: ChatCompletionRequestMessage | null = null;
// System 词保留
if (formatPrompts[0]?.role === 'system') {
systemPrompt = formatPrompts.shift() as ChatCompletionRequestMessage;
}
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
// 从后往前截取对话内容
for (let i = formatPrompts.length - 1; i >= 0; i--) {
chats.unshift(formatPrompts[i]);
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
const tokens = countChatTokens({
model,
messages
});
/* 整体 tokens 超出范围 */
if (tokens >= maxTokens) {
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
}
}
return messages;
};
/* system 内容截断. 相似度从高到低 */
export const systemPromptFilter = ({
model,
prompts,
maxTokens
}: {
model: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
prompts: string[];
maxTokens: number;
}) => {
const systemPrompt = prompts.join('\n');
return sliceTextByToken({
model,
text: systemPrompt,
length: maxTokens
});
};

4
src/types/chat.d.ts vendored
View File

@ -1,5 +1,7 @@
import { ChatRoleEnum } from '@/constants/chat';
export type ChatItemSimpleType = {
obj: 'Human' | 'AI' | 'SYSTEM';
obj: `${ChatRoleEnum}`;
value: string;
systemPrompt?: string;
};

39
src/utils/chat/index.ts Normal file
View File

@ -0,0 +1,39 @@
import { OpenAiChatEnum } from '@/constants/model';
import type { ChatModelType } from '@/constants/model';
import type { ChatItemSimpleType } from '@/types/chat';
import { countOpenAIToken, getOpenAiEncMap, adaptChatItem_openAI } from './openai';
export type CountTokenType = { messages: ChatItemSimpleType[] };
export const modelToolMap = {
[OpenAiChatEnum.GPT35]: {
countTokens: ({ messages }: CountTokenType) =>
countOpenAIToken({ model: OpenAiChatEnum.GPT35, messages }),
adaptChatMessages: adaptChatItem_openAI
},
[OpenAiChatEnum.GPT4]: {
countTokens: ({ messages }: CountTokenType) =>
countOpenAIToken({ model: OpenAiChatEnum.GPT4, messages }),
adaptChatMessages: adaptChatItem_openAI
},
[OpenAiChatEnum.GPT432k]: {
countTokens: ({ messages }: CountTokenType) =>
countOpenAIToken({ model: OpenAiChatEnum.GPT432k, messages }),
adaptChatMessages: adaptChatItem_openAI
}
};
export const sliceTextByToken = ({
model = 'gpt-3.5-turbo',
text,
length
}: {
model: ChatModelType;
text: string;
length: number;
}) => {
const enc = getOpenAiEncMap()[model];
const encodeText = enc.encode(text);
const decoder = new TextDecoder();
return decoder.decode(enc.decode(encodeText.slice(0, length)));
};

106
src/utils/chat/openai.ts Normal file
View File

@ -0,0 +1,106 @@
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
import type { ChatItemSimpleType } from '@/types/chat';
import { ChatRoleEnum } from '@/constants/chat';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import Graphemer from 'graphemer';
const textDecoder = new TextDecoder();
const graphemer = new Graphemer();
export const adaptChatItem_openAI = ({
messages
}: {
messages: ChatItemSimpleType[];
}): ChatCompletionRequestMessage[] => {
const map = {
[ChatRoleEnum.AI]: ChatCompletionRequestMessageRoleEnum.Assistant,
[ChatRoleEnum.Human]: ChatCompletionRequestMessageRoleEnum.User,
[ChatRoleEnum.System]: ChatCompletionRequestMessageRoleEnum.System
};
return messages.map((item) => ({
role: map[item.obj] || ChatCompletionRequestMessageRoleEnum.System,
content: item.value || ''
}));
};
/* count openai chat token*/
let OpenAiEncMap: Record<string, Tiktoken>;
export const getOpenAiEncMap = () => {
if (OpenAiEncMap) return OpenAiEncMap;
OpenAiEncMap = {
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
'<|im_sep|>': 100266
}),
'gpt-4': encoding_for_model('gpt-4', {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
'<|im_sep|>': 100266
}),
'gpt-4-32k': encoding_for_model('gpt-4-32k', {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
'<|im_sep|>': 100266
})
};
return OpenAiEncMap;
};
export function countOpenAIToken({
messages,
model
}: {
messages: ChatItemSimpleType[];
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k';
}) {
function getChatGPTEncodingText(
messages: { role: 'system' | 'user' | 'assistant'; content: string; name?: string }[],
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'
) {
const isGpt3 = model === 'gpt-3.5-turbo';
const msgSep = isGpt3 ? '\n' : '';
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
return [
messages
.map(({ name = '', role, content }) => {
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
})
.join(msgSep),
`<|im_start|>assistant${roleSep}`
].join(msgSep);
}
function text2TokensLen(encoder: Tiktoken, inputText: string) {
const encoding = encoder.encode(inputText, 'all');
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
let byteAcc: number[] = [];
let tokenAcc: { id: number; idx: number }[] = [];
let inputGraphemes = graphemer.splitGraphemes(inputText);
for (let idx = 0; idx < encoding.length; idx++) {
const token = encoding[idx]!;
byteAcc.push(...encoder.decode_single_token_bytes(token));
tokenAcc.push({ id: token, idx });
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
const graphemes = graphemer.splitGraphemes(segmentText);
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
segments.push({ text: segmentText, tokens: tokenAcc });
byteAcc = [];
tokenAcc = [];
inputGraphemes = inputGraphemes.slice(graphemes.length);
}
}
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
}
const adaptMessages = adaptChatItem_openAI({ messages });
return text2TokensLen(getOpenAiEncMap()[model], getChatGPTEncodingText(adaptMessages, model));
}

View File

@ -1,6 +1,6 @@
import mammoth from 'mammoth';
import Papa from 'papaparse';
import { getEncMap } from './tools';
import { getOpenAiEncMap } from './chat/openai';
/**
* txt
@ -154,7 +154,7 @@ export const splitText_token = ({
maxLen: number;
slideLen: number;
}) => {
const enc = getEncMap()['gpt-3.5-turbo'];
const enc = getOpenAiEncMap()['gpt-3.5-turbo'];
// filter empty text. encode sentence
const encodeText = enc.encode(text);

View File

@ -1,33 +1,5 @@
import crypto from 'crypto';
import { useToast } from '@/hooks/useToast';
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
import Graphemer from 'graphemer';
import type { ChatModelType } from '@/constants/model';
const textDecoder = new TextDecoder();
const graphemer = new Graphemer();
let encMap: Record<string, Tiktoken>;
export const getEncMap = () => {
if (encMap) return encMap;
encMap = {
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
'<|im_sep|>': 100266
}),
'gpt-4': encoding_for_model('gpt-4', {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
'<|im_sep|>': 100266
}),
'gpt-4-32k': encoding_for_model('gpt-4-32k', {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
'<|im_sep|>': 100266
})
};
return encMap;
};
/**
* copy text data
@ -79,75 +51,3 @@ export const Obj2Query = (obj: Record<string, string | number>) => {
}
return queryParams.toString();
};
/* 格式化 chat 聊天内容 */
function getChatGPTEncodingText(
messages: { role: 'system' | 'user' | 'assistant'; content: string; name?: string }[],
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'
) {
const isGpt3 = model === 'gpt-3.5-turbo';
const msgSep = isGpt3 ? '\n' : '';
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
return [
messages
.map(({ name = '', role, content }) => {
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
})
.join(msgSep),
`<|im_start|>assistant${roleSep}`
].join(msgSep);
}
function text2TokensLen(encoder: Tiktoken, inputText: string) {
const encoding = encoder.encode(inputText, 'all');
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
let byteAcc: number[] = [];
let tokenAcc: { id: number; idx: number }[] = [];
let inputGraphemes = graphemer.splitGraphemes(inputText);
for (let idx = 0; idx < encoding.length; idx++) {
const token = encoding[idx]!;
byteAcc.push(...encoder.decode_single_token_bytes(token));
tokenAcc.push({ id: token, idx });
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
const graphemes = graphemer.splitGraphemes(segmentText);
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
segments.push({ text: segmentText, tokens: tokenAcc });
byteAcc = [];
tokenAcc = [];
inputGraphemes = inputGraphemes.slice(graphemes.length);
}
}
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
}
export const countChatTokens = ({
model = 'gpt-3.5-turbo',
messages
}: {
model?: ChatModelType;
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
}) => {
const text = getChatGPTEncodingText(messages, model);
return text2TokensLen(getEncMap()[model], text);
};
export const sliceTextByToken = ({
model = 'gpt-3.5-turbo',
text,
length
}: {
model?: ChatModelType;
text: string;
length: number;
}) => {
const enc = getEncMap()[model];
const encodeText = enc.encode(text);
const decoder = new TextDecoder();
return decoder.decode(enc.decode(encodeText.slice(0, length)));
};