mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 应用相关接口,模型相关接口
This commit is contained in:
parent
d41f3214ec
commit
f28b222388
|
|
@ -27,7 +27,7 @@ share/python-wheels/
|
|||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# Usually these files are written by a python script froms a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 4.1.10 on 2023-10-24 12:13
|
||||
# Generated by Django 4.1.10 on 2023-11-14 07:30
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
|
@ -11,6 +11,7 @@ class Migration(migrations.Migration):
|
|||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('setting', '0003_alter_model_provider'),
|
||||
('users', '0001_initial'),
|
||||
('dataset', '0001_initial'),
|
||||
]
|
||||
|
|
@ -26,8 +27,9 @@ class Migration(migrations.Migration):
|
|||
('desc', models.CharField(max_length=128, verbose_name='引用描述')),
|
||||
('prologue', models.CharField(max_length=1024, verbose_name='开场白')),
|
||||
('example', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=256), size=None, verbose_name='示例列表')),
|
||||
('dialogue_number', models.IntegerField(default=0, verbose_name='会话数量')),
|
||||
('status', models.BooleanField(default=True, verbose_name='是否发布')),
|
||||
('is_active', models.BooleanField(default=True)),
|
||||
('model', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user')),
|
||||
],
|
||||
options={
|
||||
|
|
@ -35,13 +37,49 @@ class Migration(migrations.Migration):
|
|||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ApplicationDatasetMapping',
|
||||
name='Chat',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('application', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')),
|
||||
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('abstract', models.CharField(max_length=256, verbose_name='摘要')),
|
||||
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'application_chat',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ChatRecord',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('vote_status', models.CharField(choices=[('-1', '未投票'), ('0', '赞同'), ('1', '反对')], default='-1', max_length=10, verbose_name='投票')),
|
||||
('source_id', models.UUIDField(verbose_name='资源id 段落/问题 id ')),
|
||||
('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')),
|
||||
('message_tokens', models.IntegerField(default=0, verbose_name='请求token数量')),
|
||||
('answer_tokens', models.IntegerField(default=0, verbose_name='响应token数量')),
|
||||
('problem_text', models.CharField(max_length=1024, verbose_name='问题')),
|
||||
('answer_text', models.CharField(max_length=1024, verbose_name='答案')),
|
||||
('improve_problem_id_list', django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='改进标注列表')),
|
||||
('index', models.IntegerField(verbose_name='对话下标')),
|
||||
('chat', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.chat')),
|
||||
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集')),
|
||||
('paragraph', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'application_chat_record',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ApplicationDatasetMapping',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')),
|
||||
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='dataset.dataset')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'application_dataset_mapping',
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-14 08:02
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0001_initial'),
|
||||
('application', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='dataset',
|
||||
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='paragraph',
|
||||
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-14 08:03
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0001_initial'),
|
||||
('application', '0002_alter_chatrecord_dataset_alter_chatrecord_paragraph'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='dataset',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='paragraph',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-14 08:09
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0003_alter_chatrecord_dataset_alter_chatrecord_paragraph'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='source_id',
|
||||
field=models.UUIDField(null=True, verbose_name='资源id 段落/问题 id '),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='source_type',
|
||||
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('-1', '其他')], default='-1', max_length=5, verbose_name='资源类型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-14 08:11
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0004_alter_chatrecord_source_id_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='source_type',
|
||||
field=models.CharField(blank=True, choices=[('0', '问题'), ('1', '段落')], default='0', max_length=2, null=True, verbose_name='资源类型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-15 09:55
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('users', '0001_initial'),
|
||||
('application', '0005_alter_chatrecord_source_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='ApplicationAccessToken',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='application.application', verbose_name='应用id')),
|
||||
('access_token', models.CharField(max_length=128, unique=True, verbose_name='用户公开访问 认证token')),
|
||||
('is_active', models.BooleanField(default=True, verbose_name='是否开启公开访问')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'application_access_token',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ApplicationApiKey',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('secret_key', models.CharField(max_length=1024, unique=True, verbose_name='秘钥')),
|
||||
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', verbose_name='用户id')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'application_api_key',
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
@ -12,7 +12,9 @@ from django.contrib.postgres.fields import ArrayField
|
|||
from django.db import models
|
||||
|
||||
from common.mixins.app_model_mixin import AppModelMixin
|
||||
from dataset.models.data_set import DataSet
|
||||
from dataset.models.data_set import DataSet, Paragraph
|
||||
from embedding.models import SourceType
|
||||
from setting.models.model_management import Model
|
||||
from users.models import User
|
||||
|
||||
|
||||
|
|
@ -22,16 +24,62 @@ class Application(AppModelMixin):
|
|||
desc = models.CharField(max_length=128, verbose_name="引用描述")
|
||||
prologue = models.CharField(max_length=1024, verbose_name="开场白")
|
||||
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True))
|
||||
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
|
||||
status = models.BooleanField(default=True, verbose_name="是否发布")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
|
||||
is_active = models.BooleanField(default=True)
|
||||
model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "application"
|
||||
|
||||
|
||||
class ApplicationDatasetMapping(AppModelMixin):
|
||||
application = models.ForeignKey(Application, on_delete=models.DO_NOTHING)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
application = models.ForeignKey(Application, on_delete=models.CASCADE)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
db_table = "application_dataset_mapping"
|
||||
|
||||
|
||||
class Chat(AppModelMixin):
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
application = models.ForeignKey(Application, on_delete=models.CASCADE)
|
||||
abstract = models.CharField(max_length=256, verbose_name="摘要")
|
||||
|
||||
class Meta:
|
||||
db_table = "application_chat"
|
||||
|
||||
|
||||
class VoteChoices(models.TextChoices):
|
||||
"""订单类型"""
|
||||
UN_VOTE = -1, '未投票'
|
||||
STAR = 0, '赞同'
|
||||
TRAMPLE = 1, '反对'
|
||||
|
||||
|
||||
class ChatRecord(AppModelMixin):
|
||||
"""
|
||||
对话日志 详情
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
|
||||
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
|
||||
default=VoteChoices.UN_VOTE)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集", blank=True, null=True)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落id", blank=True, null=True)
|
||||
source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True)
|
||||
source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices,
|
||||
default=SourceType.PROBLEM, blank=True, null=True)
|
||||
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
|
||||
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
|
||||
problem_text = models.CharField(max_length=1024, verbose_name="问题")
|
||||
answer_text = models.CharField(max_length=1024, verbose_name="答案")
|
||||
improve_problem_id_list = ArrayField(verbose_name="改进标注列表",
|
||||
base_field=models.UUIDField(max_length=128, blank=True)
|
||||
, default=list)
|
||||
|
||||
index = models.IntegerField(verbose_name="对话下标")
|
||||
|
||||
class Meta:
|
||||
db_table = "application_chat_record"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,366 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: application_serializers.py
|
||||
@date:2023/11/7 10:02
|
||||
@desc:
|
||||
"""
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core import cache
|
||||
from django.core import signing
|
||||
from django.db import transaction, models
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.models import Application, ApplicationDatasetMapping
|
||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException, NotFound404, AppAuthenticationFailed
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import DataSet
|
||||
from setting.models import AuthOperate
|
||||
from setting.models.model_management import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from smartdoc.settings import JWT_AUTH
|
||||
|
||||
token_cache = cache.caches['token_cache']
|
||||
|
||||
|
||||
class ModelDatasetAssociation(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
model_id = serializers.CharField(required=True)
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
model_id = self.data.get('model_id')
|
||||
user_id = self.data.get('user_id')
|
||||
if not QuerySet(Model).filter(id=model_id).exists():
|
||||
raise AppApiException(500, f'模型不存在【{model_id}】')
|
||||
dataset_id_list = list(set(self.data.get('dataset_id_list')))
|
||||
exist_dataset_id_list = [str(dataset.id) for dataset in
|
||||
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
|
||||
for dataset_id in dataset_id_list:
|
||||
if not exist_dataset_id_list.__contains__(dataset_id):
|
||||
raise AppApiException(500, f'数据集id不存在【{dataset_id}】')
|
||||
|
||||
|
||||
class ApplicationSerializerModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Application
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class ApplicationSerializer(serializers.Serializer):
|
||||
name = serializers.CharField(required=True)
|
||||
desc = serializers.CharField(required=True)
|
||||
model_id = serializers.CharField(required=True)
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
||||
prologue = serializers.CharField(required=True)
|
||||
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
|
||||
class AccessTokenSerializer(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
||||
class AccessTokenEditSerializer(serializers.Serializer):
|
||||
access_token_reset = serializers.UUIDField(required=False)
|
||||
is_active = serializers.BooleanField(required=False)
|
||||
|
||||
def edit(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
ApplicationSerializer.AccessTokenSerializer.AccessTokenEditSerializer(data=instance).is_valid(
|
||||
raise_exception=True)
|
||||
|
||||
application_access_token = QuerySet(ApplicationAccessToken).get(
|
||||
application_id=self.data.get('application_id'))
|
||||
if 'is_active' in instance:
|
||||
application_access_token.is_active = instance.get("is_active")
|
||||
if 'access_token_reset' in instance and instance.get('access_token_reset'):
|
||||
application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]
|
||||
application_access_token.save()
|
||||
return self.one(with_valid=False)
|
||||
|
||||
def one(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
application_id = self.data.get("application_id")
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=application_id).first()
|
||||
if application_access_token is None:
|
||||
application_access_token = ApplicationAccessToken(application_id=application_id,
|
||||
access_token=hashlib.md5(
|
||||
str(uuid.uuid1()).encode()).hexdigest()[
|
||||
8:24], is_active=True)
|
||||
application_access_token.save()
|
||||
return {'application_id': application_access_token.application_id,
|
||||
'access_token': application_access_token.access_token,
|
||||
"is_active": application_access_token.is_active}
|
||||
|
||||
class Authentication(serializers.Serializer):
|
||||
access_token = serializers.CharField(required=True)
|
||||
|
||||
def auth(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
access_token = self.data.get("access_token")
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
|
||||
if application_access_token is not None and application_access_token.is_active:
|
||||
token = signing.dumps({'application_id': str(application_access_token.application_id),
|
||||
'user_id': str(application_access_token.application.user.id),
|
||||
'access_token': application_access_token.access_token,
|
||||
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value})
|
||||
token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'])
|
||||
return token
|
||||
else:
|
||||
raise AppAuthenticationFailed(401, "无效的access_token")
|
||||
|
||||
class Edit(serializers.Serializer):
|
||||
name = serializers.CharField(required=False)
|
||||
desc = serializers.CharField(required=False)
|
||||
model_id = serializers.CharField(required=False)
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=False)
|
||||
prologue = serializers.CharField(required=False)
|
||||
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
|
||||
def is_valid(self, *, user_id=None, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
|
||||
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
@transaction.atomic
|
||||
def insert(self, application: Dict):
|
||||
self.is_valid(raise_exception=True)
|
||||
user_id = self.data.get('user_id')
|
||||
ApplicationSerializer(data=application).is_valid(user_id=user_id, raise_exception=True)
|
||||
application_model = ApplicationSerializer.Create.to_application_model(user_id, application)
|
||||
dataset_id_list = application.get('dataset_id_list', [])
|
||||
application_dataset_mapping_model_list = [
|
||||
ApplicationSerializer.Create.to_application_dateset_mapping(application_model.id, dataset_id) for
|
||||
dataset_id in dataset_id_list]
|
||||
# 插入应用
|
||||
application_model.save()
|
||||
# 插入认证信息
|
||||
ApplicationAccessToken(application_id=application_model.id,
|
||||
access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save()
|
||||
# 插入关联数据
|
||||
QuerySet(ApplicationDatasetMapping).bulk_create(application_dataset_mapping_model_list)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def to_application_model(user_id: str, application: Dict):
|
||||
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
|
||||
prologue=application.get('prologue'), example=application.get('example'),
|
||||
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
|
||||
status=True, user_id=user_id, model_id=application.get('model_id'),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_application_dateset_mapping(application_id: str, dataset_id: str):
|
||||
return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id)
|
||||
|
||||
class Query(serializers.Serializer):
|
||||
name = serializers.CharField(required=False)
|
||||
|
||||
desc = serializers.CharField(required=False)
|
||||
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
def get_query_set(self):
|
||||
user_id = self.data.get("user_id")
|
||||
query_set_dict = {}
|
||||
query_set = QuerySet(model=get_dynamics_model(
|
||||
{'temp_application.name': models.CharField(), 'temp_application.desc': models.CharField()}))
|
||||
if "desc" in self.data and self.data.get('desc') is not None:
|
||||
query_set = query_set.filter(**{'temp_application.desc__contains': self.data.get("desc")})
|
||||
if "name" in self.data and self.data.get('name') is not None:
|
||||
query_set = query_set.filter(**{'temp_application.name__contains': self.data.get("name")})
|
||||
|
||||
query_set_dict['default_sql'] = query_set
|
||||
|
||||
query_set_dict['application_custom_sql'] = QuerySet(model=get_dynamics_model(
|
||||
{'application.user_id': models.CharField(),
|
||||
})).filter(
|
||||
**{'application.user_id': user_id}
|
||||
)
|
||||
|
||||
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
|
||||
{'user_id': models.CharField(),
|
||||
'team_member_permission.auth_target_type': models.CharField(),
|
||||
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
|
||||
base_field=models.CharField(max_length=256,
|
||||
blank=True,
|
||||
choices=AuthOperate.choices,
|
||||
default=AuthOperate.USE)
|
||||
)})).filter(
|
||||
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'],
|
||||
'team_member_permission.auth_target_type': 'APPLICATION'})
|
||||
|
||||
return query_set_dict
|
||||
|
||||
def list(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return [ApplicationSerializer.Query.reset_application(a) for a in
|
||||
native_search(self.get_query_set(), select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')))]
|
||||
|
||||
@staticmethod
|
||||
def reset_application(application: Dict):
|
||||
application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False
|
||||
del application['dialogue_number']
|
||||
return application
|
||||
|
||||
def page(self, current_page: int, page_size: int, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')),
|
||||
post_records_handler=ApplicationSerializer.Query.reset_application)
|
||||
|
||||
class ApplicationModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Application
|
||||
fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number', 'status']
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
|
||||
raise AppApiException(500, '不存在的应用id')
|
||||
|
||||
def delete(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
QuerySet(Application).filter(id=self.data.get('application_id')).delete()
|
||||
return True
|
||||
|
||||
def one(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
application_id = self.data.get("application_id")
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
dataset_list = self.list_dataset(with_valid=False)
|
||||
mapping_dataset_id_list = [adm.dataset_id for adm in
|
||||
QuerySet(ApplicationDatasetMapping).filter(application_id=application_id)]
|
||||
dataset_id_list = [d.get('id') for d in
|
||||
list(filter(lambda row: mapping_dataset_id_list.__contains__(row.get('id')),
|
||||
dataset_list))]
|
||||
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data),
|
||||
'dataset_id_list': dataset_id_list}
|
||||
|
||||
def profile(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
application_id = self.data.get("application_id")
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
return ApplicationSerializer.Query.reset_application(
|
||||
ApplicationSerializer.ApplicationModel(application).data)
|
||||
|
||||
def edit(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
ApplicationSerializer.Edit(data=instance).is_valid(
|
||||
raise_exception=True)
|
||||
application_id = self.data.get("application_id")
|
||||
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
|
||||
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
|
||||
|
||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
if update_key == 'multiple_rounds_dialogue':
|
||||
application.__setattr__('dialogue_number',
|
||||
0 if instance.get(update_key) else ModelProvideConstants[
|
||||
model.provider].value.get_dialogue_number())
|
||||
else:
|
||||
application.__setattr__(update_key, instance.get(update_key))
|
||||
application.save()
|
||||
|
||||
if 'dataset_id_list' in instance:
|
||||
dataset_id_list = instance.get('dataset_id_list')
|
||||
# 当前用户可修改关联的数据集列表
|
||||
application_dataset_id_list = [dataset_dict.get('id') for dataset_dict in
|
||||
self.list_dataset(with_valid=False)]
|
||||
for dataset_id in dataset_id_list:
|
||||
if not application_dataset_id_list.__contains__(dataset_id):
|
||||
raise AppApiException(500, f"未知的数据集id${dataset_id},无法关联")
|
||||
|
||||
# 删除已经关联的id
|
||||
QuerySet(ApplicationDatasetMapping).filter(dataset_id__in=application_dataset_id_list,
|
||||
application_id=application_id).delete()
|
||||
# 插入
|
||||
QuerySet(ApplicationDatasetMapping).bulk_create(
|
||||
[ApplicationDatasetMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in
|
||||
dataset_id_list]) if len(dataset_id_list) > 0 else None
|
||||
return self.one(with_valid=False)
|
||||
|
||||
def list_dataset(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
application = QuerySet(Application).get(id=self.data.get("application_id"))
|
||||
return select_list(get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')),
|
||||
[self.data.get('user_id'), application.user_id, self.data.get('user_id')])
|
||||
|
||||
class ApplicationKeySerializerModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ApplicationApiKey
|
||||
fields = "__all__"
|
||||
|
||||
class ApplicationKeySerializer(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
||||
def generate(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
user_id = self.data.get("user_id")
|
||||
application_id = self.data.get("application_id")
|
||||
secret_key = 'application-' + hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()
|
||||
application_api_key = ApplicationApiKey(id=uuid.uuid1(), secret_key=secret_key, user_id=user_id,
|
||||
application_id=application_id)
|
||||
application_api_key.save()
|
||||
return ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data
|
||||
|
||||
def list(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
user_id = self.data.get("user_id")
|
||||
application_id = self.data.get("application_id")
|
||||
return [ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data for
|
||||
application_api_key in
|
||||
QuerySet(ApplicationApiKey).filter(user_id=user_id, application_id=application_id)]
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
||||
api_key_id = serializers.CharField(required=True)
|
||||
|
||||
def delete(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
api_key_id = self.data.get("api_key_id")
|
||||
application_id = self.data.get('application_id')
|
||||
QuerySet(ApplicationApiKey).filter(id=api_key_id,
|
||||
application_id=application_id).delete()
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: chat_message_serializers.py
|
||||
@date:2023/11/14 13:51
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.http import StreamingHttpResponse
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage
|
||||
from rest_framework import serializers, status
|
||||
from django.core.cache import cache
|
||||
from common import event
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.response import result
|
||||
from dataset.models import Paragraph
|
||||
from embedding.models import SourceType
|
||||
from setting.models.model_management import Model
|
||||
|
||||
chat_cache = cache
|
||||
|
||||
|
||||
class MessageManagement:
|
||||
@staticmethod
|
||||
def get_message(title: str, content: str, message: str):
|
||||
if content is None:
|
||||
return HumanMessage(content=message)
|
||||
return HumanMessage(content=(
|
||||
f'已知信息:{title}:{content} '
|
||||
'根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从已知信息中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 '
|
||||
f'问题是:{message}'))
|
||||
|
||||
|
||||
class ChatMessage:
|
||||
def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str,
|
||||
document_id: str,
|
||||
paragraph_id,
|
||||
source_type: SourceType,
|
||||
source_id: str,
|
||||
answer: str,
|
||||
message_tokens: int,
|
||||
answer_token: int):
|
||||
self.id = id
|
||||
self.problem = problem
|
||||
self.title = title
|
||||
self.paragraph = paragraph
|
||||
self.embedding_id = embedding_id
|
||||
self.dataset_id = dataset_id
|
||||
self.document_id = document_id
|
||||
self.paragraph_id = paragraph_id
|
||||
self.source_type = source_type
|
||||
self.source_id = source_id
|
||||
self.answer = answer
|
||||
self.message_tokens = message_tokens
|
||||
self.answer_token = answer_token
|
||||
|
||||
def get_chat_message(self):
|
||||
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
|
||||
|
||||
|
||||
class ChatInfo:
|
||||
def __init__(self,
|
||||
chat_id: str,
|
||||
model: Model,
|
||||
chat_model: BaseChatModel,
|
||||
application_id: str | None,
|
||||
dataset_id_list: List[str],
|
||||
exclude_document_id_list: list[str],
|
||||
dialogue_number: int):
|
||||
self.chat_id = chat_id
|
||||
self.application_id = application_id
|
||||
self.model = model
|
||||
self.chat_model = chat_model
|
||||
self.dataset_id_list = dataset_id_list
|
||||
self.exclude_document_id_list = exclude_document_id_list
|
||||
self.dialogue_number = dialogue_number
|
||||
self.chat_message_list: List[ChatMessage] = []
|
||||
|
||||
def append_chat_message(self, chat_message: ChatMessage):
|
||||
self.chat_message_list.append(chat_message)
|
||||
if self.application_id is not None:
|
||||
event.ListenerChatMessage.record_chat_message_signal.send(
|
||||
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
|
||||
chat_message)
|
||||
)
|
||||
|
||||
def get_context_message(self):
|
||||
start_index = len(self.chat_message_list) - self.dialogue_number
|
||||
return [self.chat_message_list[index].get_chat_message() for index in
|
||||
range(start_index if start_index > 0 else 0, len(self.chat_message_list))]
|
||||
|
||||
|
||||
class ChatMessageSerializer(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
def chat(self, message):
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||
if chat_info is None:
|
||||
return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期")
|
||||
|
||||
chat_model = chat_info.chat_model
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
# 向量库检索
|
||||
_value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list,
|
||||
[chat_message.embedding_id for chat_message in
|
||||
(list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))],
|
||||
True,
|
||||
EmbeddingModel.get_embedding_model())
|
||||
# 查询段落id详情
|
||||
paragraph = None
|
||||
if _value is not None:
|
||||
paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id'))
|
||||
if paragraph is None:
|
||||
vector.delete_by_paragraph_id(_value.get('paragraph_id'))
|
||||
|
||||
title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content)
|
||||
|
||||
embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get(
|
||||
'id'), _value.get(
|
||||
'dataset_id'), _value.get(
|
||||
'document_id'), _value.get(
|
||||
'paragraph_id'), _value.get(
|
||||
'source_type'), _value.get(
|
||||
'source_id')) if _value is not None else (None, None, None, None, None, None)
|
||||
# 获取上下文
|
||||
history_message = chat_info.get_context_message()
|
||||
|
||||
# 构建会话请求问题
|
||||
chat_message = [*history_message, MessageManagement.get_message(title, content, message)]
|
||||
# 对话
|
||||
result_data = chat_model.stream(chat_message)
|
||||
|
||||
_id = str(uuid.uuid1())
|
||||
|
||||
def event_content(response):
|
||||
all_text = ''
|
||||
try:
|
||||
for chunk in response:
|
||||
all_text += chunk.content
|
||||
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
|
||||
'content': chunk.content}) + "\n\n"
|
||||
|
||||
chat_info.append_chat_message(
|
||||
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
|
||||
paragraph_id,
|
||||
source_type,
|
||||
source_id, all_text, chat_model.get_num_tokens_from_messages(chat_message),
|
||||
chat_model.get_num_tokens(all_text)))
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
except Exception as e:
|
||||
yield e
|
||||
|
||||
r = StreamingHttpResponse(streaming_content=event_content(result_data),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: chat_serializers.py
|
||||
@date:2023/11/14 9:59
|
||||
@desc:
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
|
||||
from application.serializers.application_serializers import ModelDatasetAssociation
|
||||
from application.serializers.chat_message_serializers import ChatInfo
|
||||
from common.db.search import native_search, native_page_search, page_search
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import decrypt
|
||||
from dataset.models import Document, Problem, Paragraph
|
||||
from embedding.models import SourceType, Embedding
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.views import Model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
chat_cache = cache
|
||||
|
||||
|
||||
class ChatSerializers(serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
abstract = serializers.CharField(required=False)
|
||||
history_day = serializers.IntegerField(required=True)
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
||||
def get_end_time(self):
|
||||
history_day = self.data.get('history_day')
|
||||
return datetime.datetime.now() - datetime.timedelta(days=history_day)
|
||||
|
||||
def get_query_set(self):
|
||||
end_time = self.get_end_time()
|
||||
return QuerySet(Chat).filter(application_id=self.data.get("application_id"), create_time__gte=end_time)
|
||||
|
||||
def list(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return native_search(self.get_query_set(), select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
|
||||
with_table_name=True)
|
||||
|
||||
def page(self, current_page: int, page_size: int, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
|
||||
with_table_name=True)
|
||||
|
||||
class OpenChat(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
user_id = self.data.get('user_id')
|
||||
application_id = self.data.get('application_id')
|
||||
if not QuerySet(Application).filter(id=application_id, user_id=user_id).exists():
|
||||
raise AppApiException(500, '应用不存在')
|
||||
|
||||
def open(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
application_id = self.data.get('application_id')
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
model = application.model
|
||||
dataset_id_list = [str(row.dataset_id) for row in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
application_id=application_id)]
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
streaming=True)
|
||||
|
||||
chat_id = str(uuid.uuid1())
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)],
|
||||
application.dialogue_number), timeout=60 * 30)
|
||||
return chat_id
|
||||
|
||||
class OpenTempChat(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
model_id = serializers.UUIDField(required=True)
|
||||
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
||||
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
ModelDatasetAssociation(
|
||||
data={'user_id': self.data.get('user_id'), 'model_id': self.data.get('model_id'),
|
||||
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
|
||||
|
||||
def open(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_id = str(uuid.uuid1())
|
||||
model = QuerySet(Model).get(user_id=self.data.get('user_id'), id=self.data.get('model_id'))
|
||||
dataset_id_list = self.data.get('dataset_id_list')
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
streaming=True)
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, model, chat_model, None, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)],
|
||||
3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30)
|
||||
return chat_id
|
||||
|
||||
|
||||
def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
|
||||
if source_type == SourceType.PROBLEM:
|
||||
problem = QuerySet(Problem).get(id=source_id)
|
||||
if problem is not None:
|
||||
problem.__setattr__(field, post_handler(problem))
|
||||
problem.save()
|
||||
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
|
||||
embedding.__setattr__(field, problem.__getattribute__(field))
|
||||
embedding.save()
|
||||
if source_type == SourceType.PARAGRAPH:
|
||||
paragraph = QuerySet(Paragraph).get(id=source_id)
|
||||
if paragraph is not None:
|
||||
paragraph.__setattr__(field, post_handler(paragraph))
|
||||
paragraph.save()
|
||||
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
|
||||
embedding.__setattr__(field, paragraph.__getattribute__(field))
|
||||
embedding.save()
|
||||
|
||||
|
||||
class ChatRecordSerializerModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ChatRecord
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class ChatRecordSerializer(serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
def list(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return [ChatRecordSerializerModel(chat_record).data for chat_record in
|
||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
|
||||
|
||||
def page(self, current_page: int, page_size: int, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return page_search(current_page, page_size,
|
||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
|
||||
post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data)
|
||||
|
||||
class Vote(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
chat_record_id = serializers.UUIDField(required=True)
|
||||
|
||||
vote_status = serializers.ChoiceField(choices=VoteChoices.choices)
|
||||
|
||||
@transaction.atomic
|
||||
def vote(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
if not try_lock(self.data.get('chat_record_id')):
|
||||
raise AppApiException(500, "正在对当前会话纪要进行投票中,请勿重复发送请求")
|
||||
try:
|
||||
chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'),
|
||||
chat_id=self.data.get('chat_id'))
|
||||
if chat_record_details_model is None:
|
||||
raise AppApiException(500, "不存在的对话 chat_record_id")
|
||||
vote_status = self.data.get("vote_status")
|
||||
if chat_record_details_model.vote_status == VoteChoices.UN_VOTE:
|
||||
if vote_status == VoteChoices.STAR:
|
||||
# 点赞
|
||||
chat_record_details_model.vote_status = VoteChoices.STAR
|
||||
# 点赞数量 +1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'star_num',
|
||||
lambda r: r.star_num + 1)
|
||||
|
||||
if vote_status == VoteChoices.TRAMPLE:
|
||||
# 点踩
|
||||
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
|
||||
# 点踩数量+1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'trample_num',
|
||||
lambda r: r.trample_num + 1)
|
||||
chat_record_details_model.save()
|
||||
else:
|
||||
if vote_status == VoteChoices.UN_VOTE:
|
||||
# 取消点赞
|
||||
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
|
||||
chat_record_details_model.save()
|
||||
if chat_record_details_model.vote_status == VoteChoices.STAR:
|
||||
# 点赞数量 -1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'star_num', lambda r: r.star_num - 1)
|
||||
if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
|
||||
# 点踩数量 -1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'trample_num', lambda r: r.trample_num - 1)
|
||||
|
||||
else:
|
||||
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
|
||||
finally:
|
||||
un_lock(self.data.get('chat_record_id'))
|
||||
|
||||
return True
|
||||
|
||||
class ImproveSerializer(serializers.Serializer):
|
||||
title = serializers.CharField(required=False)
|
||||
content = serializers.CharField(required=True)
|
||||
|
||||
class Improve(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
chat_record_id = serializers.UUIDField(required=True)
|
||||
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Document).filter(id=self.data.get('document_id'),
|
||||
dataset_id=self.data.get('dataset_id')).exists():
|
||||
raise AppApiException(500, "文档id不正确")
|
||||
|
||||
@transaction.atomic
|
||||
def improve(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
ChatRecordSerializer.ImproveSerializer(data=instance).is_valid(raise_exception=True)
|
||||
chat_record_id = self.data.get('chat_record_id')
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
||||
if chat_record is None:
|
||||
raise AppApiException(500, '不存在的对话记录')
|
||||
|
||||
document_id = self.data.get("document_id")
|
||||
dataset_id = self.data.get("dataset_id")
|
||||
paragraph = Paragraph(id=uuid.uuid1(),
|
||||
document_id=document_id,
|
||||
content=instance.get("content"),
|
||||
dataset_id=dataset_id,
|
||||
title=instance.get("title") if 'title' in instance else '')
|
||||
|
||||
problem = Problem(id=uuid.uuid1(), content=chat_record.problem_text, paragraph_id=paragraph.id,
|
||||
document_id=document_id, dataset_id=dataset_id)
|
||||
# 插入问题
|
||||
problem.save()
|
||||
# 插入段落
|
||||
paragraph.save()
|
||||
chat_record.improve_problem_id_list.append(problem.id)
|
||||
# 添加标注
|
||||
chat_record.save()
|
||||
return True
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
application
|
||||
WHERE
|
||||
application."id" IN ( SELECT team_member_permission.target FROM team_member team_member LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" ${team_member_permission_custom_sql})
|
||||
) temp_application ${default_sql}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
SELECT
|
||||
*
|
||||
FROM
|
||||
application_chat application_chat
|
||||
LEFT JOIN (
|
||||
SELECT COUNT
|
||||
( "id" ) AS chat_record_count,
|
||||
SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num,
|
||||
SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num,
|
||||
SUM ( CASE WHEN array_length( application_chat_record.improve_problem_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_problem_id_list, 1 ) END ) AS mark_sum,
|
||||
chat_id
|
||||
FROM
|
||||
application_chat_record
|
||||
GROUP BY
|
||||
application_chat_record.chat_id
|
||||
) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
SELECT
|
||||
*
|
||||
FROM
|
||||
dataset
|
||||
WHERE
|
||||
user_id = %s UNION
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
dataset
|
||||
WHERE
|
||||
"id" IN (
|
||||
SELECT
|
||||
team_member_permission.target
|
||||
FROM
|
||||
team_member team_member
|
||||
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
|
||||
WHERE
|
||||
( "team_member_permission"."auth_target_type" = 'DATASET' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s )
|
||||
)
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: application_api.py
|
||||
@date:2023/11/7 10:50
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg import openapi
|
||||
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
|
||||
"""
|
||||
name = serializers.CharField(required=True)
|
||||
desc = serializers.CharField(required=True)
|
||||
model_id = serializers.CharField(required=True)
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
||||
prologue = serializers.CharField(required=True)
|
||||
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
"""
|
||||
|
||||
|
||||
class ApplicationApi(ApiMixin):
|
||||
class Authentication(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['access_token', ],
|
||||
properties={
|
||||
'access_token': openapi.Schema(type=openapi.TYPE_STRING, title="应用认证token",
|
||||
description="应用认证token"),
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'user_id', 'status', 'create_time',
|
||||
'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"),
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
||||
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
|
||||
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
||||
description="是否开启多轮对话"),
|
||||
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
||||
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="示例列表", description="示例列表"),
|
||||
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户", description="所属用户"),
|
||||
|
||||
'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否发布", description='是否发布'),
|
||||
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
|
||||
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
|
||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="关联数据集Id列表",
|
||||
description="关联数据集Id列表(查询详情的时候返回)")
|
||||
}
|
||||
)
|
||||
|
||||
class ApiKey(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id')
|
||||
|
||||
]
|
||||
|
||||
class Operate(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='api_key_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用api_key id')
|
||||
]
|
||||
|
||||
class AccessToken(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id')
|
||||
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=[],
|
||||
properties={
|
||||
'access_token_reset': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重置Token",
|
||||
description="重置Token"),
|
||||
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否激活", description="是否激活"),
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class Create(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
||||
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
|
||||
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
||||
description="是否开启多轮对话"),
|
||||
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
||||
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="示例列表", description="示例列表"),
|
||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="关联数据集Id列表", description="关联数据集Id列表")
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class Query(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='name',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='应用名称'),
|
||||
openapi.Parameter(name='desc',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='应用描述')
|
||||
]
|
||||
|
||||
class Operate(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
|
||||
]
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: chat_api.py
|
||||
@date:2023/11/7 17:29
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg import openapi
|
||||
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
|
||||
|
||||
class ChatApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['message'],
|
||||
properties={
|
||||
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class OpenChat(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
|
||||
]
|
||||
|
||||
class OpenTempChat(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['model_id', 'multiple_rounds_dialogue'],
|
||||
properties={
|
||||
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
|
||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="关联数据集Id列表", description="关联数据集Id列表"),
|
||||
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
|
||||
description="是否开启多轮会话")
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='history_day',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_NUMBER,
|
||||
required=True,
|
||||
description='历史天数')
|
||||
]
|
||||
|
||||
|
||||
class ChatRecordApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='chat_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='对话id'),
|
||||
]
|
||||
|
||||
|
||||
class ImproveApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='chat_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='会话id'),
|
||||
openapi.Parameter(name='chat_record_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='会话记录id'),
|
||||
openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='数据集id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id'),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['content'],
|
||||
properties={
|
||||
'title': openapi.Schema(type=openapi.TYPE_STRING, title="段落标题",
|
||||
description="段落标题"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
|
||||
description="段落内容")
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class VoteApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='chat_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='会话id'),
|
||||
openapi.Parameter(name='chat_record_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='会话记录id')
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['vote_status'],
|
||||
properties={
|
||||
'vote_status': openapi.Schema(type=openapi.TYPE_STRING, title="投票状态",
|
||||
description="-1:取消投票|0:赞同|1:反对"),
|
||||
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
app_name = "application"
|
||||
urlpatterns = [
|
||||
path('application', views.Application.as_view(), name="application"),
|
||||
path('application/profile', views.Application.Profile.as_view()),
|
||||
path('application/authentication', views.Application.Authentication.as_view()),
|
||||
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
|
||||
path("application/<str:application_id>/api_key/<str:api_key_id>",
|
||||
views.Application.ApplicationKey.Operate.as_view()),
|
||||
path('application/<str:application_id>', views.Application.Operate.as_view(), name='application/operate'),
|
||||
path('application/<str:application_id>/list_dataset', views.Application.ListApplicationDataSet.as_view(),
|
||||
name='application/dataset'),
|
||||
path('application/<str:application_id>/access_token', views.Application.AccessToken.as_view(),
|
||||
name='application/access_token'),
|
||||
path('application/<int:current_page>/<int:page_size>', views.Application.Page.as_view(), name='application_page'),
|
||||
path('application/<str:application_id>/chat/open', views.ChatView.Open.as_view()),
|
||||
path("application/chat/open", views.ChatView.OpenTemp.as_view()),
|
||||
path('application/<str:application_id>/chat', views.ChatView.as_view(), name='chats'),
|
||||
path('application/<str:application_id>/chat/<int:current_page>/<int:page_size>', views.ChatView.Page.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
|
||||
views.ChatView.ChatRecord.Page.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote',
|
||||
views.ChatView.ChatRecord.Vote.as_view(),
|
||||
name=''),
|
||||
path(
|
||||
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve',
|
||||
views.ChatView.ChatRecord.Improve.as_view(),
|
||||
name=''),
|
||||
path('application/chat_message/<str:chat_id>', views.ChatView.Message.as_view())
|
||||
|
||||
]
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from django.shortcuts import render
|
||||
|
||||
# Create your views here.
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2023/9/25 17:12
|
||||
@desc:
|
||||
"""
|
||||
from .application_views import *
|
||||
from .chat_views import *
|
||||
|
|
@ -0,0 +1,250 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: application_views.py
|
||||
@date:2023/10/27 14:56
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from application.serializers.application_serializers import ApplicationSerializer
|
||||
from application.swagger_api.application_api import ApplicationApi
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import CompareConstants, PermissionConstants, Permission, Group, Operate, \
|
||||
ViewPermission, RoleConstants
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.dataset_serializers import DataSetSerializers
|
||||
|
||||
|
||||
class Application(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
class Profile(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取应用相关信息",
|
||||
operation_id="获取应用相关信息",
|
||||
tags=["应用/会话"])
|
||||
def get(self, request: Request):
|
||||
if 'application_id' in request.auth.keywords:
|
||||
return result.success(ApplicationSerializer.Operate(
|
||||
data={'application_id': request.auth.keywords.get('application_id'),
|
||||
'user_id': request.user.id}).profile())
|
||||
else:
|
||||
raise AppAuthenticationFailed(401, "身份异常")
|
||||
|
||||
class ApplicationKey(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="新增ApiKey",
|
||||
operation_id="新增ApiKey",
|
||||
tags=['应用/API_KEY'],
|
||||
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def post(self, request: Request, application_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.ApplicationKeySerializer(
|
||||
data={'application_id': application_id, 'user_id': request.user.id}).generate())
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取应用API_KEY列表",
|
||||
operation_id="获取应用API_KEY列表",
|
||||
tags=['应用/API_KEY'],
|
||||
manual_parameters=ApplicationApi.ApiKey.get_request_params_api()
|
||||
)
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def get(self, request: Request, application_id: str):
|
||||
return result.success(ApplicationSerializer.ApplicationKeySerializer(
|
||||
data={'application_id': application_id, 'user_id': request.user.id}).list())
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除应用API_KEY",
|
||||
operation_id="删除应用API_KEY",
|
||||
tags=['应用/API_KEY'],
|
||||
manual_parameters=ApplicationApi.ApiKey.Operate.get_request_params_api())
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND), lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE,
|
||||
dynamic_tag=k.get('application_id')),
|
||||
compare=CompareConstants.AND)
|
||||
def delete(self, request: Request, application_id: str, api_key_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.ApplicationKeySerializer.Operate(
|
||||
data={'application_id': application_id, 'user_id': request.user.id,
|
||||
'api_key_id': api_key_id}).delete())
|
||||
|
||||
class AccessToken(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="修改 应用AccessToken",
|
||||
operation_id="修改 应用AccessToken",
|
||||
tags=['应用/公开访问'],
|
||||
manual_parameters=ApplicationApi.AccessToken.get_request_params_api(),
|
||||
request_body=ApplicationApi.AccessToken.get_request_body_api())
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def put(self, request: Request, application_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取应用 AccessToken信息",
|
||||
operation_id="获取应用 AccessToken信息",
|
||||
manual_parameters=ApplicationApi.AccessToken.get_request_params_api(),
|
||||
tags=['应用/公开访问'],
|
||||
)
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def get(self, request: Request, application_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).one())
|
||||
|
||||
class Authentication(APIView):
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="应用认证",
|
||||
operation_id="应用认证",
|
||||
request_body=ApplicationApi.Authentication.get_request_body_api(),
|
||||
tags=["应用/认证"],
|
||||
security=[])
|
||||
def post(self, request: Request):
|
||||
return result.success(
|
||||
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth())
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建应用",
|
||||
operation_id="创建应用",
|
||||
request_body=ApplicationApi.Create.get_request_body_api(),
|
||||
tags=['应用'])
|
||||
@has_permissions(PermissionConstants.APPLICATION_CREATE, compare=CompareConstants.AND)
|
||||
def post(self, request: Request):
|
||||
ApplicationSerializer.Create(data={'user_id': request.user.id}).insert(request.data)
|
||||
return result.success(True)
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取应用列表",
|
||||
operation_id="获取应用列表",
|
||||
manual_parameters=ApplicationApi.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ApplicationApi.get_response_body_api()),
|
||||
tags=['应用'])
|
||||
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request):
|
||||
return result.success(
|
||||
ApplicationSerializer.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list())
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除应用",
|
||||
operation_id="删除应用",
|
||||
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=['应用'])
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND),
|
||||
lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE,
|
||||
dynamic_tag=k.get('application_id')), compare=CompareConstants.AND)
|
||||
def delete(self, request: Request, application_id: str):
|
||||
return result.success(ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id, 'user_id': request.user.id}).delete(
|
||||
with_valid=True))
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="修改应用",
|
||||
operation_id="修改应用",
|
||||
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
|
||||
request_body=ApplicationApi.Create.get_request_body_api(),
|
||||
responses=result.get_api_array_response(ApplicationApi.get_response_body_api()),
|
||||
tags=['应用'])
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def put(self, request: Request, application_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit(
|
||||
request.data))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取应用详情",
|
||||
operation_id="获取应用详情",
|
||||
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ApplicationApi.get_response_body_api()),
|
||||
tags=['应用'])
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
|
||||
RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def get(self, request: Request, application_id: str):
|
||||
return result.success(ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id, 'user_id': request.user.id}).one())
|
||||
|
||||
class ListApplicationDataSet(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取当前应用可使用的数据集",
|
||||
operation_id="获取当前应用可使用的数据集",
|
||||
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
|
||||
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
|
||||
tags=['应用'])
|
||||
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def get(self, request: Request, application_id: str):
|
||||
return result.success(ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id, 'user_id': request.user.id}).list_dataset())
|
||||
|
||||
class Page(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="分页获取应用列表",
|
||||
operation_id="分页获取应用列表",
|
||||
manual_parameters=result.get_page_request_params(
|
||||
ApplicationApi.Query.get_request_params_api()),
|
||||
responses=result.get_page_api_response(ApplicationApi.get_response_body_api()),
|
||||
tags=['应用'])
|
||||
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request, current_page: int, page_size: int):
|
||||
return result.success(
|
||||
ApplicationSerializer.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).page(
|
||||
current_page, page_size))
|
||||
|
|
@ -0,0 +1,198 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: chat_views.py
|
||||
@date:2023/11/14 9:53
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from application.serializers.chat_message_serializers import ChatMessageSerializer
|
||||
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
|
||||
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate, \
|
||||
RoleConstants, ViewPermission, CompareConstants
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
|
||||
|
||||
class ChatView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
class Open(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取会话id,根据应用id",
|
||||
operation_id="获取会话id,根据应用id",
|
||||
manual_parameters=ChatApi.OpenChat.get_request_params_api(),
|
||||
tags=["应用/会话"])
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
|
||||
RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND)
|
||||
)
|
||||
def get(self, request: Request, application_id: str):
|
||||
return result.success(ChatSerializers.OpenChat(
|
||||
data={'user_id': request.user.id, 'application_id': application_id}).open())
|
||||
|
||||
class OpenTemp(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取会话id(根据模型id,数据集列表,是否多轮会话)",
|
||||
operation_id="获取会话id",
|
||||
request_body=ChatApi.OpenTempChat.get_request_body_api(),
|
||||
tags=["应用/会话"])
|
||||
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
|
||||
def post(self, request: Request):
|
||||
return result.success(ChatSerializers.OpenTempChat(
|
||||
data={**request.data, 'user_id': request.user.id}).open())
|
||||
|
||||
class Message(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="对话",
|
||||
operation_id="对话",
|
||||
request_body=ChatApi.get_request_body_api(),
|
||||
tags=["应用/会话"])
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
|
||||
RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def post(self, request: Request, chat_id: str):
|
||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话列表",
|
||||
operation_id="获取对话列表",
|
||||
manual_parameters=ChatApi.get_request_params_api(),
|
||||
tags=["应用/对话日志"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def get(self, request: Request, application_id: str):
|
||||
return result.success(ChatSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
|
||||
'user_id': request.user.id}).list())
|
||||
|
||||
class Page(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="分页获取对话列表",
|
||||
operation_id="分页获取对话列表",
|
||||
manual_parameters=result.get_page_request_params(ChatApi.get_request_params_api()),
|
||||
tags=["应用/对话日志"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def get(self, request: Request, application_id: str, current_page: int, page_size: int):
|
||||
return result.success(ChatSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
|
||||
'user_id': request.user.id}).page(current_page=current_page,
|
||||
page_size=page_size))
|
||||
|
||||
class ChatRecord(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话记录列表",
|
||||
operation_id="获取对话记录列表",
|
||||
manual_parameters=ChatRecordApi.get_request_params_api(),
|
||||
tags=["应用/对话日志"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def get(self, request: Request, application_id: str, chat_id: str):
|
||||
return result.success(ChatRecordSerializer.Query(
|
||||
data={'application_id': application_id,
|
||||
'chat_id': chat_id}).list())
|
||||
|
||||
class Page(APIView):
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话记录列表",
|
||||
operation_id="获取对话记录列表",
|
||||
manual_parameters=result.get_page_request_params(
|
||||
ChatRecordApi.get_request_params_api()),
|
||||
tags=["应用/对话日志"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int):
|
||||
return result.success(ChatRecordSerializer.Query(
|
||||
data={'application_id': application_id,
|
||||
'chat_id': chat_id}).page(current_page, page_size))
|
||||
|
||||
class Vote(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="点赞,点踩",
|
||||
operation_id="点赞,点踩",
|
||||
manual_parameters=VoteApi.get_request_params_api(),
|
||||
request_body=VoteApi.get_request_body_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["应用/会话"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
|
||||
RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
|
||||
return result.success(ChatRecordSerializer.Vote(
|
||||
data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id,
|
||||
'chat_record_id': chat_record_id}).vote())
|
||||
|
||||
class Improve(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="标注",
|
||||
operation_id="标注",
|
||||
manual_parameters=ImproveApi.get_request_params_api(),
|
||||
request_body=ImproveApi.get_request_body_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["应用/对话日志/标注"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION,
|
||||
operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get(
|
||||
'dataset_id'))],
|
||||
)
|
||||
))
|
||||
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str, dataset_id: str,
|
||||
document_id: str):
|
||||
return result.success(ChatRecordSerializer.Improve(
|
||||
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
|
||||
'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data))
|
||||
|
|
@ -12,7 +12,10 @@ from django.core import signing
|
|||
from django.db.models import QuerySet
|
||||
from rest_framework.authentication import TokenAuthentication
|
||||
|
||||
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants
|
||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \
|
||||
Operate
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
from smartdoc.settings import JWT_AUTH
|
||||
from users.models.user import User, get_user_dynamics_permission
|
||||
|
|
@ -34,13 +37,29 @@ class TokenAuth(TokenAuthentication):
|
|||
if auth is None:
|
||||
raise AppAuthenticationFailed(1003, '未登录,请先登录')
|
||||
try:
|
||||
if str(auth).startswith("application-"):
|
||||
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first()
|
||||
if application_api_key is None:
|
||||
raise AppAuthenticationFailed(500, "secret_key 无效")
|
||||
permission_list = [Permission(group=Group.APPLICATION,
|
||||
operate=Operate.USE,
|
||||
dynamic_tag=str(
|
||||
application_api_key.application_id)),
|
||||
Permission(group=Group.APPLICATION,
|
||||
operate=Operate.MANAGE,
|
||||
dynamic_tag=str(
|
||||
application_api_key.application_id))
|
||||
]
|
||||
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
|
||||
permission_list=permission_list,
|
||||
application_id=application_api_key.application_id)
|
||||
# 解析 token
|
||||
user = signing.loads(auth)
|
||||
if 'id' in user:
|
||||
cache_token = token_cache.get(auth)
|
||||
if cache_token is None:
|
||||
raise AppAuthenticationFailed(1002, "登录过期")
|
||||
user = QuerySet(User).get(id=user['id'])
|
||||
auth_details = signing.loads(auth)
|
||||
cache_token = token_cache.get(auth)
|
||||
if cache_token is None:
|
||||
raise AppAuthenticationFailed(1002, "登录过期")
|
||||
if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value:
|
||||
user = QuerySet(User).get(id=auth_details['id'])
|
||||
# 续期
|
||||
token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
|
||||
rule = RoleConstants[user.role]
|
||||
|
|
@ -49,6 +68,26 @@ class TokenAuth(TokenAuthentication):
|
|||
permission_list += get_user_dynamics_permission(str(user.id))
|
||||
return user, Auth(role_list=[rule],
|
||||
permission_list=permission_list)
|
||||
if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get(
|
||||
'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=auth_details.get('application_id')).first()
|
||||
if application_access_token is None:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
if not application_access_token.is_active:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
if not application_access_token.access_token == auth_details.get('access_token'):
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
return application_access_token.application.user, Auth(
|
||||
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||
permission_list=[
|
||||
Permission(group=Group.APPLICATION,
|
||||
operate=Operate.USE,
|
||||
dynamic_tag=str(
|
||||
application_access_token.application_id))],
|
||||
application_id=application_access_token.application_id
|
||||
)
|
||||
|
||||
else:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
|
||||
|
||||
|
|
|
|||
|
|
@ -35,23 +35,30 @@ def exist_role_by_role_constants(user_role: List[RoleConstants],
|
|||
return any(list(map(lambda up: role_list.__contains__(up), user_role)))
|
||||
|
||||
|
||||
def exist_permissions_by_view_permission(user_role: List[RoleConstants], user_permission: List[PermissionConstants],
|
||||
permission: ViewPermission):
|
||||
def exist_permissions_by_view_permission(user_role: List[RoleConstants],
|
||||
user_permission: List[PermissionConstants | object],
|
||||
permission: ViewPermission, request, **kwargs):
|
||||
"""
|
||||
用户是否存在这些权限
|
||||
:param request:
|
||||
:param user_role: 用户角色
|
||||
:param user_permission: 用户权限
|
||||
:param permission: 所属权限
|
||||
:return: 是否存在 True False
|
||||
"""
|
||||
role_ok = any(list(map(lambda ur: permission.roleList.__contains__(ur), user_role)))
|
||||
permission_ok = any(list(map(lambda up: permission.permissionList.__contains__(up), user_permission)))
|
||||
permission_list = [user_p(request, kwargs) if callable(user_p) else user_p for user_p in
|
||||
permission.permissionList
|
||||
]
|
||||
permission_ok = any(list(map(lambda up: permission_list.__contains__(up),
|
||||
user_permission)))
|
||||
return role_ok | permission_ok if permission.compare == CompareConstants.OR else role_ok & permission_ok
|
||||
|
||||
|
||||
def exist_permissions(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission):
|
||||
def exist_permissions(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request,
|
||||
**kwargs):
|
||||
if isinstance(permission, ViewPermission):
|
||||
return exist_permissions_by_view_permission(user_role, user_permission, permission)
|
||||
return exist_permissions_by_view_permission(user_role, user_permission, permission, request, **kwargs)
|
||||
elif isinstance(permission, RoleConstants):
|
||||
return exist_role_by_role_constants(user_role, [permission])
|
||||
elif isinstance(permission, PermissionConstants):
|
||||
|
|
@ -64,9 +71,9 @@ def exist_permissions(user_role: List[RoleConstants], user_permission: List[Perm
|
|||
def exist(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request, **kwargs):
|
||||
if callable(permission):
|
||||
p = permission(request, kwargs)
|
||||
return exist_permissions(user_role, user_permission, p)
|
||||
return exist_permissions(user_role, user_permission, p, request)
|
||||
else:
|
||||
return exist_permissions(user_role, user_permission, permission)
|
||||
return exist_permissions(user_role, user_permission, permission, request, **kwargs)
|
||||
|
||||
|
||||
def has_permissions(*permission, compare=CompareConstants.OR):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: authentication_type.py
|
||||
@date:2023/11/14 20:03
|
||||
@desc:
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AuthenticationType(Enum):
|
||||
# 或者
|
||||
USER = "USER"
|
||||
# 并且
|
||||
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
|
||||
|
|
@ -21,6 +21,10 @@ class Group(Enum):
|
|||
|
||||
SETTING = "SETTING"
|
||||
|
||||
MODEL = "MODEL"
|
||||
|
||||
TEAM = "TEAM"
|
||||
|
||||
|
||||
class Operate(Enum):
|
||||
"""
|
||||
|
|
@ -40,16 +44,24 @@ class Operate(Enum):
|
|||
USE = "USE"
|
||||
|
||||
|
||||
class RoleGroup(Enum):
|
||||
USER = 'USER'
|
||||
APPLICATION_KEY = "APPLICATION_KEY"
|
||||
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
|
||||
|
||||
|
||||
class Role:
|
||||
def __init__(self, name: str, decs: str):
|
||||
def __init__(self, name: str, decs: str, group: RoleGroup):
|
||||
self.name = name
|
||||
self.decs = decs
|
||||
self.group = group
|
||||
|
||||
|
||||
class RoleConstants(Enum):
|
||||
ADMIN = Role("管理员", "管理员,预制目前不会使用")
|
||||
USER = Role("用户", "用户所有权限")
|
||||
SESSION = Role("会话", decs="只拥有应用会话框接口权限")
|
||||
ADMIN = Role("管理员", "管理员,预制目前不会使用", RoleGroup.USER)
|
||||
USER = Role("用户", "用户所有权限", RoleGroup.USER)
|
||||
APPLICATION_ACCESS_TOKEN = Role("会话", "只拥有应用会话框接口权限", RoleGroup.APPLICATION_ACCESS_TOKEN),
|
||||
APPLICATION_KEY = Role("应用私钥", "应用私钥", RoleGroup.APPLICATION_KEY)
|
||||
|
||||
|
||||
class Permission:
|
||||
|
|
@ -90,9 +102,29 @@ class PermissionConstants(Enum):
|
|||
APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
APPLICATION_CREATE = Permission(group=Group.APPLICATION, operate=Operate.CREATE,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
SETTING_READ = Permission(group=Group.SETTING, operate=Operate.READ,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
MODEL_READ = Permission(group=Group.MODEL, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
MODEL_EDIT = Permission(group=Group.MODEL, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
MODEL_DELETE = Permission(group=Group.MODEL, operate=Operate.DELETE,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
MODEL_CREATE = Permission(group=Group.MODEL, operate=Operate.CREATE,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
TEAM_READ = Permission(group=Group.TEAM, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
TEAM_CREATE = Permission(group=Group.TEAM, operate=Operate.CREATE, roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
TEAM_DELETE = Permission(group=Group.TEAM, operate=Operate.DELETE, roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
TEAM_EDIT = Permission(group=Group.TEAM, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
|
||||
def get_permission_list_by_role(role: RoleConstants):
|
||||
"""
|
||||
|
|
@ -110,9 +142,11 @@ class Auth:
|
|||
用于存储当前用户的角色和权限
|
||||
"""
|
||||
|
||||
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants]):
|
||||
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission],
|
||||
**keywords):
|
||||
self.role_list = role_list
|
||||
self.permission_list = permission_list
|
||||
self.keywords = keywords
|
||||
|
||||
|
||||
class CompareConstants(Enum):
|
||||
|
|
@ -123,7 +157,7 @@ class CompareConstants(Enum):
|
|||
|
||||
|
||||
class ViewPermission:
|
||||
def __init__(self, roleList: List[RoleConstants], permissionList: List[PermissionConstants],
|
||||
def __init__(self, roleList: List[RoleConstants], permissionList: List[PermissionConstants | object],
|
||||
compare=CompareConstants.OR):
|
||||
self.roleList = roleList
|
||||
self.permissionList = permissionList
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class AppSQLCompiler(SQLCompiler):
|
|||
field_replace_dict = {}
|
||||
self.field_replace_dict = field_replace_dict
|
||||
|
||||
def get_query_str(self, with_limits=True, with_table_name=True):
|
||||
def get_query_str(self, with_limits=True, with_table_name=False):
|
||||
refcounts_before = self.query.alias_refcount.copy()
|
||||
try:
|
||||
extra_select, order_by, group_by = self.pre_sql_setup()
|
||||
|
|
|
|||
|
|
@ -32,9 +32,10 @@ def get_dynamics_model(attr: dict, table_name='dynamics'):
|
|||
|
||||
|
||||
def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] = None):
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] = None, with_table_name=False):
|
||||
"""
|
||||
生成 查询sql
|
||||
:param with_table_name:
|
||||
:param queryset_dict: 多条件 查询条件
|
||||
:param select_string: 查询sql
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
|
|
@ -45,7 +46,8 @@ def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string
|
|||
result_params = []
|
||||
for key in queryset_dict.keys():
|
||||
value = queryset_dict.get(key)
|
||||
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key))
|
||||
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key),
|
||||
with_table_name)
|
||||
params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
|
||||
select_string = select_string.replace("${" + key + "}", sql)
|
||||
|
||||
|
|
@ -55,7 +57,7 @@ def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string
|
|||
|
||||
|
||||
def generate_sql_by_query(queryset: QuerySet, select_string: str,
|
||||
field_replace_dict: None | Dict[str, str] = None):
|
||||
field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
|
||||
"""
|
||||
生成 查询sql
|
||||
:param queryset: 查询条件
|
||||
|
|
@ -63,13 +65,14 @@ def generate_sql_by_query(queryset: QuerySet, select_string: str,
|
|||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
"""
|
||||
sql, params = compiler_queryset(queryset, field_replace_dict)
|
||||
sql, params = compiler_queryset(queryset, field_replace_dict, with_table_name)
|
||||
return select_string + " " + sql, params
|
||||
|
||||
|
||||
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None):
|
||||
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
|
||||
"""
|
||||
解析 queryset查询对象
|
||||
:param with_table_name:
|
||||
:param queryset: 查询对象
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
|
|
@ -80,15 +83,16 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s
|
|||
field_replace_dict = get_field_replace_dict(queryset)
|
||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||
field_replace_dict=field_replace_dict)
|
||||
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
||||
sql, params = app_sql_compiler.get_query_str(with_table_name)
|
||||
return sql, params
|
||||
|
||||
|
||||
def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
|
||||
with_search_one=False):
|
||||
with_search_one=False, with_table_name=False):
|
||||
"""
|
||||
复杂查询
|
||||
:param with_table_name: 生成sql是否包含表名
|
||||
:param queryset: 查询条件构造器
|
||||
:param select_string: 查询前缀 不包括 where limit 等信息
|
||||
:param field_replace_dict: 需要替换的字段
|
||||
|
|
@ -96,9 +100,9 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
|||
:return: 查询结果
|
||||
"""
|
||||
if isinstance(queryset, Dict):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
|
||||
else:
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
|
||||
if with_search_one:
|
||||
return select_one(exec_sql, exec_params)
|
||||
else:
|
||||
|
|
@ -121,9 +125,11 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco
|
|||
|
||||
def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict=None,
|
||||
post_records_handler=lambda r: r):
|
||||
post_records_handler=lambda r: r,
|
||||
with_table_name=False):
|
||||
"""
|
||||
复杂分页查询
|
||||
:param with_table_name:
|
||||
:param current_page: 当前页
|
||||
:param page_size: 每页大小
|
||||
:param queryset: 查询条件
|
||||
|
|
@ -133,9 +139,9 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet | D
|
|||
:return: 分页结果
|
||||
"""
|
||||
if isinstance(queryset, Dict):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
|
||||
else:
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
|
||||
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
|
||||
total = select_one(total_sql, exec_params)
|
||||
limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2023/11/10 10:43
|
||||
@desc:
|
||||
"""
|
||||
from .listener_manage import *
|
||||
from .listener_chat_message import *
|
||||
|
||||
|
||||
def run():
|
||||
listener_manage.ListenerManagement().run()
|
||||
listener_chat_message.ListenerChatMessage().run()
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: common.py
|
||||
@date:2023/11/10 10:41
|
||||
@desc:
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
work_thread_pool = ThreadPoolExecutor(5)
|
||||
|
||||
|
||||
def poxy(poxy_function):
|
||||
def inner(args):
|
||||
work_thread_pool.submit(poxy_function, args)
|
||||
|
||||
return inner
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: listener_manage.py
|
||||
@date:2023/10/20 14:01
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from blinker import signal
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.models import ChatRecord, Chat
|
||||
from application.serializers.chat_message_serializers import ChatMessage
|
||||
from common.event.common import poxy
|
||||
|
||||
|
||||
class RecordChatMessageArgs:
|
||||
def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage):
|
||||
self.index = index
|
||||
self.chat_id = chat_id
|
||||
self.application_id = application_id
|
||||
self.chat_message = chat_message
|
||||
|
||||
|
||||
class ListenerChatMessage:
|
||||
record_chat_message_signal = signal("record_chat_message")
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def record_chat_message(args: RecordChatMessageArgs):
|
||||
if not QuerySet(Chat).filter(id=args.chat_id).exists():
|
||||
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
|
||||
# 插入会话记录
|
||||
try:
|
||||
chat_record = ChatRecord(
|
||||
id=args.chat_message.id,
|
||||
chat_id=args.chat_id,
|
||||
dataset_id=args.chat_message.dataset_id,
|
||||
paragraph_id=args.chat_message.paragraph_id,
|
||||
source_id=args.chat_message.source_id,
|
||||
source_type=args.chat_message.source_type,
|
||||
problem_text=args.chat_message.problem,
|
||||
answer_text=args.chat_message.answer,
|
||||
index=args.index,
|
||||
message_tokens=args.chat_message.message_tokens,
|
||||
answer_tokens=args.chat_message.answer_token)
|
||||
chat_record.save()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def run(self):
|
||||
# 记录会话
|
||||
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
|
||||
|
|
@ -7,7 +7,6 @@
|
|||
@desc:
|
||||
"""
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import django.db.models
|
||||
from blinker import signal
|
||||
|
|
@ -15,22 +14,14 @@ from django.db.models import QuerySet
|
|||
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.db.search import native_search, get_dynamics_model
|
||||
from common.event.common import poxy
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph, Status, Document
|
||||
from embedding.models import SourceType
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def poxy(poxy_function):
|
||||
def inner(args):
|
||||
ListenerManagement.work_thread_pool.submit(poxy_function, args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
work_thread_pool = ThreadPoolExecutor(5)
|
||||
|
||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||
embedding_by_dataset_signal = signal("embedding_by_dataset")
|
||||
|
|
|
|||
|
|
@ -20,6 +20,17 @@ class AppApiException(Exception):
|
|||
self.message = message
|
||||
|
||||
|
||||
class NotFound404(AppApiException):
|
||||
"""
|
||||
未认证(未登录)异常
|
||||
"""
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
|
||||
def __init__(self, code, message):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class AppAuthenticationFailed(AppApiException):
|
||||
"""
|
||||
未认证(未登录)异常
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2023/10/31 17:56
|
||||
@desc:
|
||||
"""
|
||||
from .array_card import *
|
||||
from .base_field import *
|
||||
from .base_form import *
|
||||
from .combobox_field import *
|
||||
from .multi_select import *
|
||||
from .number_input_field import *
|
||||
from .object_card import *
|
||||
from .password_input import *
|
||||
from .radio_field import *
|
||||
from .single_select_field import *
|
||||
from .switch_btn import *
|
||||
from .tab_card import *
|
||||
from .table_radio import *
|
||||
from .text_input_field import *
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: array_card.py
|
||||
@date:2023/10/31 18:03
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class ArrayCard(BaseExecField):
|
||||
"""
|
||||
收集List[Object]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("ArrayCard", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_field.py
|
||||
@date:2023/10/31 18:07
|
||||
@desc:
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
class TriggerType(Enum):
|
||||
# 执行函数获取 OptionList数据
|
||||
OPTION_LIST = 'OPTION_LIST'
|
||||
# 执行函数获取子表单
|
||||
CHILD_FORMS = 'CHILD_FORMS'
|
||||
|
||||
|
||||
class BaseField:
|
||||
def __init__(self,
|
||||
input_type: str,
|
||||
label: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
"""
|
||||
|
||||
:param input_type: 字段
|
||||
:param label: 提示
|
||||
:param default_value: 默认值
|
||||
:param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示
|
||||
:param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段
|
||||
:param relation_trigger_field_list: 指定那些字段有值的时候 调用当前字段的 执行函数获取optionList数据
|
||||
:param relation_trigger_value_list: 指定那些字段有值 并且值在relation_trigger_value_list列表中 则执行函数获取optionList数据
|
||||
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
|
||||
:param attrs: 前端attr数据
|
||||
:param props_info: 其他额外信息
|
||||
"""
|
||||
if props_info is None:
|
||||
props_info = {}
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
self.label = label
|
||||
self.attrs = attrs
|
||||
self.props_info = props_info
|
||||
self.default_value = default_value
|
||||
self.input_type = input_type
|
||||
self.relation_show_field_list = [] if relation_show_field_list is None else relation_show_field_list
|
||||
self.relation_show_value_list = [] if relation_show_value_list is None else relation_show_value_list
|
||||
self.relation_trigger_field_list = [] if relation_trigger_field_list is None else relation_trigger_field_list
|
||||
self.relation_trigger_value_field_list = [] if relation_trigger_value_list is None else relation_trigger_value_list
|
||||
self.required = required
|
||||
self.trigger_type = trigger_type
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'input_type': self.input_type,
|
||||
'label': self.label,
|
||||
'required': self.required,
|
||||
'default_value': self.default_value,
|
||||
'relation_show_field_list': self.relation_show_field_list,
|
||||
'relation_show_value_list': self.relation_show_value_list,
|
||||
'relation_trigger_field_list': self.relation_trigger_field_list,
|
||||
'relation_trigger_value_field_list': self.relation_trigger_value_field_list,
|
||||
'trigger_type': self.trigger_type.value,
|
||||
'attrs': self.attrs,
|
||||
'props_info': self.props_info,
|
||||
}
|
||||
|
||||
|
||||
class BaseDefaultOptionField(BaseField):
|
||||
def __init__(self, input_type: str,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[dict],
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
"""
|
||||
|
||||
:param input_type: 字段
|
||||
:param label: label
|
||||
:param text_field: 文本字段
|
||||
:param value_field: 值字段
|
||||
:param option_list: 可选列表
|
||||
:param required: 是否必填
|
||||
:param default_value: 默认值
|
||||
:param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示
|
||||
:param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段
|
||||
:param attrs: 前端attr数据
|
||||
:param props_info: 其他额外信息
|
||||
"""
|
||||
super().__init__(input_type, label, required, default_value, relation_show_field_list, relation_show_value_list,
|
||||
[], [], TriggerType.OPTION_LIST, attrs, props_info)
|
||||
self.text_field = text_field
|
||||
self.value_field = value_field
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
'option_list': self.option_list}
|
||||
|
||||
|
||||
class BaseExecField(BaseField):
|
||||
def __init__(self,
|
||||
input_type: str,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
"""
|
||||
|
||||
:param input_type: 字段
|
||||
:param label: 提示
|
||||
:param text_field: 文本字段
|
||||
:param value_field: 值字段
|
||||
:param provider: 指定供应商
|
||||
:param method: 执行供应商函数 method
|
||||
:param required: 是否必填
|
||||
:param default_value: 默认值
|
||||
:param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示
|
||||
:param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段
|
||||
:param relation_trigger_field_list: 指定那些字段有值的时候 调用当前字段的 执行函数获取optionList数据
|
||||
:param relation_trigger_value_list: 指定那些字段有值 并且值在relation_trigger_value_list列表中 则执行函数获取optionList数据
|
||||
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
|
||||
:param attrs: 前端attr数据
|
||||
:param props_info: 其他额外信息
|
||||
"""
|
||||
super().__init__(input_type, label, required, default_value, relation_show_field_list, relation_show_value_list,
|
||||
relation_trigger_field_list, relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
self.text_field = text_field
|
||||
self.value_field = value_field
|
||||
self.provider = provider
|
||||
self.method = method
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
'provider': self.provider, 'method': self.method}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_form.py
|
||||
@date:2023/11/1 16:04
|
||||
@desc:
|
||||
"""
|
||||
from common.froms import BaseField
|
||||
|
||||
|
||||
class BaseForm:
|
||||
def to_form_list(self):
|
||||
return [{**self.__getattribute__(key).to_dict(), 'field': key} for key in
|
||||
list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
|
||||
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))]
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: combobox_field.py
|
||||
@date:2023/10/31 17:59
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class Combobox(BaseExecField):
|
||||
"""
|
||||
多选框
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str = None,
|
||||
method: str = None,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("Combobox", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: multi_select.py
|
||||
@date:2023/10/31 18:00
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class MultiSelect(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str = None,
|
||||
method: str = None,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("MultiSelect", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: number_input_field.py
|
||||
@date:2023/10/31 17:58
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from common.froms.base_field import BaseField, TriggerType
|
||||
|
||||
|
||||
class NumberInput(BaseField):
|
||||
"""
|
||||
文本输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
attrs=None, props_info=None):
|
||||
super().__init__('NumberInput', label, required, default_value, relation_show_field_list,
|
||||
relation_show_value_list, [], [],
|
||||
TriggerType.OPTION_LIST, attrs, props_info)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: object_card.py
|
||||
@date:2023/10/31 18:02
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class ObjectCard(BaseExecField):
|
||||
"""
|
||||
收集对象子表卡片
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("ObjectCard", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: password_input.py
|
||||
@date:2023/11/1 14:48
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from common.froms import BaseField, TriggerType
|
||||
|
||||
|
||||
class PasswordInputField(BaseField):
|
||||
"""
|
||||
文本输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
attrs=None, props_info=None):
|
||||
super().__init__('TextInput', label, required, default_value, relation_show_field_list,
|
||||
relation_show_value_list, [], [],
|
||||
TriggerType.OPTION_LIST, attrs, props_info)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: radio_field.py
|
||||
@date:2023/10/31 17:59
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class Radio(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("Radio", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: single_select_field.py
|
||||
@date:2023/10/31 18:00
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import TriggerType, BaseExecField
|
||||
|
||||
|
||||
class SingleSelect(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str = None,
|
||||
method: str = None,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("SingleSelect", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: switch_btn.py
|
||||
@date:2023/10/31 18:00
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from common.froms.base_field import TriggerType, BaseField
|
||||
|
||||
|
||||
class SwitchBtn(BaseField):
|
||||
"""
|
||||
开关
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
attrs=None, props_info=None):
|
||||
super().__init__('SwitchBtn', label, required, default_value, relation_show_field_list,
|
||||
relation_show_value_list, [], [],
|
||||
TriggerType.OPTION_LIST, attrs, props_info)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: tab_card.py
|
||||
@date:2023/10/31 18:03
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class TabCard(BaseExecField):
|
||||
"""
|
||||
收集 Tab类型数据 tab1:{},tab2:{}
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("TabCard", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: table_radio.py
|
||||
@date:2023/10/31 18:01
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.froms.base_field import TriggerType, BaseExecField
|
||||
|
||||
|
||||
class TableRadio(BaseExecField):
|
||||
"""
|
||||
table 单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
relation_trigger_field_list: List[str] = None,
|
||||
relation_trigger_value_list: List[str] = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("TableRadio", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
|
||||
relation_trigger_value_list, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: text_input_field.py
|
||||
@date:2023/10/31 17:58
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from common.froms.base_field import BaseField, TriggerType
|
||||
|
||||
|
||||
class TextInputField(BaseField):
|
||||
"""
|
||||
文本输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_list: List[str] = None,
|
||||
relation_show_value_list: List[str] = None,
|
||||
attrs=None, props_info=None):
|
||||
super().__init__('TextInput', label, required, default_value, relation_show_field_list,
|
||||
relation_show_value_list, [], [],
|
||||
TriggerType.OPTION_LIST, attrs, props_info)
|
||||
|
|
@ -5,7 +5,9 @@ SELECT
|
|||
problem.dataset_id AS dataset_id,
|
||||
0 AS source_type,
|
||||
problem."content" AS "text",
|
||||
paragraph.is_active AS is_active
|
||||
paragraph.is_active AS is_active,
|
||||
problem.star_num as star_num,
|
||||
problem.trample_num as trample_num
|
||||
FROM
|
||||
problem problem
|
||||
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
|
||||
|
|
@ -18,8 +20,10 @@ SELECT
|
|||
paragraph."id" AS paragraph_id,
|
||||
paragraph.dataset_id AS dataset_id,
|
||||
1 AS source_type,
|
||||
paragraph."content" AS "text",
|
||||
paragraph.is_active AS is_active
|
||||
concat_ws(':',paragraph."title",paragraph."content") AS "text",
|
||||
paragraph.is_active AS is_active,
|
||||
paragraph.star_num as star_num,
|
||||
paragraph.trample_num as trample_num
|
||||
FROM
|
||||
paragraph paragraph
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,27 @@
|
|||
@date:2023/10/16 16:42
|
||||
@desc:
|
||||
"""
|
||||
import importlib
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def query_params_to_single_dict(query_params: Dict):
|
||||
return reduce(lambda x, y: {**x, y[0]: y[1]}, list(filter(lambda row: row[1] is not None,
|
||||
list(map(lambda row: (
|
||||
row[0], row[1][0] if isinstance(row[1][0],
|
||||
list) and len(
|
||||
row[1][0]) > 0 else row[1][0]),
|
||||
query_params.items())))), {})
|
||||
return reduce(lambda x, y: {**x, **y}, list(
|
||||
filter(lambda item: item is not None, [({key: value} if value is not None and len(value) > 0 else None) for
|
||||
key, value in
|
||||
query_params.items()])), {})
|
||||
|
||||
|
||||
def get_exec_method(clazz_: str, method_: str):
|
||||
"""
|
||||
根据 class 和method函数 获取执行函数
|
||||
:param clazz_: class 字符串
|
||||
:param method_: 执行函数
|
||||
:return: 执行函数
|
||||
"""
|
||||
clazz_split = clazz_.split('.')
|
||||
clazz_name = clazz_split[-1]
|
||||
package = ".".join([clazz_split[index] for index in range(len(clazz_split) - 1)])
|
||||
package_model = importlib.import_module(package)
|
||||
return getattr(getattr(package_model, clazz_name), method_)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: rsa_util.py
|
||||
@date:2023/11/3 11:13
|
||||
@desc:
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
|
||||
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
# 对密钥加密的密码
|
||||
secret_code = "mac_kb_password"
|
||||
|
||||
|
||||
def generate():
|
||||
"""
|
||||
生成 私钥秘钥对
|
||||
:return:{key:'公钥',value:'私钥'}
|
||||
"""
|
||||
# 生成一个 2048 位的密钥
|
||||
key = RSA.generate(2048)
|
||||
|
||||
# 获取私钥
|
||||
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
|
||||
protection="scryptAndAES128-CBC")
|
||||
return {'key': key.publickey().export_key(), 'value': encrypted_key}
|
||||
|
||||
|
||||
def get_key_pair():
|
||||
if not os.path.exists("/opt/maxkb/conf/receiver.pem"):
|
||||
kv = generate()
|
||||
private_file_out = open("/opt/maxkb/conf/private.pem", "wb")
|
||||
private_file_out.write(kv.get('value'))
|
||||
private_file_out.close()
|
||||
receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb")
|
||||
receiver_file_out.write(kv.get('key'))
|
||||
receiver_file_out.close()
|
||||
return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()}
|
||||
|
||||
|
||||
def encrypt(msg, public_key: str | None = None):
|
||||
"""
|
||||
加密
|
||||
:param msg: 加密数据
|
||||
:param public_key: 公钥
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
|
||||
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
||||
return base64.b64encode(encrypt_msg).decode()
|
||||
|
||||
|
||||
def decrypt(msg, pri_key: str | None = None):
|
||||
"""
|
||||
解密
|
||||
:param msg: 需要解密的数据
|
||||
:param pri_key: 私钥
|
||||
:return: 解密后数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||
return decrypt_data.decode("utf-8")
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: test.py
|
||||
@date:2023/11/15 15:13
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from django.core import signing
|
||||
import hashlib
|
||||
from django.core.cache import cache
|
||||
|
||||
# alg使用的算法
|
||||
HEADER = {'typ': 'JWP', 'alg': 'default'}
|
||||
TOKEN_KEY = 'solomon_world_token'
|
||||
TOKEN_SALT = 'solomonwanc@gmail.com'
|
||||
TIME_OUT = 30 * 60
|
||||
|
||||
# 加密
|
||||
def encrypt(obj):
|
||||
value = signing.dumps(obj, key=TOKEN_KEY, salt=TOKEN_SALT)
|
||||
value = signing.b64_encode(value.encode()).decode()
|
||||
return value
|
||||
|
||||
|
||||
# 解密
|
||||
def decrypt(src):
|
||||
src = signing.b64_decode(src.encode()).decode()
|
||||
raw = signing.loads(src, key=TOKEN_KEY, salt=TOKEN_SALT)
|
||||
print(type(raw))
|
||||
return raw
|
||||
|
||||
|
||||
# 生成token信息
|
||||
def create_token(username, password):
|
||||
# 1. 加密头信息
|
||||
header = encrypt(HEADER)
|
||||
# 2. 构造Payload
|
||||
payload = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
"iat": time.time()
|
||||
}
|
||||
payload = encrypt(payload)
|
||||
# 3. 生成签名
|
||||
md5 = hashlib.md5()
|
||||
md5.update(("%s.%s" % (header, payload)).encode())
|
||||
signature = md5.hexdigest()
|
||||
token = "%s.%s.%s" % (header, payload, signature)
|
||||
# 4.存储到缓存中
|
||||
cache.set(username, token, TIME_OUT)
|
||||
return token
|
||||
|
||||
|
||||
def get_payload(token):
|
||||
payload = str(token).split('.')[1]
|
||||
payload = decrypt(payload)
|
||||
return payload
|
||||
|
||||
|
||||
# 通过token获取用户名
|
||||
def get_username(token):
|
||||
payload = get_payload(token)
|
||||
return payload['username']
|
||||
pass
|
||||
|
||||
|
||||
def check_token(token):
|
||||
username = get_username(token)
|
||||
print('username', username)
|
||||
last_token = cache.get(username)
|
||||
if last_token:
|
||||
return last_token == token
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
token = create_token('zhangsan', 'lisi')
|
||||
|
|
@ -98,13 +98,15 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
|
||||
{'user_id': models.CharField(),
|
||||
'team_member_permission.auth_target_type': models.CharField(),
|
||||
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
|
||||
base_field=models.CharField(max_length=256,
|
||||
blank=True,
|
||||
choices=AuthOperate.choices,
|
||||
default=AuthOperate.USE)
|
||||
)})).filter(
|
||||
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})
|
||||
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'],
|
||||
'team_member_permission.auth_target_type': 'DATASET'})
|
||||
|
||||
return query_set_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from common.db.search import page_search
|
|||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from dataset.models import Paragraph, Problem
|
||||
from dataset.models import Paragraph, Problem, Document
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
|
||||
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
message="段落在1-1024个字符之间")
|
||||
])
|
||||
|
||||
title = serializers.CharField(required=False)
|
||||
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||
|
||||
problem_list = ProblemInstanceSerializer(required=False, many=True)
|
||||
|
||||
|
|
@ -165,6 +165,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Document).filter(id=self.data.get('document_id'),
|
||||
dataset_id=self.data.get('dataset_id')).exists():
|
||||
raise AppApiException(500, "文档id不正确")
|
||||
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,13 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
paragraph_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'),
|
||||
document_id=self.data.get('document_id'),
|
||||
dataset_id=self.data.get('dataset_id')).exists():
|
||||
raise AppApiException(500, "段落id不正确")
|
||||
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ class Dataset(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取数据集列表",
|
||||
operation_id="获取数据集列表",
|
||||
manual_parameters=DataSetSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()))
|
||||
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
|
||||
tags=["数据集"])
|
||||
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request):
|
||||
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
|
||||
|
|
@ -37,7 +38,9 @@ class Dataset(APIView):
|
|||
@swagger_auto_schema(operation_summary="创建数据集",
|
||||
operation_id="创建数据集",
|
||||
request_body=DataSetSerializers.Create.get_request_body_api(),
|
||||
responses=get_api_response(DataSetSerializers.Create.get_response_body_api()))
|
||||
responses=get_api_response(DataSetSerializers.Create.get_response_body_api()),
|
||||
tags=["数据集"]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
|
||||
def post(self, request: Request):
|
||||
s = DataSetSerializers.Create(data=request.data)
|
||||
|
|
@ -50,7 +53,8 @@ class Dataset(APIView):
|
|||
@action(methods="DELETE", detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集",
|
||||
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
responses=result.get_default_response(),
|
||||
tags=["数据集"])
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('dataset_id')),
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
|
||||
|
|
@ -62,7 +66,8 @@ class Dataset(APIView):
|
|||
@action(methods="GET", detail=False)
|
||||
@swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id",
|
||||
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
|
||||
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
|
||||
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集"])
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str):
|
||||
|
|
@ -72,7 +77,9 @@ class Dataset(APIView):
|
|||
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
|
||||
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
|
||||
request_body=DataSetSerializers.Operate.get_request_body_api(),
|
||||
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
|
||||
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集"]
|
||||
)
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str):
|
||||
|
|
@ -87,7 +94,9 @@ class Dataset(APIView):
|
|||
operation_id="获取数据集分页列表",
|
||||
manual_parameters=get_page_request_params(
|
||||
DataSetSerializers.Query.get_request_params_api()),
|
||||
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
|
||||
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()),
|
||||
tags=["数据集"]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request, current_page, page_size):
|
||||
d = DataSetSerializers.Query(
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ class Document(APIView):
|
|||
operation_id="创建文档",
|
||||
request_body=DocumentSerializers.Create.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -40,7 +41,8 @@ class Document(APIView):
|
|||
@swagger_auto_schema(operation_summary="文档列表",
|
||||
operation_id="文档列表",
|
||||
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()))
|
||||
responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()),
|
||||
tags=["数据集/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -57,7 +59,8 @@ class Document(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取文档详情",
|
||||
operation_id="获取文档详情",
|
||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -71,7 +74,8 @@ class Document(APIView):
|
|||
operation_id="修改文档",
|
||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||
request_body=DocumentSerializers.Operate.get_request_body_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集/文档"]
|
||||
)
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
|
|
@ -86,7 +90,8 @@ class Document(APIView):
|
|||
@swagger_auto_schema(operation_summary="删除文档",
|
||||
operation_id="删除文档",
|
||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
responses=result.get_default_response(),
|
||||
tags=["数据集/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -101,7 +106,9 @@ class Document(APIView):
|
|||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="分段文档",
|
||||
operation_id="分段文档",
|
||||
manual_parameters=DocumentSerializers.Split.get_request_params_api())
|
||||
manual_parameters=DocumentSerializers.Split.get_request_params_api(),
|
||||
tags=["数据集/文档"],
|
||||
security=[])
|
||||
def post(self, request: Request):
|
||||
ds = DocumentSerializers.Split(
|
||||
data={'file': request.FILES.getlist('file'),
|
||||
|
|
@ -116,7 +123,8 @@ class Document(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取数据集分页列表",
|
||||
operation_id="获取数据集分页列表",
|
||||
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()))
|
||||
responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()),
|
||||
tags=["数据集/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
|
|||
|
|
@ -25,7 +25,9 @@ class Paragraph(APIView):
|
|||
@swagger_auto_schema(operation_summary="段落列表",
|
||||
operation_id="段落列表",
|
||||
manual_parameters=ParagraphSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()))
|
||||
responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()),
|
||||
tags=["数据集/文档/段落"]
|
||||
)
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -41,7 +43,8 @@ class Paragraph(APIView):
|
|||
operation_id="创建段落",
|
||||
manual_parameters=ParagraphSerializers.Create.get_request_params_api(),
|
||||
request_body=ParagraphSerializers.Create.get_request_body_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()))
|
||||
responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()),
|
||||
tags=["数据集/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -57,7 +60,8 @@ class Paragraph(APIView):
|
|||
operation_id="修改段落数据",
|
||||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||
request_body=ParagraphSerializers.Operate.get_request_body_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
|
||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())
|
||||
,tags=["数据集/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -71,7 +75,8 @@ class Paragraph(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取段落详情",
|
||||
operation_id="获取段落详情",
|
||||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
|
||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -85,7 +90,8 @@ class Paragraph(APIView):
|
|||
@swagger_auto_schema(operation_summary="删除段落",
|
||||
operation_id="删除段落",
|
||||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
responses=result.get_default_response(),
|
||||
tags=["数据集/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -103,7 +109,8 @@ class Paragraph(APIView):
|
|||
operation_id="分页获取段落列表",
|
||||
manual_parameters=result.get_page_request_params(
|
||||
ParagraphSerializers.Query.get_request_params_api()),
|
||||
responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()))
|
||||
responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()),
|
||||
tags=["数据集/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ class Problem(APIView):
|
|||
operation_id="添加段落关联问题",
|
||||
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
|
||||
request_body=ProblemSerializers.Create.get_request_body_api(),
|
||||
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()))
|
||||
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -38,7 +39,8 @@ class Problem(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取段落问题列表",
|
||||
operation_id="获取段落问题列表",
|
||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()))
|
||||
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
@ -54,7 +56,8 @@ class Problem(APIView):
|
|||
@swagger_auto_schema(operation_summary="删除段落问题",
|
||||
operation_id="删除段落问题",
|
||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
responses=result.get_default_response(),
|
||||
tags=["数据集/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-09 07:45
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('embedding', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='embedding',
|
||||
name='star_num',
|
||||
field=models.IntegerField(default=0, verbose_name='点赞数量'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='embedding',
|
||||
name='trample_num',
|
||||
field=models.IntegerField(default=0, verbose_name='点踩数量'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-14 07:30
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('embedding', '0002_embedding_star_num_embedding_trample_num'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterUniqueTogether(
|
||||
name='embedding',
|
||||
unique_together={('source_id', 'source_type')},
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-14 08:09
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('embedding', '0003_alter_embedding_unique_together'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='embedding',
|
||||
name='source_type',
|
||||
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('-1', '其他')], default='0', max_length=5, verbose_name='资源类型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-14 08:11
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('embedding', '0004_alter_embedding_source_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='embedding',
|
||||
name='source_type',
|
||||
field=models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=5, verbose_name='资源类型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -23,7 +23,7 @@ class Embedding(models.Model):
|
|||
|
||||
source_id = models.CharField(max_length=128, verbose_name="资源id")
|
||||
|
||||
source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices,
|
||||
source_type = models.CharField(verbose_name='资源类型', max_length=5, choices=SourceType.choices,
|
||||
default=SourceType.PROBLEM)
|
||||
|
||||
is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True)
|
||||
|
|
@ -36,5 +36,11 @@ class Embedding(models.Model):
|
|||
|
||||
embedding = VectorField(verbose_name="向量")
|
||||
|
||||
star_num = models.IntegerField(default=0, verbose_name="点赞数量")
|
||||
|
||||
trample_num = models.IntegerField(default=0,
|
||||
verbose_name="点踩数量")
|
||||
|
||||
class Meta:
|
||||
db_table = "embedding"
|
||||
unique_together = ['source_id', 'source_type']
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
SELECT * FROM (SELECT
|
||||
*,
|
||||
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
|
||||
CASE
|
||||
|
||||
WHEN embedding.star_num - embedding.trample_num = 0 THEN
|
||||
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
|
||||
END AS score
|
||||
FROM
|
||||
embedding,
|
||||
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
|
||||
${embedding_query}
|
||||
) temp
|
||||
WHERE similarity>0.5
|
||||
ORDER BY (similarity + score) DESC LIMIT 1
|
||||
|
|
@ -47,6 +47,8 @@ class BaseVectorStore(ABC):
|
|||
|
||||
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
star_num: int,
|
||||
trample_num: int,
|
||||
embedding=None):
|
||||
"""
|
||||
插入向量数据
|
||||
|
|
@ -58,12 +60,15 @@ class BaseVectorStore(ABC):
|
|||
:param is_active: 是否禁用
|
||||
:param embedding: 向量化处理器
|
||||
:param paragraph_id 段落id
|
||||
:param star_num 点赞数量
|
||||
:param trample_num 点踩数量
|
||||
:return: bool
|
||||
"""
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num,
|
||||
trample_num, embedding)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding=None):
|
||||
"""
|
||||
|
|
@ -81,6 +86,8 @@ class BaseVectorStore(ABC):
|
|||
@abstractmethod
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
star_num: int,
|
||||
trample_num: int,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
|
|
@ -89,7 +96,10 @@ class BaseVectorStore(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
|
||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_id_list: list[str],
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -6,14 +6,20 @@
|
|||
@date:2023/10/19 15:28
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from common.db.search import native_search, generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_one
|
||||
from common.util.file_util import get_file_content
|
||||
from embedding.models import Embedding, SourceType
|
||||
from embedding.vector.base_vector import BaseVectorStore
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class PGVector(BaseVectorStore):
|
||||
|
|
@ -27,6 +33,8 @@ class PGVector(BaseVectorStore):
|
|||
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
star_num: int,
|
||||
trample_num: int,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
text_embedding = embedding.embed_query(text)
|
||||
embedding = Embedding(id=uuid.uuid1(),
|
||||
|
|
@ -37,6 +45,8 @@ class PGVector(BaseVectorStore):
|
|||
source_id=source_id,
|
||||
embedding=text_embedding,
|
||||
source_type=source_type,
|
||||
star_num=star_num,
|
||||
trample_num=trample_num
|
||||
)
|
||||
embedding.save()
|
||||
return True
|
||||
|
|
@ -44,19 +54,41 @@ class PGVector(BaseVectorStore):
|
|||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
texts = [row.get('text') for row in text_list]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(),
|
||||
document_id=text_list[index].get('document_id'),
|
||||
paragraph_id=text_list[index].get('paragraph_id'),
|
||||
dataset_id=text_list[index].get('dataset_id'),
|
||||
is_active=text_list[index].get('is_active', True),
|
||||
source_id=text_list[index].get('source_id'),
|
||||
source_type=text_list[index].get('source_type'),
|
||||
embedding=embeddings[index]) for index in
|
||||
range(0, len(text_list))]) if len(text_list) > 0 else None
|
||||
embedding_list = [Embedding(id=uuid.uuid1(),
|
||||
document_id=text_list[index].get('document_id'),
|
||||
paragraph_id=text_list[index].get('paragraph_id'),
|
||||
dataset_id=text_list[index].get('dataset_id'),
|
||||
is_active=text_list[index].get('is_active', True),
|
||||
source_id=text_list[index].get('source_id'),
|
||||
source_type=text_list[index].get('source_type'),
|
||||
star_num=text_list[index].get('star_num'),
|
||||
trample_num=text_list[index].get('trample_num'),
|
||||
embedding=embeddings[index]) for index in
|
||||
range(0, len(text_list))]
|
||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||
return True
|
||||
|
||||
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_id_list: list[str],
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
exclude_dict = {}
|
||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||
return None
|
||||
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||||
if exclude_id_list is not None and len(exclude_id_list) > 0:
|
||||
exclude_dict.__setitem__('id__in', exclude_id_list)
|
||||
query_set = query_set.exclude(**exclude_dict)
|
||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||
'embedding_search.sql')),
|
||||
with_table_name=True)
|
||||
embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params))
|
||||
return embedding_model
|
||||
|
||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-06 08:51
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('users', '0001_initial'),
|
||||
('setting', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Model',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('name', models.CharField(max_length=128, verbose_name='名称')),
|
||||
('model_type', models.CharField(max_length=128, verbose_name='模型类型')),
|
||||
('model_name', models.CharField(max_length=128, verbose_name='模型名称')),
|
||||
('provider', models.CharField(max_length=20, verbose_name='供应商')),
|
||||
('credential', models.CharField(max_length=5120, verbose_name='模型认证信息')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='成员用户id')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'model',
|
||||
'unique_together': {('name', 'user_id')},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.10 on 2023-11-13 05:55
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0002_model'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='model',
|
||||
name='provider',
|
||||
field=models.CharField(max_length=128, verbose_name='供应商'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2023/10/31 17:16
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_model_provider.py
|
||||
@date:2023/10/31 16:19
|
||||
@desc:
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
|
||||
class IModelProvider(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_model_provide_info(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_type_list(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_list(self, model_type):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dialogue_number(self):
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelCredential(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], raise_exception=False):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
"""
|
||||
:param model_info: 模型数据
|
||||
:return: 加密后数据
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def encryption(message: str):
|
||||
"""
|
||||
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
max_pre_len = 8
|
||||
max_post_len = 4
|
||||
message_len = len(message)
|
||||
pre_len = int(message_len / 5 * 2)
|
||||
post_len = int(message_len / 5 * 1)
|
||||
pre_str = "".join([message[index] for index in
|
||||
range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))])
|
||||
end_str = "".join(
|
||||
[message[index] for index in
|
||||
range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)])
|
||||
content = "***************"
|
||||
return pre_str + content + end_str
|
||||
|
||||
|
||||
class ModelTypeConst(Enum):
|
||||
LLM = {'code': 'LLM', 'message': '大语言模型'}
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
|
||||
**keywords):
|
||||
self.name = name
|
||||
self.desc = desc
|
||||
self.model_type = model_type.name
|
||||
self.model_credential = model_credential
|
||||
if keywords is not None:
|
||||
for key in keywords.keys():
|
||||
self.__setattr__(key, keywords.get(key))
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
获取模型名称
|
||||
:return: 模型名称
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_desc(self):
|
||||
"""
|
||||
获取模型描述
|
||||
:return: 模型描述
|
||||
"""
|
||||
return self.desc
|
||||
|
||||
def get_model_type(self):
|
||||
return self.model_type
|
||||
|
||||
def to_dict(self):
|
||||
return reduce(lambda x, y: {**x, **y},
|
||||
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
|
||||
not attr.startswith("__") and not attr == 'model_credential'], {})
|
||||
|
||||
|
||||
class ModelProvideInfo:
|
||||
def __init__(self, provider: str, name: str, icon: str):
|
||||
self.provider = provider
|
||||
|
||||
self.name = name
|
||||
|
||||
self.icon = icon
|
||||
|
||||
def to_dict(self):
|
||||
return reduce(lambda x, y: {**x, **y},
|
||||
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
|
||||
not attr.startswith("__")], {})
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: model_provider_constants.py
|
||||
@date:2023/11/2 14:55
|
||||
@desc:
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
|
||||
|
||||
class ModelProvideConstants(Enum):
|
||||
model_azure_provider = AzureModelProvider()
|
||||
model_wenxin_provider = WenxinModelProvider()
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2023/10/31 17:16
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: azure_model_provider.py
|
||||
@date:2023/10/31 16:19
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage, BaseMessage
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
from common import froms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.froms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, \
|
||||
ModelTypeConst
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = AzureModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(500, f'{model_type} 模型类型不支持')
|
||||
|
||||
if model_name not in model_dict:
|
||||
raise AppApiException(500, f'{model_name} 模型名称不支持')
|
||||
|
||||
for key in ['api_base', 'api_key', 'deployment_name']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
AzureModelProvider().query(model_type, model_name, model_credential,
|
||||
message=[HumanMessage(content='valid')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(500, '校验失败,请检查参数是否正确')
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = froms.TextInputField('API 域名', required=True)
|
||||
|
||||
api_key = froms.PasswordInputField("API Key", required=True)
|
||||
|
||||
deployment_name = froms.TextInputField("部署名", required=True)
|
||||
|
||||
|
||||
azure_llm_model_credential = AzureLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
|
||||
api_version='2023-07-01-preview'),
|
||||
'gpt-3.5-turbo-0301': ModelInfo('gpt-3.5-turbo-0301', '', ModelTypeConst.LLM, azure_llm_model_credential,
|
||||
api_version='2023-07-01-preview'),
|
||||
'gpt-3.5-turbo-16k-0613': ModelInfo('gpt-3.5-turbo-16k-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
|
||||
api_version='2023-07-01-preview')
|
||||
}
|
||||
|
||||
|
||||
class AzureModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
|
||||
model_info: ModelInfo = model_dict.get(model_name)
|
||||
azure_chat_open_ai = AzureChatOpenAI(
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_version=model_info.api_version,
|
||||
deployment_name=model_credential.get('deployment_name'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_type="azure",
|
||||
tiktoken_model_name=model_name
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
return model_dict.get(model_name).model_credential
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon',
|
||||
'azure_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" id="Layer_1" x="0px" y="0px" width="64px" height="64px" viewBox="0 0 64 64" enable-background="new 0 0 64 64" xml:space="preserve"> <image id="image0" width="64" height="64" x="0" y="0" href="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAYAAACqaXHeAAAABGdBTUEAALGPC/xhBQAAACBjSFJN AAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAABmJLR0QA/wD/AP+gvaeTAAAA CXBIWXMAAA7EAAAOxAGVKw4bAAAJXUlEQVR42u2aa2wcVxXH/+fO7K4fa6epX2snDSqNHb8TQ4HE Nn0EldR5v2wnLRQZCSHxqSFqCUJCfEAVihBCSFRIICEB3RabkFYVAipFgEgjhNUQYjs2SYGw1LFj O8l67V17X3P5MLO7szuvuxsnMeArbZydk8zM+c3/f+69xwOsjbWxNtbG//Gge3HSosd2biCZNZY8 4vMxt1xc4Gl48Ya6v2t/DwZ++Mql/woA3o9/8Rvy+k1fQzIhk6TgkecPoLypIc/UtR+c64+9N/Kl np7QxfOzK3m/bCVPVlT/zGOuis1fBzEZRFDiSXzw+i+hJBJiSXMAnINrn8wxABwfbfv+r19eyftd cQByxebjYEREDGASQBIS4SgWxicckkY6aZ6ddC6U58o7uldUtSsNoJ+DAGIAMRCTAMmF+b9eMU2c Wz/tLCXooNSt6/hk96oEUNSwqwlErSAC1wBwYiAmIzR6FUo87iRxoxIAA5SNnz/VvyoByBWbj6k1 lQDSVMAkcCZBSXCExiZEJC6ihCPl27pX7L5XDkBlfZ+aOFQATLMBaTa4PG5MDCK+N8R85R3dT64q AEVbnm0DUWNaASD11EwrhkxGaPwalHhctNiZxHk6vvFzp/pWFQCpsv4YB4Gnnr5mASJJs4IMJQmE xsbzTBpZStDFj5Rv7ZZWDQC5sqE3lTiH9iFoyTMQYyDmwvzIuIjERZRQVb6t6+lVAcDTuLsDoHr1 G+kUQOBEAEngJAGSjIW//QNKLOYocRElbHjhVO+qACBX1vfrk86srjPrAbUWyEgqhPmxcRGJG6HA EDtStrVLfqAApMp6SFVbetWc9VOgDgZjgFYLSHJhfnRCuNjZQ0FFeXv3px4oALmy4XGAPqx6PiX5 HBhg6TUBmIyF968jGY2KFjtbKBs++5W7tsFdAfA07+tPz/ta0gYYRKD00lgG5wyhKxMiEheBcris /e5sUDAAqbIBIOrLlr4FDJbZHJHkwvyVqyISF1HC+rL2rl0PBIBctWU7gE3GAmgCgxg4mAbChcV/ /gvJ5WXBp81toWx47uW7skHBANwt+/u4PmkA1jBI3RlqH4XL6hZZrNg5QTlY1tbluq8ApKotBKAX Or87wqBMMSTZheDEtTxnAF3SyIKyrqytq+d+A+gE0cZMgoAIDEpDkBG+/gGSS0uOEs/ErZVQd/yl gvcGBQHwtB7IFD/D3G8Bg7T1AJNATAInGaGJqyISz4lzMyj7va2dnvsCQJU/9WaSBoRhaFvk1GwQ mZp+X0DilkrQQSkra+3afX8AVDc+wYlqubbxyU7aAYaWPNeWxsvXx74Jzm87SVxECbXHCrNB3gDc rYd0cz8yT1cUBktbIXLn9z8b4hxnBSQuooS93pbOvH8HkRcAqbqRgeioYarLB0a6Y8x+pdy5Hrnx 2ukhAYnbQEkf8Ja1du67xwCangJRte287wiDgYghdvWdQQBYGDl/DpzfEk2amyghFff1vnT0ngJw tR3uM93x5QNDhbCYnLv2tgrgQoJznBEsdsY4UnEOcL7X29xZek8AsOomCUSH1URg8mTzgMHYW8m5 a9HUuW/4Tw8JSFxECcXels799wSAVNO0E6AqfdKFwoiNvT2kP/fC6Lu/4xyzpokBeUHxHT2Z195A GIC7/Wi/2Y6vABih5OzEb/TnXhy9kATHGfunzWHmexMl9Hibd5StKABW0yyD6JBxWisIxpvJmYlo 7jWm3jg9KFrsHKAUlTaJ20AIgFTT/AwHPZzV9i4QRmz07JDZNRZGL/yBc0w7+547QvEdOSm8KBIC 4Nraa9H2zhtGMDEz8VuzayyOXVDA+RkBiYtAebbm8Ml1KwKA1TS7ADqkfstpe+cP46wyMx63utbU 4LcHxX1vEsscdpc27jggAsCxnyb5WnaB6KH0AZ76g9K1DdqFkyXrkPR4Lc+luIvrcWr0x1n/SUkQ ElEgsUzTyyG5Ymk56SrySKm3RFL/LPvamR9Z8RQFAN7GHX0AfnLXAFzb+nvVE6cquS6YAyO6fqNu F2g6urWPdi7dNlnbKgdv3ETVo5t0SekyNE065wtPH/p09cEvPzTz5neCdjdkawHma/Fw0MFUc9O4 rs+e6uTIHSeexpHaIDEJkNwITs9Z+x5i9lB/zc5dpY07Djtd3gFAaw8I5akEM543h+Gen4a8eCtf Apl2meRCeDGGeCQi7HvzuHrcu2WH42xgC8DVceyoVUGzguEOzcAVupmfAvQ2kD0ITs1aVHlYPW2L ON9ZfeBERUEAmK+1GMABkbZ3LgxX+DbcwakcozpASNvAheDNORGJC0CBq7TB3gbWAGpbd3Mir22n 1waGtDQPT3AS4IoIgSwbRCJxxCIRR4nbQoEaL3GwgSUAueN4fm1vExhsOQz37ck5AAuOCtC9WgfZ g/mbcyISF4HyVNW+E1V5AWC+1lIQ7c+37W22CJLikUEEhp8GYF8d9RAkN4IzcyISF4EilzZst2yU mAOoa9sLULFlc9MKhsmKMH5pcBD+gfdw/tVuAAEbAjobyFhaSiAWDjtKPBWzg1JSv91yi2wKQP7I 8wW1vWGEMaVMj/4RAHD+1Qn4BzoBjInaIDh7S7TYmSpBF3+ias+LNUIAqLbNy4n2cDJL2gKGZgVu hDGkTI9lqmBgeBL+gScB/FnEBvOzt/L0vaUSpJJ6cxsYALDa9v0AFUHb5Bhl7rwiTMGIX3rDuPUN DN+Cf2AngHOWADQbLC8nEQ2HHSUuooSS+u2ms4EBgPz4Z/rMiloaBoRhTCpTo++aPunAcBj+gd0A fpFDIMcGbszPzolIPCduGuuu3P1inS0Aqm0vB1GP01QnCGNImR61XgkFhmPwD/QD+JGtDeZuCxc7 y9dtAXDOWcnmTxhskAWA1bXvAcgt3Om1gZH4y+s/h9MIDCvwD3wBgeFvWdkgGlcQjYRFi52pElJQ SjZvP2JvAaJHLft5AosgHYxA4qL/T44AVAiAf+CrCAyfUG9bbwMZkDyIR2MiErdXgjo22QMAzpn1 88wLnTUMZWrkB0LJ64d/4LsIDL8AIJGxAYPkKUJRaZmIxI1QjOOd3ANZ79vyG5cnqW7bv6nc1wLQ w8bFDUCW0yIAohvK1MgriYuvncbijOBOSDdG3roM4AI+9LEWgNd5XBLqaivgduvegNH3PkwaIRZj AcBPA987djJ+ezKW932tjbWxNtbG/+r4D87QQ0KXWZJ6AAAAMnRFWHRDb21tZW50AHhyOmQ6REFG allncE9SVXM6MixqOjUwNTg3NDI3NTEsdDoyMzA1MTkxOOBfQ5wAAAAldEVYdGRhdGU6Y3JlYXRl ADIwMjMtMTAtMzFUMTA6MTQ6MjgrMDE6MDAxbYpSAAAAJXRFWHRkYXRlOm1vZGlmeQAyMDIzLTEw LTMxVDEwOjE0OjI4KzAxOjAwQDAy7gAAAABJRU5ErkJggg=="/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.9 KiB |
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2023/10/31 17:16
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" id="Layer_1" x="0px" y="0px" width="64px" height="64px" viewBox="0 0 64 64" enable-background="new 0 0 64 64" xml:space="preserve"> <image id="image0" width="64" height="64" x="0" y="0" href="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAYAAACqaXHeAAAABGdBTUEAALGPC/xhBQAAACBjSFJN AAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAABmJLR0QA/wD/AP+gvaeTAAAA CXBIWXMAAA7EAAAOxAGVKw4bAAAJXUlEQVR42u2aa2wcVxXH/+fO7K4fa6epX2snDSqNHb8TQ4HE Nn0EldR5v2wnLRQZCSHxqSFqCUJCfEAVihBCSFRIICEB3RabkFYVAipFgEgjhNUQYjs2SYGw1LFj O8l67V17X3P5MLO7szuvuxsnMeArbZydk8zM+c3/f+69xwOsjbWxNtbG//Gge3HSosd2biCZNZY8 4vMxt1xc4Gl48Ya6v2t/DwZ++Mql/woA3o9/8Rvy+k1fQzIhk6TgkecPoLypIc/UtR+c64+9N/Kl np7QxfOzK3m/bCVPVlT/zGOuis1fBzEZRFDiSXzw+i+hJBJiSXMAnINrn8wxABwfbfv+r19eyftd cQByxebjYEREDGASQBIS4SgWxicckkY6aZ6ddC6U58o7uldUtSsNoJ+DAGIAMRCTAMmF+b9eMU2c Wz/tLCXooNSt6/hk96oEUNSwqwlErSAC1wBwYiAmIzR6FUo87iRxoxIAA5SNnz/VvyoByBWbj6k1 lQDSVMAkcCZBSXCExiZEJC6ihCPl27pX7L5XDkBlfZ+aOFQATLMBaTa4PG5MDCK+N8R85R3dT64q AEVbnm0DUWNaASD11EwrhkxGaPwalHhctNiZxHk6vvFzp/pWFQCpsv4YB4Gnnr5mASJJs4IMJQmE xsbzTBpZStDFj5Rv7ZZWDQC5sqE3lTiH9iFoyTMQYyDmwvzIuIjERZRQVb6t6+lVAcDTuLsDoHr1 G+kUQOBEAEngJAGSjIW//QNKLOYocRElbHjhVO+qACBX1vfrk86srjPrAbUWyEgqhPmxcRGJG6HA EDtStrVLfqAApMp6SFVbetWc9VOgDgZjgFYLSHJhfnRCuNjZQ0FFeXv3px4oALmy4XGAPqx6PiX5 HBhg6TUBmIyF968jGY2KFjtbKBs++5W7tsFdAfA07+tPz/ta0gYYRKD00lgG5wyhKxMiEheBcris /e5sUDAAqbIBIOrLlr4FDJbZHJHkwvyVqyISF1HC+rL2rl0PBIBctWU7gE3GAmgCgxg4mAbChcV/ /gvJ5WXBp81toWx47uW7skHBANwt+/u4PmkA1jBI3RlqH4XL6hZZrNg5QTlY1tbluq8ApKotBKAX Or87wqBMMSTZheDEtTxnAF3SyIKyrqytq+d+A+gE0cZMgoAIDEpDkBG+/gGSS0uOEs/ErZVQd/yl gvcGBQHwtB7IFD/D3G8Bg7T1AJNATAInGaGJqyISz4lzMyj7va2dnvsCQJU/9WaSBoRhaFvk1GwQ mZp+X0DilkrQQSkra+3afX8AVDc+wYlqubbxyU7aAYaWPNeWxsvXx74Jzm87SVxECbXHCrNB3gDc rYd0cz8yT1cUBktbIXLn9z8b4hxnBSQuooS93pbOvH8HkRcAqbqRgeioYarLB0a6Y8x+pdy5Hrnx 2ukhAYnbQEkf8Ja1du67xwCangJRte287wiDgYghdvWdQQBYGDl/DpzfEk2amyghFff1vnT0ngJw tR3uM93x5QNDhbCYnLv2tgrgQoJznBEsdsY4UnEOcL7X29xZek8AsOomCUSH1URg8mTzgMHYW8m5 a9HUuW/4Tw8JSFxECcXels799wSAVNO0E6AqfdKFwoiNvT2kP/fC6Lu/4xyzpokBeUHxHT2Z195A GIC7/Wi/2Y6vABih5OzEb/TnXhy9kATHGfunzWHmexMl9Hibd5StKABW0yyD6JBxWisIxpvJmYlo 7jWm3jg9KFrsHKAUlTaJ20AIgFTT/AwHPZzV9i4QRmz07JDZNRZGL/yBc0w7+547QvEdOSm8KBIC 4Nraa9H2zhtGMDEz8VuzayyOXVDA+RkBiYtAebbm8Ml1KwKA1TS7ADqkfstpe+cP46wyMx63utbU 4LcHxX1vEsscdpc27jggAsCxnyb5WnaB6KH0AZ76g9K1DdqFkyXrkPR4Lc+luIvrcWr0x1n/SUkQ ElEgsUzTyyG5Ymk56SrySKm3RFL/LPvamR9Z8RQFAN7GHX0AfnLXAFzb+nvVE6cquS6YAyO6fqNu F2g6urWPdi7dNlnbKgdv3ETVo5t0SekyNE065wtPH/p09cEvPzTz5neCdjdkawHma/Fw0MFUc9O4 rs+e6uTIHSeexpHaIDEJkNwITs9Z+x5i9lB/zc5dpY07Djtd3gFAaw8I5akEM543h+Gen4a8eCtf Apl2meRCeDGGeCQi7HvzuHrcu2WH42xgC8DVceyoVUGzguEOzcAVupmfAvQ2kD0ITs1aVHlYPW2L ON9ZfeBERUEAmK+1GMABkbZ3LgxX+DbcwakcozpASNvAheDNORGJC0CBq7TB3gbWAGpbd3Mir22n 1waGtDQPT3AS4IoIgSwbRCJxxCIRR4nbQoEaL3GwgSUAueN4fm1vExhsOQz37ck5AAuOCtC9WgfZ g/mbcyISF4HyVNW+E1V5AWC+1lIQ7c+37W22CJLikUEEhp8GYF8d9RAkN4IzcyISF4EilzZst2yU mAOoa9sLULFlc9MKhsmKMH5pcBD+gfdw/tVuAAEbAjobyFhaSiAWDjtKPBWzg1JSv91yi2wKQP7I 8wW1vWGEMaVMj/4RAHD+1Qn4BzoBjInaIDh7S7TYmSpBF3+ias+LNUIAqLbNy4n2cDJL2gKGZgVu hDGkTI9lqmBgeBL+gScB/FnEBvOzt/L0vaUSpJJ6cxsYALDa9v0AFUHb5Bhl7rwiTMGIX3rDuPUN DN+Cf2AngHOWADQbLC8nEQ2HHSUuooSS+u2ms4EBgPz4Z/rMiloaBoRhTCpTo++aPunAcBj+gd0A fpFDIMcGbszPzolIPCduGuuu3P1inS0Aqm0vB1GP01QnCGNImR61XgkFhmPwD/QD+JGtDeZuCxc7 y9dtAXDOWcnmTxhskAWA1bXvAcgt3Om1gZH4y+s/h9MIDCvwD3wBgeFvWdkgGlcQjYRFi52pElJQ SjZvP2JvAaJHLft5AosgHYxA4qL/T44AVAiAf+CrCAyfUG9bbwMZkDyIR2MiErdXgjo22QMAzpn1 88wLnTUMZWrkB0LJ64d/4LsIDL8AIJGxAYPkKUJRaZmIxI1QjOOd3ANZ79vyG5cnqW7bv6nc1wLQ w8bFDUCW0yIAohvK1MgriYuvncbijOBOSDdG3roM4AI+9LEWgNd5XBLqaivgduvegNH3PkwaIRZj AcBPA987djJ+ezKW932tjbWxNtbG/+r4D87QQ0KXWZJ6AAAAMnRFWHRDb21tZW50AHhyOmQ6REFG allncE9SVXM6MixqOjUwNTg3NDI3NTEsdDoyMzA1MTkxOOBfQ5wAAAAldEVYdGRhdGU6Y3JlYXRl ADIwMjMtMTAtMzFUMTA6MTQ6MjgrMDE6MDAxbYpSAAAAJXRFWHRkYXRlOm1vZGlmeQAyMDIzLTEw LTMxVDEwOjE0OjI4KzAxOjAwQDAy7gAAAABJRU5ErkJggg=="/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.9 KiB |
|
|
@ -0,0 +1,78 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: qian_fan_chat_model.py
|
||||
@date:2023/11/10 17:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import Optional, List, Any, Iterator, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models import QianfanChatEndpoint
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.load import dumpd
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
|
||||
|
||||
class QianfanChatModel(QianfanChatEndpoint):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if len(input) % 2 == 0:
|
||||
input = [HumanMessage(content='占位'), *input]
|
||||
input = [
|
||||
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
|
||||
for index in range(0, len(input))]
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: wenxin_model_provider.py
|
||||
@date:2023/10/31 16:19
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.chat_models import QianfanChatEndpoint
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import convert_message_to_dict
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import froms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.froms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||
ModelInfo, IModelProvider
|
||||
from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = WenxinModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(500, f'{model_type} 模型类型不支持')
|
||||
|
||||
if model_name not in model_dict:
|
||||
raise AppApiException(500, f'{model_name} 模型名称不支持')
|
||||
|
||||
for key in ['api_key', 'secret_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
WenxinModelProvider().get_model(model_type, model_name, model_credential)([HumanMessage(content='valid')])
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'secret_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
self.api_key = model_info.get('api_key')
|
||||
self.secret_key = model_info.get('secret_key')
|
||||
return self
|
||||
|
||||
api_key = froms.PasswordInputField('API Key', required=True)
|
||||
|
||||
secret_key = froms.PasswordInputField("Secret Key", required=True)
|
||||
|
||||
|
||||
win_xin_llm_model_credential = WenxinLLMModelCredential()
|
||||
model_dict = {
|
||||
'ERNIE-Bot-4': ModelInfo('ERNIE-Bot-4',
|
||||
'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'ERNIE-Bot': ModelInfo('ERNIE-Bot',
|
||||
'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'ERNIE-Bot-turbo': ModelInfo('ERNIE-Bot-turbo',
|
||||
'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'BLOOMZ-7B': ModelInfo('BLOOMZ-7B',
|
||||
'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Llama-2-7b-chat': ModelInfo('Llama-2-7b-chat',
|
||||
'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Llama-2-13b-chat': ModelInfo('Llama-2-13b-chat',
|
||||
'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Llama-2-70b-chat': ModelInfo('Llama-2-70b-chat',
|
||||
'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Qianfan-Chinese-Llama-2-7B': ModelInfo('Qianfan-Chinese-Llama-2-7B',
|
||||
'千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文数据集上表现优异。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Qianfan-Chinese-Llama-2-13B': ModelInfo('Qianfan-Chinese-Llama-2-13B',
|
||||
'千帆团队在Llama-2-13b基础上的中文增强版本,在CMMLU、C-EVAL等中文数据集上表现优异。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential)
|
||||
|
||||
}
|
||||
|
||||
|
||||
class WenxinModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 2
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
|
||||
**model_kwargs) -> QianfanChatEndpoint:
|
||||
return QianfanChatModel(model=model_name,
|
||||
qianfan_ak=model_credential.get('api_key'),
|
||||
qianfan_sk=model_credential.get('secret_key'),
|
||||
streaming=model_kwargs.get('streaming', False))
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
def get_model_list(self, model_type):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
raise AppApiException(500, f'不支持的模型:{model_name}')
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_wenxin_provider', name='Azure OpenAI', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'wenxin_model_provider', 'icon',
|
||||
'azure_icon_svg')))
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: provider_serializers.py
|
||||
@date:2023/11/2 14:01
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.rsa_util import encrypt, decrypt
|
||||
from setting.models.model_management import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
class ModelSerializer(serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
name = serializers.CharField(required=False)
|
||||
|
||||
model_type = serializers.CharField(required=False)
|
||||
|
||||
model_name = serializers.CharField(required=False)
|
||||
|
||||
def list(self, with_valid):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
user_id = self.data.get('user_id')
|
||||
name = self.data.get('name')
|
||||
model_query_set = QuerySet(Model).filter(user_id=user_id)
|
||||
query_params = {}
|
||||
if name is not None:
|
||||
query_params['name__contains'] = name
|
||||
if self.data.get('model_type') is not None:
|
||||
query_params['model_type'] = self.data.get('model_type')
|
||||
if self.data.get('model_name') is not None:
|
||||
query_params['model_name'] = self.data.get('model_name')
|
||||
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)]
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
user_id = serializers.CharField(required=True)
|
||||
|
||||
name = serializers.CharField(required=True)
|
||||
|
||||
provider = serializers.CharField(required=True)
|
||||
|
||||
model_type = serializers.CharField(required=True)
|
||||
|
||||
model_name = serializers.CharField(required=True)
|
||||
|
||||
credential = serializers.DictField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if QuerySet(Model).filter(user_id=self.data.get('user_id'),
|
||||
name=self.data.get('name')).exists():
|
||||
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
|
||||
# 校验模型认证数据
|
||||
ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'),
|
||||
self.data.get(
|
||||
'model_name')).is_valid(
|
||||
self.data.get('model_type'),
|
||||
self.data.get('model_name'),
|
||||
self.data.get('credential'),
|
||||
raise_exception=True)
|
||||
|
||||
def insert(self, user_id, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
credential = self.data.get('credential')
|
||||
name = self.data.get('name')
|
||||
provider = self.data.get('provider')
|
||||
model_type = self.data.get('model_type')
|
||||
model_name = self.data.get('model_name')
|
||||
model_credential_str = json.dumps(credential)
|
||||
model = Model(id=uuid.uuid1(), user_id=user_id, name=name,
|
||||
credential=encrypt(model_credential_str),
|
||||
provider=provider, model_type=model_type, model_name=model_name)
|
||||
model.save()
|
||||
return ModelSerializer.Operate(data={'id': model.id}).one(user_id, with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def model_to_dict(model: Model):
|
||||
credential = json.loads(decrypt(model.credential))
|
||||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name,
|
||||
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
||||
model.model_name).encryption_dict(
|
||||
credential)}
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
id = serializers.UUIDField(required=True)
|
||||
|
||||
def one(self, user_id, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
model = QuerySet(Model).get(id=self.data.get('id'), user_id=user_id)
|
||||
return ModelSerializer.model_to_dict(model)
|
||||
|
||||
|
||||
class ProviderSerializer(serializers.Serializer):
|
||||
provider = serializers.CharField(required=True)
|
||||
|
||||
method = serializers.CharField(required=True)
|
||||
|
||||
def exec(self, exec_params: Dict[str, object], with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
|
||||
provider = self.data.get('provider')
|
||||
method = self.data.get('method')
|
||||
return getattr(ModelProvideConstants[provider].value, method)(exec_params)
|
||||
|
|
@ -233,12 +233,12 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer):
|
|||
member_permission_list))
|
||||
# 分为 APPLICATION DATASET俩组
|
||||
groups = itertools.groupby(
|
||||
list(map(lambda m: {**m, 'member_id': member_id,
|
||||
'operate': dict(
|
||||
map(lambda key: (key, True if m.get('operate') is not None and m.get(
|
||||
'operate').__contains__(key) else False),
|
||||
[Operate.USE.value, Operate.MANAGE.value]))},
|
||||
member_permission_list)),
|
||||
sorted(list(map(lambda m: {**m, 'member_id': member_id,
|
||||
'operate': dict(
|
||||
map(lambda key: (key, True if m.get('operate') is not None and m.get(
|
||||
'operate').__contains__(key) else False),
|
||||
[Operate.USE.value, Operate.MANAGE.value]))},
|
||||
member_permission_list)), key=lambda x: x.get('type')),
|
||||
key=lambda x: x.get('type'))
|
||||
return dict([(key, list(group)) for key, group in groups])
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: provide_api.py
|
||||
@date:2023/11/2 14:25
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg import openapi
|
||||
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
|
||||
|
||||
class ModelQueryApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='name',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='模型名称'),
|
||||
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='模型类型'),
|
||||
openapi.Parameter(name='model_name', in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='基础模型名称')
|
||||
]
|
||||
|
||||
|
||||
class ModelCreateApi(ApiMixin):
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||
title="调用函数所需要的参数",
|
||||
description="调用函数所需要的参数",
|
||||
required=['provide', 'model_info'],
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING,
|
||||
title="模型名称",
|
||||
description="模型名称"),
|
||||
'provider': openapi.Schema(type=openapi.TYPE_STRING,
|
||||
title="供应商",
|
||||
description="供应商"),
|
||||
'model_type': openapi.Schema(type=openapi.TYPE_STRING,
|
||||
title="供应商",
|
||||
description="供应商"),
|
||||
'model_name': openapi.Schema(type=openapi.TYPE_STRING,
|
||||
title="供应商",
|
||||
description="供应商"),
|
||||
'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||
title="模型证书信息",
|
||||
description="模型证书信息")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ProvideApi(ApiMixin):
|
||||
class ModelTypeList(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='provider',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='供应名称'),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['key', 'value'],
|
||||
properties={
|
||||
'key': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型描述",
|
||||
description="模型类型描述", default="大语言模型"),
|
||||
'value': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值",
|
||||
description="模型类型值", default="LLM"),
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class ModelList(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='provider',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='供应名称'),
|
||||
openapi.Parameter(name='model_type',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='模型类型'),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['name', 'desc', 'model_type'],
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="模型名称",
|
||||
description="模型名称", default="模型名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="模型描述",
|
||||
description="模型描述", default="xxx模型"),
|
||||
'model_type': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值",
|
||||
description="模型类型值", default="LLM"),
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class ModelForm(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='provider',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='供应名称'),
|
||||
openapi.Parameter(name='model_type',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='模型类型'),
|
||||
openapi.Parameter(name='model_name',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='模型名称'),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='provider',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='供应商'),
|
||||
openapi.Parameter(name='method',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='需要执行的函数'),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||
title="调用函数所需要的参数",
|
||||
description="调用函数所需要的参数",
|
||||
)
|
||||
|
|
@ -5,5 +5,14 @@ from . import views
|
|||
app_name = "team"
|
||||
urlpatterns = [
|
||||
path('team/member', views.TeamMember.as_view(), name="team"),
|
||||
path('team/member/<str:member_id>', views.TeamMember.Operate.as_view(), name='member')
|
||||
path('team/member/<str:member_id>', views.TeamMember.Operate.as_view(), name='member'),
|
||||
path('provider/<str:provider>/<str:method>', views.Provide.Exec.as_view(), name='provide_exec'),
|
||||
path('provider', views.Provide.as_view(), name='provide'),
|
||||
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
|
||||
path('provider/model_list', views.Provide.ModelList.as_view(),
|
||||
name="provider/model_name_list"),
|
||||
path('provider/model_form', views.Provide.ModelForm.as_view(),
|
||||
name="provider/model_form"),
|
||||
path('model', views.Model.as_view(), name='model'),
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ from rest_framework.decorators import action
|
|||
from rest_framework.views import APIView
|
||||
from rest_framework.views import Request
|
||||
|
||||
from common.auth import TokenAuth
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import PermissionConstants
|
||||
from common.response import result
|
||||
from setting.serializers.team_serializers import TeamMemberSerializer, get_response_body_api, \
|
||||
UpdateTeamMemberPermissionSerializer
|
||||
|
|
@ -23,14 +24,18 @@ class TeamMember(APIView):
|
|||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取团队成员列表",
|
||||
operation_id="获取团员成员列表",
|
||||
responses=result.get_api_response(get_response_body_api()))
|
||||
responses=result.get_api_response(get_response_body_api()),
|
||||
tags=["团队"])
|
||||
@has_permissions(PermissionConstants.TEAM_READ)
|
||||
def get(self, request: Request):
|
||||
return result.success(TeamMemberSerializer(data={'team_id': str(request.user.id)}).list_member())
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="添加成员",
|
||||
operation_id="添加成员",
|
||||
request_body=TeamMemberSerializer().get_request_body_api())
|
||||
request_body=TeamMemberSerializer().get_request_body_api(),
|
||||
tags=["团队"])
|
||||
@has_permissions(PermissionConstants.TEAM_CREATE)
|
||||
def post(self, request: Request):
|
||||
team = TeamMemberSerializer(data={'team_id': str(request.user.id)})
|
||||
return result.success((team.add_member(**request.data)))
|
||||
|
|
@ -41,7 +46,9 @@ class TeamMember(APIView):
|
|||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取团队成员权限",
|
||||
operation_id="获取团队成员权限",
|
||||
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api())
|
||||
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(),
|
||||
tags=["团队"])
|
||||
@has_permissions(PermissionConstants.TEAM_READ)
|
||||
def get(self, request: Request, member_id: str):
|
||||
return result.success(TeamMemberSerializer.Operate(
|
||||
data={'member_id': member_id, 'team_id': str(request.user.id)}).list_member_permission())
|
||||
|
|
@ -50,8 +57,10 @@ class TeamMember(APIView):
|
|||
@swagger_auto_schema(operation_summary="修改团队成员权限",
|
||||
operation_id="修改团队成员权限",
|
||||
request_body=UpdateTeamMemberPermissionSerializer().get_request_body_api(),
|
||||
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api()
|
||||
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(),
|
||||
tags=["团队"]
|
||||
)
|
||||
@has_permissions(PermissionConstants.TEAM_EDIT)
|
||||
def put(self, request: Request, member_id: str):
|
||||
return result.success(TeamMemberSerializer.Operate(
|
||||
data={'member_id': member_id, 'team_id': str(request.user.id)}).edit(request.data))
|
||||
|
|
@ -59,8 +68,10 @@ class TeamMember(APIView):
|
|||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="移除成员",
|
||||
operation_id="移除成员",
|
||||
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api()
|
||||
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(),
|
||||
tags=["团队"]
|
||||
)
|
||||
@has_permissions(PermissionConstants.TEAM_DELETE)
|
||||
def delete(self, request: Request, member_id: str):
|
||||
return result.success(TeamMemberSerializer.Operate(
|
||||
data={'member_id': member_id, 'team_id': str(request.user.id)}).delete())
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@
|
|||
@desc:
|
||||
"""
|
||||
from .Team import *
|
||||
from .model import *
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: model.py
|
||||
@date:2023/11/2 13:55
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.views import Request
|
||||
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import PermissionConstants
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer
|
||||
from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi
|
||||
|
||||
|
||||
class Model(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建模型",
|
||||
operation_id="创建模型",
|
||||
request_body=ModelCreateApi.get_request_body_api()
|
||||
, tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_CREATE)
|
||||
def post(self, request: Request):
|
||||
return result.success(
|
||||
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
||||
with_valid=True))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取模型列表",
|
||||
operation_id="获取模型列表",
|
||||
manual_parameters=ModelQueryApi.get_request_params_api()
|
||||
, tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def get(self, request: Request):
|
||||
return result.success(
|
||||
ModelSerializer.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
|
||||
with_valid=True))
|
||||
|
||||
|
||||
class Provide(APIView):
|
||||
class Exec(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="调用供应商函数,获取表单数据",
|
||||
operation_id="调用供应商函数,获取表单数据",
|
||||
manual_parameters=ProvideApi.get_request_params_api(),
|
||||
request_body=ProvideApi.get_request_body_api()
|
||||
, tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def post(self, request: Request, provider: str, method: str):
|
||||
return result.success(
|
||||
ProviderSerializer(data={'provider': provider, 'method': method}).exec(request.data, with_valid=True))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取模型供应商数据",
|
||||
operation_id="获取模型供应商列表"
|
||||
, tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def get(self, request: Request):
|
||||
return result.success(
|
||||
[ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in
|
||||
ModelProvideConstants.__members__])
|
||||
|
||||
class ModelTypeList(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取模型类型列表",
|
||||
operation_id="获取模型类型类型列表",
|
||||
manual_parameters=ProvideApi.ModelTypeList.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ProvideApi.ModelTypeList.get_response_body_api())
|
||||
, tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def get(self, request: Request):
|
||||
provider = request.query_params.get('provider')
|
||||
return result.success(ModelProvideConstants[provider].value.get_model_type_list())
|
||||
|
||||
class ModelList(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取模型列表",
|
||||
operation_id="获取模型创建表单",
|
||||
manual_parameters=ProvideApi.ModelList.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api())
|
||||
, tags=["模型"]
|
||||
)
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def get(self, request: Request):
|
||||
provider = request.query_params.get('provider')
|
||||
model_type = request.query_params.get('model_type')
|
||||
|
||||
return result.success(
|
||||
ModelProvideConstants[provider].value.get_model_list(
|
||||
model_type))
|
||||
|
||||
class ModelForm(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取模型创建表单",
|
||||
operation_id="获取模型创建表单",
|
||||
manual_parameters=ProvideApi.ModelForm.get_request_params_api(),
|
||||
tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def get(self, request: Request):
|
||||
provider = request.query_params.get('provider')
|
||||
model_type = request.query_params.get('model_type')
|
||||
model_name = request.query_params.get('model_name')
|
||||
return result.success(
|
||||
ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list())
|
||||
|
|
@ -99,6 +99,10 @@ CACHES = {
|
|||
"token_cache": {
|
||||
'BACKEND': 'common.cache.file_cache.FileCache',
|
||||
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
|
||||
},
|
||||
"chat_cache": {
|
||||
'BACKEND': 'common.cache.file_cache.FileCache',
|
||||
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "chat_cache") # 文件夹路径
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ The `urlpatterns` list routes URLs to views. For more information please see:
|
|||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
1. Add an import: froms my_app import views
|
||||
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
1. Add an import: froms other_app.views import Home
|
||||
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
1. Import the include() function: froms django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
import os
|
||||
|
|
@ -43,7 +43,8 @@ schema_view = get_schema_view(
|
|||
urlpatterns = [
|
||||
path("api/", include("users.urls")),
|
||||
path("api/", include("dataset.urls")),
|
||||
path("api/", include("setting.urls"))
|
||||
path("api/", include("setting.urls")),
|
||||
path("api/", include("application.urls"))
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ application = get_wsgi_application()
|
|||
|
||||
|
||||
def post_handler():
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
ListenerManagement().run()
|
||||
ListenerManagement.init_embedding_model_signal.send()
|
||||
from common import event
|
||||
event.run()
|
||||
event.ListenerManagement.init_embedding_model_signal.send()
|
||||
|
||||
|
||||
post_handler()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from django.db.models import Q
|
|||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.exception_code_constants import ExceptionCodeConstants
|
||||
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
|
@ -66,7 +67,8 @@ class LoginSerializer(ApiMixin, serializers.Serializer):
|
|||
:return: 用户Token(认证信息)
|
||||
"""
|
||||
user = self.is_valid()
|
||||
token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email})
|
||||
token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email,
|
||||
'type': AuthenticationType.USER.value})
|
||||
return token
|
||||
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
@desc:
|
||||
"""
|
||||
from django.core import cache
|
||||
from django.db.models import QuerySet
|
||||
from drf_yasg import openapi
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
|
|
@ -21,7 +20,6 @@ from common.auth.authentication import has_permissions
|
|||
from common.constants.permission_constants import PermissionConstants, CompareConstants
|
||||
from common.response import result
|
||||
from smartdoc.settings import JWT_AUTH
|
||||
from users.models.user import User as UserModel
|
||||
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
|
||||
RePasswordSerializer, \
|
||||
SendEmailSerializer, UserProfile
|
||||
|
|
@ -36,7 +34,8 @@ class User(APIView):
|
|||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取当前用户信息",
|
||||
operation_id="获取当前用户信息",
|
||||
responses=result.get_api_response(UserProfile.get_response_body_api()))
|
||||
responses=result.get_api_response(UserProfile.get_response_body_api()),
|
||||
tags=['用户'])
|
||||
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request):
|
||||
return result.success(UserProfile.get_user_profile(request.user))
|
||||
|
|
@ -58,7 +57,8 @@ class ResetCurrentUserPasswordView(APIView):
|
|||
description="密码")
|
||||
}
|
||||
),
|
||||
responses=RePasswordSerializer().get_response_body_api())
|
||||
responses=RePasswordSerializer().get_response_body_api(),
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
data = {'email': request.user.email}
|
||||
data.update(request.data)
|
||||
|
|
@ -77,7 +77,8 @@ class SendEmailToCurrentUserView(APIView):
|
|||
@permission_classes((AllowAny,))
|
||||
@swagger_auto_schema(operation_summary="发送邮件到当前用户",
|
||||
operation_id="发送邮件到当前用户",
|
||||
responses=SendEmailSerializer().get_response_body_api())
|
||||
responses=SendEmailSerializer().get_response_body_api(),
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
serializer_obj = SendEmailSerializer(data={'email': request.user.email, 'type': "reset_password"})
|
||||
if serializer_obj.is_valid(raise_exception=True):
|
||||
|
|
@ -91,7 +92,8 @@ class Logout(APIView):
|
|||
@permission_classes((AllowAny,))
|
||||
@swagger_auto_schema(operation_summary="登出",
|
||||
operation_id="登出",
|
||||
responses=SendEmailSerializer().get_response_body_api())
|
||||
responses=SendEmailSerializer().get_response_body_api(),
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
token_cache.delete(request.META.get('HTTP_AUTHORIZATION', None
|
||||
))
|
||||
|
|
@ -104,7 +106,9 @@ class Login(APIView):
|
|||
@swagger_auto_schema(operation_summary="登录",
|
||||
operation_id="登录",
|
||||
request_body=LoginSerializer().get_request_body_api(),
|
||||
responses=LoginSerializer().get_response_body_api())
|
||||
responses=LoginSerializer().get_response_body_api(),
|
||||
security=[],
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
login_request = LoginSerializer(data=request.data)
|
||||
# 校验请求参数
|
||||
|
|
@ -121,7 +125,9 @@ class Register(APIView):
|
|||
@swagger_auto_schema(operation_summary="用户注册",
|
||||
operation_id="用户注册",
|
||||
request_body=RegisterSerializer().get_request_body_api(),
|
||||
responses=RegisterSerializer().get_response_body_api())
|
||||
responses=RegisterSerializer().get_response_body_api(),
|
||||
security=[],
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
serializer_obj = RegisterSerializer(data=request.data)
|
||||
if serializer_obj.is_valid(raise_exception=True):
|
||||
|
|
@ -136,7 +142,9 @@ class RePasswordView(APIView):
|
|||
@swagger_auto_schema(operation_summary="修改密码",
|
||||
operation_id="修改密码",
|
||||
request_body=RePasswordSerializer().get_request_body_api(),
|
||||
responses=RePasswordSerializer().get_response_body_api())
|
||||
responses=RePasswordSerializer().get_response_body_api(),
|
||||
security=[],
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
serializer_obj = RePasswordSerializer(data=request.data)
|
||||
return result.success(serializer_obj.reset_password())
|
||||
|
|
@ -149,7 +157,9 @@ class CheckCode(APIView):
|
|||
@swagger_auto_schema(operation_summary="校验验证码是否正确",
|
||||
operation_id="校验验证码是否正确",
|
||||
request_body=CheckCodeSerializer().get_request_body_api(),
|
||||
responses=CheckCodeSerializer().get_response_body_api())
|
||||
responses=CheckCodeSerializer().get_response_body_api(),
|
||||
security=[],
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
return result.success(CheckCodeSerializer(data=request.data).is_valid(raise_exception=True))
|
||||
|
||||
|
|
@ -160,7 +170,9 @@ class SendEmail(APIView):
|
|||
@swagger_auto_schema(operation_summary="发送邮件",
|
||||
operation_id="发送邮件",
|
||||
request_body=SendEmailSerializer().get_request_body_api(),
|
||||
responses=SendEmailSerializer().get_response_body_api())
|
||||
responses=SendEmailSerializer().get_response_body_api(),
|
||||
security=[],
|
||||
tags=['用户'])
|
||||
def post(self, request: Request):
|
||||
serializer_obj = SendEmailSerializer(data=request.data)
|
||||
if serializer_obj.is_valid(raise_exception=True):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ djangorestframework = "3.14.0"
|
|||
drf-yasg = "1.21.7"
|
||||
django-filter = "23.2"
|
||||
elasticsearch = "8.9.0"
|
||||
langchain = "0.0.274"
|
||||
langchain = "^0.0.321"
|
||||
psycopg2-binary = "2.9.7"
|
||||
jieba = "^0.42.1"
|
||||
diskcache = "^5.6.3"
|
||||
|
|
@ -22,6 +22,10 @@ chardet = "^5.2.0"
|
|||
torch = "^2.1.0"
|
||||
sentence-transformers = "^2.2.2"
|
||||
blinker = "^1.6.3"
|
||||
openai = "^0.28.1"
|
||||
tiktoken = "^0.5.1"
|
||||
qianfan = "^0.1.1"
|
||||
pycryptodome = "^3.19.0"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
Loading…
Reference in New Issue