diff --git a/src/constants/model.ts b/src/constants/model.ts index 7b12ca631..34983ab0a 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -7,31 +7,35 @@ export enum ChatModelNameEnum { } export type ModelConstantsData = { + serviceCompany: `${ServiceName}`; name: string; model: `${ChatModelNameEnum}`; trainName: string; // 空字符串代表不能训练 maxToken: number; maxTemperature: number; + price: number; // 多少钱 / 1字,单位: 0.00001元 }; -export const ModelList: Record = { - openai: [ - { - name: 'chatGPT', - model: ChatModelNameEnum.GPT35, - trainName: 'turbo', - maxToken: 4000, - maxTemperature: 2 - }, - { - name: 'GPT3', - model: ChatModelNameEnum.GPT3, - trainName: 'davinci', - maxToken: 4000, - maxTemperature: 2 - } - ] -}; +export const ModelList: ModelConstantsData[] = [ + { + serviceCompany: 'openai', + name: 'chatGPT', + model: ChatModelNameEnum.GPT35, + trainName: 'turbo', + maxToken: 4000, + maxTemperature: 2, + price: 2 + }, + { + serviceCompany: 'openai', + name: 'GPT3', + model: ChatModelNameEnum.GPT3, + trainName: 'davinci', + maxToken: 4000, + maxTemperature: 2, + price: 20 + } +]; export enum TrainingStatusEnum { pending = 'pending', diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index 76f96325a..9cdf33118 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; -import { connectToDatabase, Chat } from '@/service/mongo'; +import { connectToDatabase } from '@/service/mongo'; import { getOpenAIApi, authChat } from '@/service/utils/chat'; import { httpsAgent } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; @@ -9,6 +9,7 @@ import { jsonRes } from '@/service/response'; import type { ModelSchema } from '@/types/mongoSchema'; import { PassThrough } from 'stream'; import { ModelList } from '@/constants/model'; +import { pushBill } from '@/service/events/bill'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -24,7 +25,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) await connectToDatabase(); - const { chat, userApiKey } = await authChat(chatId); + const { chat, userApiKey, systemKey, userId } = await authChat(chatId); const model: ModelSchema = chat.modelId; @@ -58,16 +59,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } // 计算温度 - const modelConstantsData = ModelList['openai'].find( - (item) => item.model === model.service.modelName - ); + const modelConstantsData = ModelList.find((item) => item.model === model.service.modelName); if (!modelConstantsData) { throw new Error('模型异常'); } const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); // 获取 chatAPI - const chatAPI = getOpenAIApi(userApiKey); + const chatAPI = getOpenAIApi(userApiKey || systemKey); let startTime = Date.now(); // 发出请求 const chatResponse = await chatAPI.createChatCompletion( @@ -84,12 +83,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) httpsAgent } ); - console.log( - 'response success', - `time: ${(Date.now() - startTime) / 1000}s`, - `promptLen: ${formatPrompts.length}`, - `contentLen: ${formatPrompts.reduce((sum, item) => sum + item.content.length, 0)}` - ); + + console.log('api response time:', `time: ${(Date.now() - startTime) / 1000}s`); // 创建响应流 res.setHeader('Content-Type', 'text/event-stream;charset-utf-8'); @@ -97,6 +92,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) res.setHeader('X-Accel-Buffering', 'no'); res.setHeader('Cache-Control', 'no-cache, no-transform'); + let responseContent = ''; const pass = new PassThrough(); pass.pipe(res); @@ -108,6 +104,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const json = JSON.parse(data); const content: string = json?.choices?.[0].delta.content || ''; if (!content) return; + responseContent += content; // console.log('content:', content) pass.push(content.replace(/\n/g, '
')); } catch (error) { @@ -125,6 +122,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) console.log('pipe error', error); } pass.push(null); + + const promptsLen = formatPrompts.reduce((sum, item) => sum + item.content.length, 0); + console.log(`responseLen: ${responseContent.length}`, `promptLen: ${promptsLen}`); + // 只有使用平台的 key 才计费 + !userApiKey && + pushBill({ + modelName: model.service.modelName, + userId, + chatId, + textLen: promptsLen + responseContent.length + }); } catch (err: any) { res.status(500); jsonRes(res, { diff --git a/src/pages/api/chat/gpt3.ts b/src/pages/api/chat/gpt3.ts index 8f684cb69..7d86abac2 100644 --- a/src/pages/api/chat/gpt3.ts +++ b/src/pages/api/chat/gpt3.ts @@ -6,6 +6,7 @@ import { getOpenAIApi, authChat } from '@/service/utils/chat'; import { ChatItemType } from '@/types/chat'; import { httpsAgent } from '@/service/utils/tools'; import { ModelList } from '@/constants/model'; +import { pushBill } from '@/service/events/bill'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -18,20 +19,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) await connectToDatabase(); - const { chat, userApiKey } = await authChat(chatId); + const { chat, userApiKey, systemKey, userId } = await authChat(chatId); const model = chat.modelId; // 获取 chatAPI - const chatAPI = getOpenAIApi(userApiKey); + const chatAPI = getOpenAIApi(userApiKey || systemKey); // prompt处理 - const formatPrompt = prompt.map((item) => `${item.value}\n\n###\n\n`).join(''); + const formatPrompts = prompt.map((item) => `${item.value}\n\n###\n\n`).join(''); // 计算温度 - const modelConstantsData = ModelList['openai'].find( - (item) => item.model === model.service.modelName - ); + const modelConstantsData = ModelList.find((item) => item.model === model.service.modelName); if (!modelConstantsData) { throw new Error('模型异常'); } @@ -41,7 +40,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const response = await chatAPI.createCompletion( { model: model.service.modelName, - prompt: formatPrompt, + prompt: formatPrompts, temperature: temperature, // max_tokens: modelConstantsData.maxToken, top_p: 1, @@ -54,7 +53,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } ); - const responseMessage = response.data.choices[0]?.text; + const responseMessage = response.data.choices[0]?.text || ''; + + const promptsLen = prompt.reduce((sum, item) => sum + item.value.length, 0); + console.log(`responseLen: ${responseMessage.length}`, `promptLen: ${promptsLen}`); + // 只有使用平台的 key 才计费 + !userApiKey && + pushBill({ + modelName: model.service.modelName, + userId, + chatId, + textLen: promptsLen + responseMessage.length + }); jsonRes(res, { data: responseMessage diff --git a/src/pages/api/model/create.ts b/src/pages/api/model/create.ts index 793321fc1..d9ae7f63a 100644 --- a/src/pages/api/model/create.ts +++ b/src/pages/api/model/create.ts @@ -4,19 +4,13 @@ import { jsonRes } from '@/service/response'; import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { ModelStatusEnum, ModelList, ChatModelNameEnum } from '@/constants/model'; -import type { ServiceName } from '@/types/mongoSchema'; import { Model } from '@/service/models/model'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { - name, - serviceModelName, - serviceModelCompany = 'openai' - } = req.body as { + const { name, serviceModelName } = req.body as { name: string; serviceModelName: `${ChatModelNameEnum}`; - serviceModelCompany: ServiceName; }; const { authorization } = req.headers; @@ -24,16 +18,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< throw new Error('无权操作'); } - if (!name || !serviceModelName || !serviceModelCompany) { + if (!name || !serviceModelName) { throw new Error('缺少参数'); } // 凭证校验 const userId = await authToken(authorization); - const modelItem = ModelList[serviceModelCompany].find( - (item) => item.model === serviceModelName - ); + const modelItem = ModelList.find((item) => item.model === serviceModelName); if (!modelItem) { throw new Error('模型不存在'); @@ -64,7 +56,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< userId, status: ModelStatusEnum.running, service: { - company: serviceModelCompany, + company: modelItem.serviceCompany, trainId: modelItem.trainName, chatModel: modelItem.model, modelName: modelItem.model diff --git a/src/pages/chat/components/SlideBar.tsx b/src/pages/chat/components/SlideBar.tsx index 5d04094f2..fa3fc4bf2 100644 --- a/src/pages/chat/components/SlideBar.tsx +++ b/src/pages/chat/components/SlideBar.tsx @@ -156,7 +156,7 @@ const SlideBar = ({ {/* 我的模型 & 历史记录 折叠框*/} - + {isSuccess && ( diff --git a/src/pages/model/components/CreateModel.tsx b/src/pages/model/components/CreateModel.tsx index b7dc5cf57..5e368acdc 100644 --- a/src/pages/model/components/CreateModel.tsx +++ b/src/pages/model/components/CreateModel.tsx @@ -42,7 +42,7 @@ const CreateModel = ({ formState: { errors } } = useForm({ defaultValues: { - serviceModelName: ModelList['openai'][0].model + serviceModelName: ModelList[0].model } }); @@ -95,7 +95,7 @@ const CreateModel = ({ required: '底层模型不能为空' })} > - {ModelList['openai'].map((item) => ( + {ModelList.map((item) => ( diff --git a/src/pages/model/detail.tsx b/src/pages/model/detail.tsx index 9a1ffe54d..646769f36 100644 --- a/src/pages/model/detail.tsx +++ b/src/pages/model/detail.tsx @@ -38,9 +38,7 @@ const ModelDetail = ({ modelId }: { modelId: string }) => { }); const canTrain = useMemo(() => { - const openai = ModelList[model.service.company].find( - (item) => item.model === model?.service.modelName - ); + const openai = ModelList.find((item) => item.model === model?.service.modelName); return openai && openai.trainName; }, [model]); diff --git a/src/pages/number/setting.tsx b/src/pages/number/setting.tsx index 4dc5d0274..2d4b39f05 100644 --- a/src/pages/number/setting.tsx +++ b/src/pages/number/setting.tsx @@ -68,17 +68,17 @@ const NumberSetting = () => { {userInfo?.email} - {/* + 余额: {userInfo?.balance} - + */} - */} + @@ -148,6 +148,55 @@ const NumberSetting = () => { + + + 使用记录 + + + + + + + + + + + + {accounts.map((item, i) => ( + + + + + + ))} + +
账号类型
+ + + + + } + colorScheme={'red'} + onClick={() => { + removeAccount(i); + handleSubmit(onclickSave)(); + }} + /> +
+
+
); }; diff --git a/src/service/events/bill.ts b/src/service/events/bill.ts new file mode 100644 index 000000000..79645fd4a --- /dev/null +++ b/src/service/events/bill.ts @@ -0,0 +1,41 @@ +import { connectToDatabase, Bill, User } from '../mongo'; +import { ModelList } from '@/constants/model'; + +export const pushBill = async ({ + modelName, + userId, + chatId, + textLen +}: { + modelName: string; + userId: string; + chatId: string; + textLen: number; +}) => { + await connectToDatabase(); + + const modelItem = ModelList.find((item) => item.model === modelName); + + if (!modelItem) return; + + const price = modelItem.price * textLen; + + let billId; + try { + // 插入 Bill 记录 + const res = await Bill.create({ + userId, + chatId, + textLen, + price + }); + billId = res._id; + + // 扣费 + await User.findByIdAndUpdate(userId, { + $inc: { balance: -price } + }); + } catch (error) { + Bill.findByIdAndDelete(billId); + } +}; diff --git a/src/service/models/bill.ts b/src/service/models/bill.ts new file mode 100644 index 000000000..51d00e6ba --- /dev/null +++ b/src/service/models/bill.ts @@ -0,0 +1,29 @@ +import { Schema, model, models } from 'mongoose'; + +const BillSchema = new Schema({ + userId: { + type: Schema.Types.ObjectId, + ref: 'user', + required: true + }, + chatId: { + type: Schema.Types.ObjectId, + ref: 'chat', + required: true + }, + time: { + type: Number, + default: () => Date.now() + }, + textLen: { + // 提示词+响应的总字数 + type: Number, + required: true + }, + price: { + type: Number, + required: true + } +}); + +export const Bill = models['bill'] || model('bill', BillSchema); diff --git a/src/service/mongo.ts b/src/service/mongo.ts index 3de00366a..5fc6f5b4c 100644 --- a/src/service/mongo.ts +++ b/src/service/mongo.ts @@ -30,3 +30,4 @@ export * from './models/chat'; export * from './models/model'; export * from './models/user'; export * from './models/training'; +export * from './models/bill'; diff --git a/src/service/utils/chat.ts b/src/service/utils/chat.ts index 81b785584..890826000 100644 --- a/src/service/utils/chat.ts +++ b/src/service/utils/chat.ts @@ -1,6 +1,7 @@ import { Configuration, OpenAIApi } from 'openai'; import { Chat } from '../mongo'; import type { ChatPopulate } from '@/types/mongoSchema'; +import { formatPrice } from '@/utils/user'; export const getOpenAIApi = (apiKey: string) => { const configuration = new Configuration({ @@ -40,12 +41,14 @@ export const authChat = async (chatId: string) => { const userApiKey = user.accounts?.find((item: any) => item.type === 'openai')?.value; - if (!userApiKey) { - return Promise.reject('缺少ApiKey, 无法请求'); + if (!userApiKey && formatPrice(user.balance) <= -1) { + return Promise.reject('该账号余额不足'); } return { userApiKey, - chat + systemKey: process.env.OPENAIKEY as string, + chat, + userId: user._id }; }; diff --git a/src/store/chat.ts b/src/store/chat.ts index 756041b8d..1b2a42cb7 100644 --- a/src/store/chat.ts +++ b/src/store/chat.ts @@ -19,6 +19,7 @@ export const useChatStore = create()( chatHistory: [], pushChatHistory(item: HistoryItem) { set((state) => { + if (state.chatHistory.find((history) => history.chatId === item.chatId)) return; state.chatHistory = [item, ...state.chatHistory].slice(0, 20); }); }, diff --git a/src/store/user.ts b/src/store/user.ts index af5003409..a127c545e 100644 --- a/src/store/user.ts +++ b/src/store/user.ts @@ -5,6 +5,7 @@ import type { UserType, UserUpdateParams } from '@/types/user'; import type { ModelSchema } from '@/types/mongoSchema'; import { setToken } from '@/utils/user'; import { getMyModels } from '@/api/model'; +import { formatPrice } from '@/utils/user'; type State = { userInfo: UserType | null; @@ -21,7 +22,10 @@ export const useUserStore = create()( userInfo: null, setUserInfo(user: UserType, token?: string) { set((state) => { - state.userInfo = user; + state.userInfo = { + ...user, + balance: formatPrice(user.balance) + }; }); token && setToken(token); }, diff --git a/src/utils/user.ts b/src/utils/user.ts index 190fc4314..f7749d0db 100644 --- a/src/utils/user.ts +++ b/src/utils/user.ts @@ -9,3 +9,10 @@ export const getToken = () => { export const clearToken = () => { localStorage.removeItem(tokenKey); }; + +/** + * 把数据库读取到的price,转化成元 + */ +export const formatPrice = (val: number) => { + return val / 100000; +};