diff --git a/src/service/utils/tools.ts b/src/service/utils/tools.ts index 113ea54fd..8f43d3ce2 100644 --- a/src/service/utils/tools.ts +++ b/src/service/utils/tools.ts @@ -6,7 +6,7 @@ import { OpenApi, User } from '../mongo'; import { formatPrice } from '@/utils/user'; import { ERROR_ENUM } from '../errorCode'; import { countChatTokens } from '@/utils/tools'; -import { ChatCompletionRequestMessageRoleEnum } from 'openai'; +import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai'; import { ChatModelEnum } from '@/constants/model'; /* 密码加密 */ @@ -88,6 +88,13 @@ export const authOpenApiKey = async (req: NextApiRequest) => { export const httpsAgent = (fast: boolean) => fast ? global.httpsAgentFast : global.httpsAgentNormal; +/* delete invalid symbol */ +const simplifyStr = (str: string) => + str + .replace(/\n+/g, '\n') // 连续空行 + .replace(/[^\S\r\n]+/g, ' ') // 连续空白内容 + .trim(); + /* 聊天内容 tokens 截断 */ export const openaiChatFilter = ({ model, @@ -98,40 +105,44 @@ export const openaiChatFilter = ({ prompts: ChatItemType[]; maxTokens: number; }) => { - const formatPrompts = prompts.map((item) => ({ - obj: item.obj, - value: item.value - // .replace(/[\u3000\u3001\uff01-\uff5e\u3002]/g, ' ') // 中文标点改空格 - .replace(/\n+/g, '\n') // 连续空行 - .replace(/[^\S\r\n]+/g, ' ') // 连续空白内容 - .trim() - })); - - let chats: ChatItemType[] = []; - let systemPrompt: ChatItemType | null = null; - - // System 词保留 - if (formatPrompts[0]?.obj === 'SYSTEM') { - systemPrompt = formatPrompts.shift() as ChatItemType; - } - - // 格式化文本内容成 chatgpt 格式 + // role map const map = { Human: ChatCompletionRequestMessageRoleEnum.User, AI: ChatCompletionRequestMessageRoleEnum.Assistant, SYSTEM: ChatCompletionRequestMessageRoleEnum.System }; + let rawTextLen = 0; + const formatPrompts = prompts.map((item) => { + const val = simplifyStr(item.value); + rawTextLen += val.length; + return { + role: map[item.obj], + content: val + }; + }); + + // 长度太小时,不需要进行 token 截断 + if (rawTextLen < maxTokens * 0.5) { + return formatPrompts; + } + + // 根据 tokens 截断内容 + const chats: ChatCompletionRequestMessage[] = []; + let systemPrompt: ChatCompletionRequestMessage | null = null; + + // System 词保留 + if (formatPrompts[0]?.role === 'system') { + systemPrompt = formatPrompts.shift() as ChatCompletionRequestMessage; + } + let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = []; // 从后往前截取对话内容 for (let i = formatPrompts.length - 1; i >= 0; i--) { chats.unshift(formatPrompts[i]); - messages = (systemPrompt ? [systemPrompt, ...chats] : chats).map((item) => ({ - role: map[item.obj], - content: item.value - })); + messages = systemPrompt ? [systemPrompt, ...chats] : chats; const tokens = countChatTokens({ model, @@ -147,7 +158,7 @@ export const openaiChatFilter = ({ return messages; }; -/* system 内容截断 */ +/* system 内容截断. 相似度从高到低 */ export const systemPromptFilter = ({ model, prompts, @@ -161,7 +172,7 @@ export const systemPromptFilter = ({ // 从前往前截取 for (let i = 0; i < prompts.length; i++) { - const prompt = prompts[i].replace(/\n+/g, '\n'); + const prompt = simplifyStr(prompts[i]); splitText += `${prompt}\n`; const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] }); @@ -170,5 +181,5 @@ export const systemPromptFilter = ({ } } - return splitText.slice(0, splitText.length - 1).replace(/\n+/g, '\n'); + return splitText.slice(0, splitText.length - 1); };