feat: 应用相关接口,模型相关接口

This commit is contained in:
shaohuzhang1 2023-11-16 13:16:27 +08:00
parent d41f3214ec
commit f28b222388
91 changed files with 4173 additions and 138 deletions

2
.gitignore vendored
View File

@ -27,7 +27,7 @@ share/python-wheels/
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# Usually these files are written by a python script froms a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-24 12:13
# Generated by Django 4.1.10 on 2023-11-14 07:30
import django.contrib.postgres.fields
from django.db import migrations, models
@ -11,6 +11,7 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
('setting', '0003_alter_model_provider'),
('users', '0001_initial'),
('dataset', '0001_initial'),
]
@ -26,8 +27,9 @@ class Migration(migrations.Migration):
('desc', models.CharField(max_length=128, verbose_name='引用描述')),
('prologue', models.CharField(max_length=1024, verbose_name='开场白')),
('example', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=256), size=None, verbose_name='示例列表')),
('dialogue_number', models.IntegerField(default=0, verbose_name='会话数量')),
('status', models.BooleanField(default=True, verbose_name='是否发布')),
('is_active', models.BooleanField(default=True)),
('model', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user')),
],
options={
@ -35,13 +37,49 @@ class Migration(migrations.Migration):
},
),
migrations.CreateModel(
name='ApplicationDatasetMapping',
name='Chat',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('application', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('abstract', models.CharField(max_length=256, verbose_name='摘要')),
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')),
],
options={
'db_table': 'application_chat',
},
),
migrations.CreateModel(
name='ChatRecord',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('vote_status', models.CharField(choices=[('-1', '未投票'), ('0', '赞同'), ('1', '反对')], default='-1', max_length=10, verbose_name='投票')),
('source_id', models.UUIDField(verbose_name='资源id 段落/问题 id ')),
('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')),
('message_tokens', models.IntegerField(default=0, verbose_name='请求token数量')),
('answer_tokens', models.IntegerField(default=0, verbose_name='响应token数量')),
('problem_text', models.CharField(max_length=1024, verbose_name='问题')),
('answer_text', models.CharField(max_length=1024, verbose_name='答案')),
('improve_problem_id_list', django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='改进标注列表')),
('index', models.IntegerField(verbose_name='对话下标')),
('chat', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.chat')),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集')),
('paragraph', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id')),
],
options={
'db_table': 'application_chat_record',
},
),
migrations.CreateModel(
name='ApplicationDatasetMapping',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='dataset.dataset')),
],
options={
'db_table': 'application_dataset_mapping',

View File

@ -0,0 +1,25 @@
# Generated by Django 4.1.10 on 2023-11-14 08:02
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('dataset', '0001_initial'),
('application', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='dataset',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集'),
),
migrations.AlterField(
model_name='chatrecord',
name='paragraph',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id'),
),
]

View File

@ -0,0 +1,25 @@
# Generated by Django 4.1.10 on 2023-11-14 08:03
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('dataset', '0001_initial'),
('application', '0002_alter_chatrecord_dataset_alter_chatrecord_paragraph'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='dataset',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集'),
),
migrations.AlterField(
model_name='chatrecord',
name='paragraph',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id'),
),
]

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.10 on 2023-11-14 08:09
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0003_alter_chatrecord_dataset_alter_chatrecord_paragraph'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='source_id',
field=models.UUIDField(null=True, verbose_name='资源id 段落/问题 id '),
),
migrations.AlterField(
model_name='chatrecord',
name='source_type',
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('-1', '其他')], default='-1', max_length=5, verbose_name='资源类型'),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2023-11-14 08:11
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0004_alter_chatrecord_source_id_and_more'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='source_type',
field=models.CharField(blank=True, choices=[('0', '问题'), ('1', '段落')], default='0', max_length=2, null=True, verbose_name='资源类型'),
),
]

View File

@ -0,0 +1,43 @@
# Generated by Django 4.1.10 on 2023-11-15 09:55
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
('users', '0001_initial'),
('application', '0005_alter_chatrecord_source_type'),
]
operations = [
migrations.CreateModel(
name='ApplicationAccessToken',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='application.application', verbose_name='应用id')),
('access_token', models.CharField(max_length=128, unique=True, verbose_name='用户公开访问 认证token')),
('is_active', models.BooleanField(default=True, verbose_name='是否开启公开访问')),
],
options={
'db_table': 'application_access_token',
},
),
migrations.CreateModel(
name='ApplicationApiKey',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('secret_key', models.CharField(max_length=1024, unique=True, verbose_name='秘钥')),
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', verbose_name='用户id')),
],
options={
'db_table': 'application_api_key',
},
),
]

View File

@ -12,7 +12,9 @@ from django.contrib.postgres.fields import ArrayField
from django.db import models
from common.mixins.app_model_mixin import AppModelMixin
from dataset.models.data_set import DataSet
from dataset.models.data_set import DataSet, Paragraph
from embedding.models import SourceType
from setting.models.model_management import Model
from users.models import User
@ -22,16 +24,62 @@ class Application(AppModelMixin):
desc = models.CharField(max_length=128, verbose_name="引用描述")
prologue = models.CharField(max_length=1024, verbose_name="开场白")
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True))
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
status = models.BooleanField(default=True, verbose_name="是否发布")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
is_active = models.BooleanField(default=True)
model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, db_constraint=False)
class Meta:
db_table = "application"
class ApplicationDatasetMapping(AppModelMixin):
application = models.ForeignKey(Application, on_delete=models.DO_NOTHING)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE)
dataset = models.ForeignKey(DataSet, on_delete=models.CASCADE)
class Meta:
db_table = "application_dataset_mapping"
class Chat(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE)
abstract = models.CharField(max_length=256, verbose_name="摘要")
class Meta:
db_table = "application_chat"
class VoteChoices(models.TextChoices):
"""订单类型"""
UN_VOTE = -1, '未投票'
STAR = 0, '赞同'
TRAMPLE = 1, '反对'
class ChatRecord(AppModelMixin):
"""
对话日志 详情
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
default=VoteChoices.UN_VOTE)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集", blank=True, null=True)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落id", blank=True, null=True)
source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True)
source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices,
default=SourceType.PROBLEM, blank=True, null=True)
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
problem_text = models.CharField(max_length=1024, verbose_name="问题")
answer_text = models.CharField(max_length=1024, verbose_name="答案")
improve_problem_id_list = ArrayField(verbose_name="改进标注列表",
base_field=models.UUIDField(max_length=128, blank=True)
, default=list)
index = models.IntegerField(verbose_name="对话下标")
class Meta:
db_table = "application_chat_record"

View File

@ -0,0 +1,366 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file application_serializers.py
@date2023/11/7 10:02
@desc:
"""
import hashlib
import os
import uuid
from typing import Dict
from django.contrib.postgres.fields import ArrayField
from django.core import cache
from django.core import signing
from django.db import transaction, models
from django.db.models import QuerySet
from rest_framework import serializers
from application.models import Application, ApplicationDatasetMapping
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404, AppAuthenticationFailed
from common.util.file_util import get_file_content
from dataset.models import DataSet
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
from smartdoc.settings import JWT_AUTH
token_cache = cache.caches['token_cache']
class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
model_id = serializers.CharField(required=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
model_id = self.data.get('model_id')
user_id = self.data.get('user_id')
if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'模型不存在【{model_id}')
dataset_id_list = list(set(self.data.get('dataset_id_list')))
exist_dataset_id_list = [str(dataset.id) for dataset in
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
for dataset_id in dataset_id_list:
if not exist_dataset_id_list.__contains__(dataset_id):
raise AppApiException(500, f'数据集id不存在【{dataset_id}')
class ApplicationSerializerModel(serializers.ModelSerializer):
class Meta:
model = Application
fields = "__all__"
class ApplicationSerializer(serializers.Serializer):
name = serializers.CharField(required=True)
desc = serializers.CharField(required=True)
model_id = serializers.CharField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True)
prologue = serializers.CharField(required=True)
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
class AccessTokenSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
class AccessTokenEditSerializer(serializers.Serializer):
access_token_reset = serializers.UUIDField(required=False)
is_active = serializers.BooleanField(required=False)
def edit(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ApplicationSerializer.AccessTokenSerializer.AccessTokenEditSerializer(data=instance).is_valid(
raise_exception=True)
application_access_token = QuerySet(ApplicationAccessToken).get(
application_id=self.data.get('application_id'))
if 'is_active' in instance:
application_access_token.is_active = instance.get("is_active")
if 'access_token_reset' in instance and instance.get('access_token_reset'):
application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]
application_access_token.save()
return self.one(with_valid=False)
def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application_id = self.data.get("application_id")
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=application_id).first()
if application_access_token is None:
application_access_token = ApplicationAccessToken(application_id=application_id,
access_token=hashlib.md5(
str(uuid.uuid1()).encode()).hexdigest()[
8:24], is_active=True)
application_access_token.save()
return {'application_id': application_access_token.application_id,
'access_token': application_access_token.access_token,
"is_active": application_access_token.is_active}
class Authentication(serializers.Serializer):
access_token = serializers.CharField(required=True)
def auth(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
access_token = self.data.get("access_token")
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
if application_access_token is not None and application_access_token.is_active:
token = signing.dumps({'application_id': str(application_access_token.application_id),
'user_id': str(application_access_token.application.user.id),
'access_token': application_access_token.access_token,
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value})
token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'])
return token
else:
raise AppAuthenticationFailed(401, "无效的access_token")
class Edit(serializers.Serializer):
name = serializers.CharField(required=False)
desc = serializers.CharField(required=False)
model_id = serializers.CharField(required=False)
multiple_rounds_dialogue = serializers.BooleanField(required=False)
prologue = serializers.CharField(required=False)
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
@transaction.atomic
def insert(self, application: Dict):
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
ApplicationSerializer(data=application).is_valid(user_id=user_id, raise_exception=True)
application_model = ApplicationSerializer.Create.to_application_model(user_id, application)
dataset_id_list = application.get('dataset_id_list', [])
application_dataset_mapping_model_list = [
ApplicationSerializer.Create.to_application_dateset_mapping(application_model.id, dataset_id) for
dataset_id in dataset_id_list]
# 插入应用
application_model.save()
# 插入认证信息
ApplicationAccessToken(application_id=application_model.id,
access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save()
# 插入关联数据
QuerySet(ApplicationDatasetMapping).bulk_create(application_dataset_mapping_model_list)
return True
@staticmethod
def to_application_model(user_id: str, application: Dict):
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
prologue=application.get('prologue'), example=application.get('example'),
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
status=True, user_id=user_id, model_id=application.get('model_id'),
)
@staticmethod
def to_application_dateset_mapping(application_id: str, dataset_id: str):
return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id)
class Query(serializers.Serializer):
name = serializers.CharField(required=False)
desc = serializers.CharField(required=False)
user_id = serializers.UUIDField(required=True)
def get_query_set(self):
user_id = self.data.get("user_id")
query_set_dict = {}
query_set = QuerySet(model=get_dynamics_model(
{'temp_application.name': models.CharField(), 'temp_application.desc': models.CharField()}))
if "desc" in self.data and self.data.get('desc') is not None:
query_set = query_set.filter(**{'temp_application.desc__contains': self.data.get("desc")})
if "name" in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'temp_application.name__contains': self.data.get("name")})
query_set_dict['default_sql'] = query_set
query_set_dict['application_custom_sql'] = QuerySet(model=get_dynamics_model(
{'application.user_id': models.CharField(),
})).filter(
**{'application.user_id': user_id}
)
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
{'user_id': models.CharField(),
'team_member_permission.auth_target_type': models.CharField(),
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
base_field=models.CharField(max_length=256,
blank=True,
choices=AuthOperate.choices,
default=AuthOperate.USE)
)})).filter(
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'],
'team_member_permission.auth_target_type': 'APPLICATION'})
return query_set_dict
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return [ApplicationSerializer.Query.reset_application(a) for a in
native_search(self.get_query_set(), select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')))]
@staticmethod
def reset_application(application: Dict):
application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False
del application['dialogue_number']
return application
def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')),
post_records_handler=ApplicationSerializer.Query.reset_application)
class ApplicationModel(serializers.ModelSerializer):
class Meta:
model = Application
fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number', 'status']
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
raise AppApiException(500, '不存在的应用id')
def delete(self, with_valid=True):
if with_valid:
self.is_valid()
QuerySet(Application).filter(id=self.data.get('application_id')).delete()
return True
def one(self, with_valid=True):
if with_valid:
self.is_valid()
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
dataset_list = self.list_dataset(with_valid=False)
mapping_dataset_id_list = [adm.dataset_id for adm in
QuerySet(ApplicationDatasetMapping).filter(application_id=application_id)]
dataset_id_list = [d.get('id') for d in
list(filter(lambda row: mapping_dataset_id_list.__contains__(row.get('id')),
dataset_list))]
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data),
'dataset_id_list': dataset_id_list}
def profile(self, with_valid=True):
if with_valid:
self.is_valid()
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
return ApplicationSerializer.Query.reset_application(
ApplicationSerializer.ApplicationModel(application).data)
def edit(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid()
ApplicationSerializer.Edit(data=instance).is_valid(
raise_exception=True)
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue':
application.__setattr__('dialogue_number',
0 if instance.get(update_key) else ModelProvideConstants[
model.provider].value.get_dialogue_number())
else:
application.__setattr__(update_key, instance.get(update_key))
application.save()
if 'dataset_id_list' in instance:
dataset_id_list = instance.get('dataset_id_list')
# 当前用户可修改关联的数据集列表
application_dataset_id_list = [dataset_dict.get('id') for dataset_dict in
self.list_dataset(with_valid=False)]
for dataset_id in dataset_id_list:
if not application_dataset_id_list.__contains__(dataset_id):
raise AppApiException(500, f"未知的数据集id${dataset_id},无法关联")
# 删除已经关联的id
QuerySet(ApplicationDatasetMapping).filter(dataset_id__in=application_dataset_id_list,
application_id=application_id).delete()
# 插入
QuerySet(ApplicationDatasetMapping).bulk_create(
[ApplicationDatasetMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in
dataset_id_list]) if len(dataset_id_list) > 0 else None
return self.one(with_valid=False)
def list_dataset(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application = QuerySet(Application).get(id=self.data.get("application_id"))
return select_list(get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')),
[self.data.get('user_id'), application.user_id, self.data.get('user_id')])
class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
model = ApplicationApiKey
fields = "__all__"
class ApplicationKeySerializer(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True)
def generate(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get("user_id")
application_id = self.data.get("application_id")
secret_key = 'application-' + hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()
application_api_key = ApplicationApiKey(id=uuid.uuid1(), secret_key=secret_key, user_id=user_id,
application_id=application_id)
application_api_key.save()
return ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get("user_id")
application_id = self.data.get("application_id")
return [ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data for
application_api_key in
QuerySet(ApplicationApiKey).filter(user_id=user_id, application_id=application_id)]
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
api_key_id = serializers.CharField(required=True)
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
api_key_id = self.data.get("api_key_id")
application_id = self.data.get('application_id')
QuerySet(ApplicationApiKey).filter(id=api_key_id,
application_id=application_id).delete()

View File

@ -0,0 +1,167 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file chat_message_serializers.py
@date2023/11/14 13:51
@desc:
"""
import json
import uuid
from typing import List
from django.db.models import QuerySet
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from rest_framework import serializers, status
from django.core.cache import cache
from common import event
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.response import result
from dataset.models import Paragraph
from embedding.models import SourceType
from setting.models.model_management import Model
chat_cache = cache
class MessageManagement:
@staticmethod
def get_message(title: str, content: str, message: str):
if content is None:
return HumanMessage(content=message)
return HumanMessage(content=(
f'已知信息:{title}:{content} '
'根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从已知信息中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 '
f'问题是:{message}'))
class ChatMessage:
def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str,
document_id: str,
paragraph_id,
source_type: SourceType,
source_id: str,
answer: str,
message_tokens: int,
answer_token: int):
self.id = id
self.problem = problem
self.title = title
self.paragraph = paragraph
self.embedding_id = embedding_id
self.dataset_id = dataset_id
self.document_id = document_id
self.paragraph_id = paragraph_id
self.source_type = source_type
self.source_id = source_id
self.answer = answer
self.message_tokens = message_tokens
self.answer_token = answer_token
def get_chat_message(self):
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
class ChatInfo:
def __init__(self,
chat_id: str,
model: Model,
chat_model: BaseChatModel,
application_id: str | None,
dataset_id_list: List[str],
exclude_document_id_list: list[str],
dialogue_number: int):
self.chat_id = chat_id
self.application_id = application_id
self.model = model
self.chat_model = chat_model
self.dataset_id_list = dataset_id_list
self.exclude_document_id_list = exclude_document_id_list
self.dialogue_number = dialogue_number
self.chat_message_list: List[ChatMessage] = []
def append_chat_message(self, chat_message: ChatMessage):
self.chat_message_list.append(chat_message)
if self.application_id is not None:
event.ListenerChatMessage.record_chat_message_signal.send(
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
chat_message)
)
def get_context_message(self):
start_index = len(self.chat_message_list) - self.dialogue_number
return [self.chat_message_list[index].get_chat_message() for index in
range(start_index if start_index > 0 else 0, len(self.chat_message_list))]
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
def chat(self, message):
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is None:
return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期")
chat_model = chat_info.chat_model
vector = VectorStore.get_embedding_vector()
# 向量库检索
_value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list,
[chat_message.embedding_id for chat_message in
(list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))],
True,
EmbeddingModel.get_embedding_model())
# 查询段落id详情
paragraph = None
if _value is not None:
paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id'))
if paragraph is None:
vector.delete_by_paragraph_id(_value.get('paragraph_id'))
title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content)
embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get(
'id'), _value.get(
'dataset_id'), _value.get(
'document_id'), _value.get(
'paragraph_id'), _value.get(
'source_type'), _value.get(
'source_id')) if _value is not None else (None, None, None, None, None, None)
# 获取上下文
history_message = chat_info.get_context_message()
# 构建会话请求问题
chat_message = [*history_message, MessageManagement.get_message(title, content, message)]
# 对话
result_data = chat_model.stream(chat_message)
_id = str(uuid.uuid1())
def event_content(response):
all_text = ''
try:
for chunk in response:
all_text += chunk.content
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
'content': chunk.content}) + "\n\n"
chat_info.append_chat_message(
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
paragraph_id,
source_type,
source_id, all_text, chat_model.get_num_tokens_from_messages(chat_message),
chat_model.get_num_tokens(all_text)))
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
except Exception as e:
yield e
r = StreamingHttpResponse(streaming_content=event_content(result_data),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r

View File

@ -0,0 +1,282 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file chat_serializers.py
@date2023/11/14 9:59
@desc:
"""
import datetime
import json
import os
import uuid
from typing import Dict
from django.core.cache import cache
from django.db import transaction
from django.db.models import QuerySet
from rest_framework import serializers
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
from application.serializers.application_serializers import ModelDatasetAssociation
from application.serializers.chat_message_serializers import ChatInfo
from common.db.search import native_search, native_page_search, page_search
from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
from dataset.models import Document, Problem, Paragraph
from embedding.models import SourceType, Embedding
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.views import Model
from smartdoc.conf import PROJECT_DIR
chat_cache = cache
class ChatSerializers(serializers.Serializer):
class Query(serializers.Serializer):
abstract = serializers.CharField(required=False)
history_day = serializers.IntegerField(required=True)
user_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True)
def get_end_time(self):
history_day = self.data.get('history_day')
return datetime.datetime.now() - datetime.timedelta(days=history_day)
def get_query_set(self):
end_time = self.get_end_time()
return QuerySet(Chat).filter(application_id=self.data.get("application_id"), create_time__gte=end_time)
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return native_search(self.get_query_set(), select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
with_table_name=True)
def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
with_table_name=True)
class OpenChat(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
user_id = self.data.get('user_id')
application_id = self.data.get('application_id')
if not QuerySet(Application).filter(id=application_id, user_id=user_id).exists():
raise AppApiException(500, '应用不存在')
def open(self):
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).get(id=application_id)
model = application.model
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
application_id=application_id)]
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
application.dialogue_number), timeout=60 * 30)
return chat_id
class OpenTempChat(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
model_id = serializers.UUIDField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
ModelDatasetAssociation(
data={'user_id': self.data.get('user_id'), 'model_id': self.data.get('model_id'),
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
def open(self):
self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1())
model = QuerySet(Model).get(user_id=self.data.get('user_id'), id=self.data.get('model_id'))
dataset_id_list = self.data.get('dataset_id_list')
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, None, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30)
return chat_id
def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
if source_type == SourceType.PROBLEM:
problem = QuerySet(Problem).get(id=source_id)
if problem is not None:
problem.__setattr__(field, post_handler(problem))
problem.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, problem.__getattribute__(field))
embedding.save()
if source_type == SourceType.PARAGRAPH:
paragraph = QuerySet(Paragraph).get(id=source_id)
if paragraph is not None:
paragraph.__setattr__(field, post_handler(paragraph))
paragraph.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, paragraph.__getattribute__(field))
embedding.save()
class ChatRecordSerializerModel(serializers.ModelSerializer):
class Meta:
model = ChatRecord
fields = "__all__"
class ChatRecordSerializer(serializers.Serializer):
class Query(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True)
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return [ChatRecordSerializerModel(chat_record).data for chat_record in
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return page_search(current_page, page_size,
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data)
class Vote(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True)
vote_status = serializers.ChoiceField(choices=VoteChoices.choices)
@transaction.atomic
def vote(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
if not try_lock(self.data.get('chat_record_id')):
raise AppApiException(500, "正在对当前会话纪要进行投票中,请勿重复发送请求")
try:
chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'),
chat_id=self.data.get('chat_id'))
if chat_record_details_model is None:
raise AppApiException(500, "不存在的对话 chat_record_id")
vote_status = self.data.get("vote_status")
if chat_record_details_model.vote_status == VoteChoices.UN_VOTE:
if vote_status == VoteChoices.STAR:
# 点赞
chat_record_details_model.vote_status = VoteChoices.STAR
# 点赞数量 +1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num',
lambda r: r.star_num + 1)
if vote_status == VoteChoices.TRAMPLE:
# 点踩
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
# 点踩数量+1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num',
lambda r: r.trample_num + 1)
chat_record_details_model.save()
else:
if vote_status == VoteChoices.UN_VOTE:
# 取消点赞
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
chat_record_details_model.save()
if chat_record_details_model.vote_status == VoteChoices.STAR:
# 点赞数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num', lambda r: r.star_num - 1)
if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
# 点踩数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num', lambda r: r.trample_num - 1)
else:
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
finally:
un_lock(self.data.get('chat_record_id'))
return True
class ImproveSerializer(serializers.Serializer):
title = serializers.CharField(required=False)
content = serializers.CharField(required=True)
class Improve(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Document).filter(id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "文档id不正确")
@transaction.atomic
def improve(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ChatRecordSerializer.ImproveSerializer(data=instance).is_valid(raise_exception=True)
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
if chat_record is None:
raise AppApiException(500, '不存在的对话记录')
document_id = self.data.get("document_id")
dataset_id = self.data.get("dataset_id")
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem = Problem(id=uuid.uuid1(), content=chat_record.problem_text, paragraph_id=paragraph.id,
document_id=document_id, dataset_id=dataset_id)
# 插入问题
problem.save()
# 插入段落
paragraph.save()
chat_record.improve_problem_id_list.append(problem.id)
# 添加标注
chat_record.save()
return True

View File

@ -0,0 +1,8 @@
SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION
SELECT
*
FROM
application
WHERE
application."id" IN ( SELECT team_member_permission.target FROM team_member team_member LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" ${team_member_permission_custom_sql})
) temp_application ${default_sql}

View File

@ -0,0 +1,16 @@
SELECT
*
FROM
application_chat application_chat
LEFT JOIN (
SELECT COUNT
( "id" ) AS chat_record_count,
SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num,
SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num,
SUM ( CASE WHEN array_length( application_chat_record.improve_problem_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_problem_id_list, 1 ) END ) AS mark_sum,
chat_id
FROM
application_chat_record
GROUP BY
application_chat_record.chat_id
) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id

View File

@ -0,0 +1,20 @@
SELECT
*
FROM
dataset
WHERE
user_id = %s UNION
SELECT
*
FROM
dataset
WHERE
"id" IN (
SELECT
team_member_permission.target
FROM
team_member team_member
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
WHERE
( "team_member_permission"."auth_target_type" = 'DATASET' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s )
)

View File

@ -0,0 +1,165 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file application_api.py
@date2023/11/7 10:50
@desc:
"""
from drf_yasg import openapi
from common.mixins.api_mixin import ApiMixin
"""
name = serializers.CharField(required=True)
desc = serializers.CharField(required=True)
model_id = serializers.CharField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True)
prologue = serializers.CharField(required=True)
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
"""
class ApplicationApi(ApiMixin):
class Authentication(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['access_token', ],
properties={
'access_token': openapi.Schema(type=openapi.TYPE_STRING, title="应用认证token",
description="应用认证token"),
}
)
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'user_id', 'status', 'create_time',
'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="示例列表", description="示例列表"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户", description="所属用户"),
'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否发布", description='是否发布'),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联数据集Id列表",
description="关联数据集Id列表(查询详情的时候返回)")
}
)
class ApiKey(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id')
]
class Operate(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='api_key_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用api_key id')
]
class AccessToken(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id')
]
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=[],
properties={
'access_token_reset': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重置Token",
description="重置Token"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否激活", description="是否激活"),
}
)
class Create(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="示例列表", description="示例列表"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联数据集Id列表", description="关联数据集Id列表")
}
)
class Query(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='name',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='应用名称'),
openapi.Parameter(name='desc',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='应用描述')
]
class Operate(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
]

View File

@ -0,0 +1,159 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file chat_api.py
@date2023/11/7 17:29
@desc:
"""
from drf_yasg import openapi
from common.mixins.api_mixin import ApiMixin
class ChatApi(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['message'],
properties={
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
}
)
class OpenChat(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
]
class OpenTempChat(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['model_id', 'multiple_rounds_dialogue'],
properties={
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联数据集Id列表", description="关联数据集Id列表"),
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
description="是否开启多轮会话")
}
)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='history_day',
in_=openapi.IN_QUERY,
type=openapi.TYPE_NUMBER,
required=True,
description='历史天数')
]
class ChatRecordApi(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='chat_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='对话id'),
]
class ImproveApi(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='chat_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='会话id'),
openapi.Parameter(name='chat_record_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='会话记录id'),
openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
]
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['content'],
properties={
'title': openapi.Schema(type=openapi.TYPE_STRING, title="段落标题",
description="段落标题"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
description="段落内容")
}
)
class VoteApi(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='chat_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='会话id'),
openapi.Parameter(name='chat_record_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='会话记录id')
]
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['vote_status'],
properties={
'vote_status': openapi.Schema(type=openapi.TYPE_STRING, title="投票状态",
description="-1:取消投票|0:赞同|1:反对"),
}
)

35
apps/application/urls.py Normal file
View File

@ -0,0 +1,35 @@
from django.urls import path
from . import views
app_name = "application"
urlpatterns = [
path('application', views.Application.as_view(), name="application"),
path('application/profile', views.Application.Profile.as_view()),
path('application/authentication', views.Application.Authentication.as_view()),
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
path("application/<str:application_id>/api_key/<str:api_key_id>",
views.Application.ApplicationKey.Operate.as_view()),
path('application/<str:application_id>', views.Application.Operate.as_view(), name='application/operate'),
path('application/<str:application_id>/list_dataset', views.Application.ListApplicationDataSet.as_view(),
name='application/dataset'),
path('application/<str:application_id>/access_token', views.Application.AccessToken.as_view(),
name='application/access_token'),
path('application/<int:current_page>/<int:page_size>', views.Application.Page.as_view(), name='application_page'),
path('application/<str:application_id>/chat/open', views.ChatView.Open.as_view()),
path("application/chat/open", views.ChatView.OpenTemp.as_view()),
path('application/<str:application_id>/chat', views.ChatView.as_view(), name='chats'),
path('application/<str:application_id>/chat/<int:current_page>/<int:page_size>', views.ChatView.Page.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
views.ChatView.ChatRecord.Page.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote',
views.ChatView.ChatRecord.Vote.as_view(),
name=''),
path(
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve',
views.ChatView.ChatRecord.Improve.as_view(),
name=''),
path('application/chat_message/<str:chat_id>', views.ChatView.Message.as_view())
]

View File

@ -1,3 +0,0 @@
from django.shortcuts import render
# Create your views here.

View File

@ -0,0 +1,10 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2023/9/25 17:12
@desc:
"""
from .application_views import *
from .chat_views import *

View File

@ -0,0 +1,250 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file application_views.py
@date2023/10/27 14:56
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView
from application.serializers.application_serializers import ApplicationSerializer
from application.swagger_api.application_api import ApplicationApi
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import CompareConstants, PermissionConstants, Permission, Group, Operate, \
ViewPermission, RoleConstants
from common.exception.app_exception import AppAuthenticationFailed
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.dataset_serializers import DataSetSerializers
class Application(APIView):
authentication_classes = [TokenAuth]
class Profile(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用相关信息",
operation_id="获取应用相关信息",
tags=["应用/会话"])
def get(self, request: Request):
if 'application_id' in request.auth.keywords:
return result.success(ApplicationSerializer.Operate(
data={'application_id': request.auth.keywords.get('application_id'),
'user_id': request.user.id}).profile())
else:
raise AppAuthenticationFailed(401, "身份异常")
class ApplicationKey(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="新增ApiKey",
operation_id="新增ApiKey",
tags=['应用/API_KEY'],
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.ApplicationKeySerializer(
data={'application_id': application_id, 'user_id': request.user.id}).generate())
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用API_KEY列表",
operation_id="获取应用API_KEY列表",
tags=['应用/API_KEY'],
manual_parameters=ApplicationApi.ApiKey.get_request_params_api()
)
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.ApplicationKeySerializer(
data={'application_id': application_id, 'user_id': request.user.id}).list())
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除应用API_KEY",
operation_id="删除应用API_KEY",
tags=['应用/API_KEY'],
manual_parameters=ApplicationApi.ApiKey.Operate.get_request_params_api())
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND), lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE,
dynamic_tag=k.get('application_id')),
compare=CompareConstants.AND)
def delete(self, request: Request, application_id: str, api_key_id: str):
return result.success(
ApplicationSerializer.ApplicationKeySerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id,
'api_key_id': api_key_id}).delete())
class AccessToken(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="修改 应用AccessToken",
operation_id="修改 应用AccessToken",
tags=['应用/公开访问'],
manual_parameters=ApplicationApi.AccessToken.get_request_params_api(),
request_body=ApplicationApi.AccessToken.get_request_body_api())
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用 AccessToken信息",
operation_id="获取应用 AccessToken信息",
manual_parameters=ApplicationApi.AccessToken.get_request_params_api(),
tags=['应用/公开访问'],
)
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).one())
class Authentication(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="应用认证",
operation_id="应用认证",
request_body=ApplicationApi.Authentication.get_request_body_api(),
tags=["应用/认证"],
security=[])
def post(self, request: Request):
return result.success(
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth())
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建应用",
operation_id="创建应用",
request_body=ApplicationApi.Create.get_request_body_api(),
tags=['应用'])
@has_permissions(PermissionConstants.APPLICATION_CREATE, compare=CompareConstants.AND)
def post(self, request: Request):
ApplicationSerializer.Create(data={'user_id': request.user.id}).insert(request.data)
return result.success(True)
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用列表",
operation_id="获取应用列表",
manual_parameters=ApplicationApi.Query.get_request_params_api(),
responses=result.get_api_array_response(ApplicationApi.get_response_body_api()),
tags=['应用'])
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
def get(self, request: Request):
return result.success(
ApplicationSerializer.Query(
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list())
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除应用",
operation_id="删除应用",
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
responses=result.get_default_response(),
tags=['应用'])
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND),
lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE,
dynamic_tag=k.get('application_id')), compare=CompareConstants.AND)
def delete(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).delete(
with_valid=True))
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="修改应用",
operation_id="修改应用",
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
request_body=ApplicationApi.Create.get_request_body_api(),
responses=result.get_api_array_response(ApplicationApi.get_response_body_api()),
tags=['应用'])
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit(
request.data))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用详情",
operation_id="获取应用详情",
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
responses=result.get_api_array_response(ApplicationApi.get_response_body_api()),
tags=['应用'])
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).one())
class ListApplicationDataSet(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取当前应用可使用的数据集",
operation_id="获取当前应用可使用的数据集",
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
tags=['应用'])
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).list_dataset())
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="分页获取应用列表",
operation_id="分页获取应用列表",
manual_parameters=result.get_page_request_params(
ApplicationApi.Query.get_request_params_api()),
responses=result.get_page_api_response(ApplicationApi.get_response_body_api()),
tags=['应用'])
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
def get(self, request: Request, current_page: int, page_size: int):
return result.success(
ApplicationSerializer.Query(
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).page(
current_page, page_size))

View File

@ -0,0 +1,198 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file chat_views.py
@date2023/11/14 9:53
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView
from application.serializers.chat_message_serializers import ChatMessageSerializer
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate, \
RoleConstants, ViewPermission, CompareConstants
from common.exception.app_exception import AppAuthenticationFailed
from common.response import result
from common.util.common import query_params_to_single_dict
class ChatView(APIView):
authentication_classes = [TokenAuth]
class Open(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取会话id,根据应用id",
operation_id="获取会话id,根据应用id",
manual_parameters=ChatApi.OpenChat.get_request_params_api(),
tags=["应用/会话"])
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND)
)
def get(self, request: Request, application_id: str):
return result.success(ChatSerializers.OpenChat(
data={'user_id': request.user.id, 'application_id': application_id}).open())
class OpenTemp(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="获取会话id(根据模型id,数据集列表,是否多轮会话)",
operation_id="获取会话id",
request_body=ChatApi.OpenTempChat.get_request_body_api(),
tags=["应用/会话"])
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
def post(self, request: Request):
return result.success(ChatSerializers.OpenTempChat(
data={**request.data, 'user_id': request.user.id}).open())
class Message(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="对话",
operation_id="对话",
request_body=ChatApi.get_request_body_api(),
tags=["应用/会话"])
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def post(self, request: Request, chat_id: str):
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表",
operation_id="获取对话列表",
manual_parameters=ChatApi.get_request_params_api(),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str):
return result.success(ChatSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
'user_id': request.user.id}).list())
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="分页获取对话列表",
operation_id="分页获取对话列表",
manual_parameters=result.get_page_request_params(ChatApi.get_request_params_api()),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, current_page: int, page_size: int):
return result.success(ChatSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
'user_id': request.user.id}).page(current_page=current_page,
page_size=page_size))
class ChatRecord(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录列表",
operation_id="获取对话记录列表",
manual_parameters=ChatRecordApi.get_request_params_api(),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, chat_id: str):
return result.success(ChatRecordSerializer.Query(
data={'application_id': application_id,
'chat_id': chat_id}).list())
class Page(APIView):
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录列表",
operation_id="获取对话记录列表",
manual_parameters=result.get_page_request_params(
ChatRecordApi.get_request_params_api()),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int):
return result.success(ChatRecordSerializer.Query(
data={'application_id': application_id,
'chat_id': chat_id}).page(current_page, page_size))
class Vote(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="点赞,点踩",
operation_id="点赞,点踩",
manual_parameters=VoteApi.get_request_params_api(),
request_body=VoteApi.get_request_body_api(),
responses=result.get_default_response(),
tags=["应用/会话"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
return result.success(ChatRecordSerializer.Vote(
data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id,
'chat_record_id': chat_record_id}).vote())
class Improve(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="标注",
operation_id="标注",
manual_parameters=ImproveApi.get_request_params_api(),
request_body=ImproveApi.get_request_body_api(),
responses=result.get_default_response(),
tags=["应用/对话日志/标注"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.MANAGE,
dynamic_tag=keywords.get(
'dataset_id'))],
)
))
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str, dataset_id: str,
document_id: str):
return result.success(ChatRecordSerializer.Improve(
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data))

View File

@ -12,7 +12,10 @@ from django.core import signing
from django.db.models import QuerySet
from rest_framework.authentication import TokenAuthentication
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \
Operate
from common.exception.app_exception import AppAuthenticationFailed
from smartdoc.settings import JWT_AUTH
from users.models.user import User, get_user_dynamics_permission
@ -34,13 +37,29 @@ class TokenAuth(TokenAuthentication):
if auth is None:
raise AppAuthenticationFailed(1003, '未登录,请先登录')
try:
if str(auth).startswith("application-"):
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first()
if application_api_key is None:
raise AppAuthenticationFailed(500, "secret_key 无效")
permission_list = [Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_api_key.application_id)),
Permission(group=Group.APPLICATION,
operate=Operate.MANAGE,
dynamic_tag=str(
application_api_key.application_id))
]
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
permission_list=permission_list,
application_id=application_api_key.application_id)
# 解析 token
user = signing.loads(auth)
if 'id' in user:
cache_token = token_cache.get(auth)
if cache_token is None:
raise AppAuthenticationFailed(1002, "登录过期")
user = QuerySet(User).get(id=user['id'])
auth_details = signing.loads(auth)
cache_token = token_cache.get(auth)
if cache_token is None:
raise AppAuthenticationFailed(1002, "登录过期")
if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value:
user = QuerySet(User).get(id=auth_details['id'])
# 续期
token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
rule = RoleConstants[user.role]
@ -49,6 +68,26 @@ class TokenAuth(TokenAuthentication):
permission_list += get_user_dynamics_permission(str(user.id))
return user, Auth(role_list=[rule],
permission_list=permission_list)
if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get(
'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=auth_details.get('application_id')).first()
if application_access_token is None:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.is_active:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.access_token == auth_details.get('access_token'):
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
return application_access_token.application.user, Auth(
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
permission_list=[
Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_access_token.application_id))],
application_id=application_access_token.application_id
)
else:
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")

View File

@ -35,23 +35,30 @@ def exist_role_by_role_constants(user_role: List[RoleConstants],
return any(list(map(lambda up: role_list.__contains__(up), user_role)))
def exist_permissions_by_view_permission(user_role: List[RoleConstants], user_permission: List[PermissionConstants],
permission: ViewPermission):
def exist_permissions_by_view_permission(user_role: List[RoleConstants],
user_permission: List[PermissionConstants | object],
permission: ViewPermission, request, **kwargs):
"""
用户是否存在这些权限
:param request:
:param user_role: 用户角色
:param user_permission: 用户权限
:param permission: 所属权限
:return: 是否存在 True False
"""
role_ok = any(list(map(lambda ur: permission.roleList.__contains__(ur), user_role)))
permission_ok = any(list(map(lambda up: permission.permissionList.__contains__(up), user_permission)))
permission_list = [user_p(request, kwargs) if callable(user_p) else user_p for user_p in
permission.permissionList
]
permission_ok = any(list(map(lambda up: permission_list.__contains__(up),
user_permission)))
return role_ok | permission_ok if permission.compare == CompareConstants.OR else role_ok & permission_ok
def exist_permissions(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission):
def exist_permissions(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request,
**kwargs):
if isinstance(permission, ViewPermission):
return exist_permissions_by_view_permission(user_role, user_permission, permission)
return exist_permissions_by_view_permission(user_role, user_permission, permission, request, **kwargs)
elif isinstance(permission, RoleConstants):
return exist_role_by_role_constants(user_role, [permission])
elif isinstance(permission, PermissionConstants):
@ -64,9 +71,9 @@ def exist_permissions(user_role: List[RoleConstants], user_permission: List[Perm
def exist(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request, **kwargs):
if callable(permission):
p = permission(request, kwargs)
return exist_permissions(user_role, user_permission, p)
return exist_permissions(user_role, user_permission, p, request)
else:
return exist_permissions(user_role, user_permission, permission)
return exist_permissions(user_role, user_permission, permission, request, **kwargs)
def has_permissions(*permission, compare=CompareConstants.OR):

View File

@ -0,0 +1,16 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file authentication_type.py
@date2023/11/14 20:03
@desc:
"""
from enum import Enum
class AuthenticationType(Enum):
# 或者
USER = "USER"
# 并且
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"

View File

@ -21,6 +21,10 @@ class Group(Enum):
SETTING = "SETTING"
MODEL = "MODEL"
TEAM = "TEAM"
class Operate(Enum):
"""
@ -40,16 +44,24 @@ class Operate(Enum):
USE = "USE"
class RoleGroup(Enum):
USER = 'USER'
APPLICATION_KEY = "APPLICATION_KEY"
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
class Role:
def __init__(self, name: str, decs: str):
def __init__(self, name: str, decs: str, group: RoleGroup):
self.name = name
self.decs = decs
self.group = group
class RoleConstants(Enum):
ADMIN = Role("管理员", "管理员,预制目前不会使用")
USER = Role("用户", "用户所有权限")
SESSION = Role("会话", decs="只拥有应用会话框接口权限")
ADMIN = Role("管理员", "管理员,预制目前不会使用", RoleGroup.USER)
USER = Role("用户", "用户所有权限", RoleGroup.USER)
APPLICATION_ACCESS_TOKEN = Role("会话", "只拥有应用会话框接口权限", RoleGroup.APPLICATION_ACCESS_TOKEN),
APPLICATION_KEY = Role("应用私钥", "应用私钥", RoleGroup.APPLICATION_KEY)
class Permission:
@ -90,9 +102,29 @@ class PermissionConstants(Enum):
APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ,
roles=[RoleConstants.ADMIN, RoleConstants.USER])
APPLICATION_CREATE = Permission(group=Group.APPLICATION, operate=Operate.CREATE,
roles=[RoleConstants.ADMIN, RoleConstants.USER])
SETTING_READ = Permission(group=Group.SETTING, operate=Operate.READ,
roles=[RoleConstants.ADMIN, RoleConstants.USER])
MODEL_READ = Permission(group=Group.MODEL, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER])
MODEL_EDIT = Permission(group=Group.MODEL, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER])
MODEL_DELETE = Permission(group=Group.MODEL, operate=Operate.DELETE,
roles=[RoleConstants.ADMIN, RoleConstants.USER])
MODEL_CREATE = Permission(group=Group.MODEL, operate=Operate.CREATE,
roles=[RoleConstants.ADMIN, RoleConstants.USER])
TEAM_READ = Permission(group=Group.TEAM, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER])
TEAM_CREATE = Permission(group=Group.TEAM, operate=Operate.CREATE, roles=[RoleConstants.ADMIN, RoleConstants.USER])
TEAM_DELETE = Permission(group=Group.TEAM, operate=Operate.DELETE, roles=[RoleConstants.ADMIN, RoleConstants.USER])
TEAM_EDIT = Permission(group=Group.TEAM, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER])
def get_permission_list_by_role(role: RoleConstants):
"""
@ -110,9 +142,11 @@ class Auth:
用于存储当前用户的角色和权限
"""
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants]):
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission],
**keywords):
self.role_list = role_list
self.permission_list = permission_list
self.keywords = keywords
class CompareConstants(Enum):
@ -123,7 +157,7 @@ class CompareConstants(Enum):
class ViewPermission:
def __init__(self, roleList: List[RoleConstants], permissionList: List[PermissionConstants],
def __init__(self, roleList: List[RoleConstants], permissionList: List[PermissionConstants | object],
compare=CompareConstants.OR):
self.roleList = roleList
self.permissionList = permissionList

View File

@ -19,7 +19,7 @@ class AppSQLCompiler(SQLCompiler):
field_replace_dict = {}
self.field_replace_dict = field_replace_dict
def get_query_str(self, with_limits=True, with_table_name=True):
def get_query_str(self, with_limits=True, with_table_name=False):
refcounts_before = self.query.alias_refcount.copy()
try:
extra_select, order_by, group_by = self.pre_sql_setup()

View File

@ -32,9 +32,10 @@ def get_dynamics_model(attr: dict, table_name='dynamics'):
def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] = None):
field_replace_dict: None | Dict[str, Dict[str, str]] = None, with_table_name=False):
"""
生成 查询sql
:param with_table_name:
:param queryset_dict: 多条件 查询条件
:param select_string: 查询sql
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
@ -45,7 +46,8 @@ def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string
result_params = []
for key in queryset_dict.keys():
value = queryset_dict.get(key)
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key))
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key),
with_table_name)
params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
select_string = select_string.replace("${" + key + "}", sql)
@ -55,7 +57,7 @@ def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string
def generate_sql_by_query(queryset: QuerySet, select_string: str,
field_replace_dict: None | Dict[str, str] = None):
field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
"""
生成 查询sql
:param queryset: 查询条件
@ -63,13 +65,14 @@ def generate_sql_by_query(queryset: QuerySet, select_string: str,
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
"""
sql, params = compiler_queryset(queryset, field_replace_dict)
sql, params = compiler_queryset(queryset, field_replace_dict, with_table_name)
return select_string + " " + sql, params
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None):
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
"""
解析 queryset查询对象
:param with_table_name:
:param queryset: 查询对象
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
@ -80,15 +83,16 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s
field_replace_dict = get_field_replace_dict(queryset)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
sql, params = app_sql_compiler.get_query_str(with_table_name)
return sql, params
def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
with_search_one=False):
with_search_one=False, with_table_name=False):
"""
复杂查询
:param with_table_name: 生成sql是否包含表名
:param queryset: 查询条件构造器
:param select_string: 查询前缀 不包括 where limit 等信息
:param field_replace_dict: 需要替换的字段
@ -96,9 +100,9 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
:return: 查询结果
"""
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
else:
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
if with_search_one:
return select_one(exec_sql, exec_params)
else:
@ -121,9 +125,11 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco
def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict=None,
post_records_handler=lambda r: r):
post_records_handler=lambda r: r,
with_table_name=False):
"""
复杂分页查询
:param with_table_name:
:param current_page: 当前页
:param page_size: 每页大小
:param queryset: 查询条件
@ -133,9 +139,9 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet | D
:return: 分页结果
"""
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
else:
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
total = select_one(total_sql, exec_params)
limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(

View File

@ -0,0 +1,15 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2023/11/10 10:43
@desc:
"""
from .listener_manage import *
from .listener_chat_message import *
def run():
listener_manage.ListenerManagement().run()
listener_chat_message.ListenerChatMessage().run()

View File

@ -0,0 +1,18 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common.py
@date2023/11/10 10:41
@desc:
"""
from concurrent.futures import ThreadPoolExecutor
work_thread_pool = ThreadPoolExecutor(5)
def poxy(poxy_function):
def inner(args):
work_thread_pool.submit(poxy_function, args)
return inner

View File

@ -0,0 +1,54 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file listener_manage.py
@date2023/10/20 14:01
@desc:
"""
from blinker import signal
from django.db.models import QuerySet
from application.models import ChatRecord, Chat
from application.serializers.chat_message_serializers import ChatMessage
from common.event.common import poxy
class RecordChatMessageArgs:
def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage):
self.index = index
self.chat_id = chat_id
self.application_id = application_id
self.chat_message = chat_message
class ListenerChatMessage:
record_chat_message_signal = signal("record_chat_message")
@staticmethod
@poxy
def record_chat_message(args: RecordChatMessageArgs):
if not QuerySet(Chat).filter(id=args.chat_id).exists():
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
# 插入会话记录
try:
chat_record = ChatRecord(
id=args.chat_message.id,
chat_id=args.chat_id,
dataset_id=args.chat_message.dataset_id,
paragraph_id=args.chat_message.paragraph_id,
source_id=args.chat_message.source_id,
source_type=args.chat_message.source_type,
problem_text=args.chat_message.problem,
answer_text=args.chat_message.answer,
index=args.index,
message_tokens=args.chat_message.message_tokens,
answer_tokens=args.chat_message.answer_token)
chat_record.save()
except Exception as e:
print(e)
def run(self):
# 记录会话
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)

View File

@ -7,7 +7,6 @@
@desc:
"""
import os
from concurrent.futures import ThreadPoolExecutor
import django.db.models
from blinker import signal
@ -15,22 +14,14 @@ from django.db.models import QuerySet
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy
from common.util.file_util import get_file_content
from dataset.models import Paragraph, Status, Document
from embedding.models import SourceType
from smartdoc.conf import PROJECT_DIR
def poxy(poxy_function):
def inner(args):
ListenerManagement.work_thread_pool.submit(poxy_function, args)
return inner
class ListenerManagement:
work_thread_pool = ThreadPoolExecutor(5)
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
embedding_by_dataset_signal = signal("embedding_by_dataset")

View File

@ -20,6 +20,17 @@ class AppApiException(Exception):
self.message = message
class NotFound404(AppApiException):
"""
未认证(未登录)异常
"""
status_code = status.HTTP_404_NOT_FOUND
def __init__(self, code, message):
self.code = code
self.message = message
class AppAuthenticationFailed(AppApiException):
"""
未认证(未登录)异常

View File

@ -0,0 +1,22 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2023/10/31 17:56
@desc:
"""
from .array_card import *
from .base_field import *
from .base_form import *
from .combobox_field import *
from .multi_select import *
from .number_input_field import *
from .object_card import *
from .password_input import *
from .radio_field import *
from .single_select_field import *
from .switch_btn import *
from .tab_card import *
from .table_radio import *
from .text_input_field import *

View File

@ -0,0 +1,36 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file array_card.py
@date2023/10/31 18:03
@desc:
"""
from typing import List, Dict
from common.froms.base_field import BaseExecField, TriggerType
class ArrayCard(BaseExecField):
"""
收集List[Object]
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("ArrayCard", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)

View File

@ -0,0 +1,159 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_field.py
@date2023/10/31 18:07
@desc:
"""
from enum import Enum
from typing import List, Dict
class TriggerType(Enum):
# 执行函数获取 OptionList数据
OPTION_LIST = 'OPTION_LIST'
# 执行函数获取子表单
CHILD_FORMS = 'CHILD_FORMS'
class BaseField:
def __init__(self,
input_type: str,
label: str,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
"""
:param input_type: 字段
:param label: 提示
:param default_value: 默认值
:param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示
:param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段
:param relation_trigger_field_list: 指定那些字段有值的时候 调用当前字段的 执行函数获取optionList数据
:param relation_trigger_value_list: 指定那些字段有值 并且值在relation_trigger_value_list列表中 则执行函数获取optionList数据
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
:param attrs: 前端attr数据
:param props_info: 其他额外信息
"""
if props_info is None:
props_info = {}
if attrs is None:
attrs = {}
self.label = label
self.attrs = attrs
self.props_info = props_info
self.default_value = default_value
self.input_type = input_type
self.relation_show_field_list = [] if relation_show_field_list is None else relation_show_field_list
self.relation_show_value_list = [] if relation_show_value_list is None else relation_show_value_list
self.relation_trigger_field_list = [] if relation_trigger_field_list is None else relation_trigger_field_list
self.relation_trigger_value_field_list = [] if relation_trigger_value_list is None else relation_trigger_value_list
self.required = required
self.trigger_type = trigger_type
def to_dict(self):
return {
'input_type': self.input_type,
'label': self.label,
'required': self.required,
'default_value': self.default_value,
'relation_show_field_list': self.relation_show_field_list,
'relation_show_value_list': self.relation_show_value_list,
'relation_trigger_field_list': self.relation_trigger_field_list,
'relation_trigger_value_field_list': self.relation_trigger_value_field_list,
'trigger_type': self.trigger_type.value,
'attrs': self.attrs,
'props_info': self.props_info,
}
class BaseDefaultOptionField(BaseField):
def __init__(self, input_type: str,
label: str,
text_field: str,
value_field: str,
option_list: List[dict],
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
"""
:param input_type: 字段
:param label: label
:param text_field: 文本字段
:param value_field: 值字段
:param option_list: 可选列表
:param required: 是否必填
:param default_value: 默认值
:param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示
:param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段
:param attrs: 前端attr数据
:param props_info: 其他额外信息
"""
super().__init__(input_type, label, required, default_value, relation_show_field_list, relation_show_value_list,
[], [], TriggerType.OPTION_LIST, attrs, props_info)
self.text_field = text_field
self.value_field = value_field
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
'option_list': self.option_list}
class BaseExecField(BaseField):
def __init__(self,
input_type: str,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
"""
:param input_type: 字段
:param label: 提示
:param text_field: 文本字段
:param value_field: 值字段
:param provider: 指定供应商
:param method: 执行供应商函数 method
:param required: 是否必填
:param default_value: 默认值
:param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示
:param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段
:param relation_trigger_field_list: 指定那些字段有值的时候 调用当前字段的 执行函数获取optionList数据
:param relation_trigger_value_list: 指定那些字段有值 并且值在relation_trigger_value_list列表中 则执行函数获取optionList数据
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
:param attrs: 前端attr数据
:param props_info: 其他额外信息
"""
super().__init__(input_type, label, required, default_value, relation_show_field_list, relation_show_value_list,
relation_trigger_field_list, relation_trigger_value_list, trigger_type, attrs, props_info)
self.text_field = text_field
self.value_field = value_field
self.provider = provider
self.method = method
def to_dict(self):
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
'provider': self.provider, 'method': self.method}

View File

@ -0,0 +1,16 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_form.py
@date2023/11/1 16:04
@desc:
"""
from common.froms import BaseField
class BaseForm:
def to_form_list(self):
return [{**self.__getattribute__(key).to_dict(), 'field': key} for key in
list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))]

View File

@ -0,0 +1,41 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file combobox_field.py
@date2023/10/31 17:59
@desc:
"""
from typing import List, Dict
from common.froms.base_field import BaseExecField, TriggerType
class Combobox(BaseExecField):
"""
多选框
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str = None,
method: str = None,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("Combobox", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,41 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file multi_select.py
@date2023/10/31 18:00
@desc:
"""
from typing import List, Dict
from common.froms.base_field import BaseExecField, TriggerType
class MultiSelect(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str = None,
method: str = None,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("MultiSelect", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file number_input_field.py
@date2023/10/31 17:58
@desc:
"""
from typing import List
from common.froms.base_field import BaseField, TriggerType
class NumberInput(BaseField):
"""
文本输入框
"""
def __init__(self, label: str,
required: bool = False,
default_value=None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
attrs=None, props_info=None):
super().__init__('NumberInput', label, required, default_value, relation_show_field_list,
relation_show_value_list, [], [],
TriggerType.OPTION_LIST, attrs, props_info)

View File

@ -0,0 +1,36 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file object_card.py
@date2023/10/31 18:02
@desc:
"""
from typing import List, Dict
from common.froms.base_field import BaseExecField, TriggerType
class ObjectCard(BaseExecField):
"""
收集对象子表卡片
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("ObjectCard", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file password_input.py
@date2023/11/1 14:48
@desc:
"""
from typing import List
from common.froms import BaseField, TriggerType
class PasswordInputField(BaseField):
"""
文本输入框
"""
def __init__(self, label: str,
required: bool = False,
default_value=None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
attrs=None, props_info=None):
super().__init__('TextInput', label, required, default_value, relation_show_field_list,
relation_show_value_list, [], [],
TriggerType.OPTION_LIST, attrs, props_info)

View File

@ -0,0 +1,41 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file radio_field.py
@date2023/10/31 17:59
@desc:
"""
from typing import List, Dict
from common.froms.base_field import BaseExecField, TriggerType
class Radio(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("Radio", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,41 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file single_select_field.py
@date2023/10/31 18:00
@desc:
"""
from typing import List, Dict
from common.froms.base_field import TriggerType, BaseExecField
class SingleSelect(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str = None,
method: str = None,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("SingleSelect", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,28 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file switch_btn.py
@date2023/10/31 18:00
@desc:
"""
from typing import List
from common.froms.base_field import TriggerType, BaseField
class SwitchBtn(BaseField):
"""
开关
"""
def __init__(self,
label: str,
required: bool = False,
default_value=None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
attrs=None, props_info=None):
super().__init__('SwitchBtn', label, required, default_value, relation_show_field_list,
relation_show_value_list, [], [],
TriggerType.OPTION_LIST, attrs, props_info)

View File

@ -0,0 +1,36 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file tab_card.py
@date2023/10/31 18:03
@desc:
"""
from typing import List, Dict
from common.froms.base_field import BaseExecField, TriggerType
class TabCard(BaseExecField):
"""
收集 Tab类型数据 tab1:{},tab2:{}
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("TabCard", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)

View File

@ -0,0 +1,36 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file table_radio.py
@date2023/10/31 18:01
@desc:
"""
from typing import List, Dict
from common.froms.base_field import TriggerType, BaseExecField
class TableRadio(BaseExecField):
"""
table 单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
relation_trigger_field_list: List[str] = None,
relation_trigger_value_list: List[str] = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("TableRadio", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_list, relation_show_value_list, relation_trigger_field_list,
relation_trigger_value_list, trigger_type, attrs, props_info)

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file text_input_field.py
@date2023/10/31 17:58
@desc:
"""
from typing import List
from common.froms.base_field import BaseField, TriggerType
class TextInputField(BaseField):
"""
文本输入框
"""
def __init__(self, label: str,
required: bool = False,
default_value=None,
relation_show_field_list: List[str] = None,
relation_show_value_list: List[str] = None,
attrs=None, props_info=None):
super().__init__('TextInput', label, required, default_value, relation_show_field_list,
relation_show_value_list, [], [],
TriggerType.OPTION_LIST, attrs, props_info)

View File

@ -5,7 +5,9 @@ SELECT
problem.dataset_id AS dataset_id,
0 AS source_type,
problem."content" AS "text",
paragraph.is_active AS is_active
paragraph.is_active AS is_active,
problem.star_num as star_num,
problem.trample_num as trample_num
FROM
problem problem
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
@ -18,8 +20,10 @@ SELECT
paragraph."id" AS paragraph_id,
paragraph.dataset_id AS dataset_id,
1 AS source_type,
paragraph."content" AS "text",
paragraph.is_active AS is_active
concat_ws(':',paragraph."title",paragraph."content") AS "text",
paragraph.is_active AS is_active,
paragraph.star_num as star_num,
paragraph.trample_num as trample_num
FROM
paragraph paragraph

View File

@ -6,14 +6,27 @@
@date2023/10/16 16:42
@desc:
"""
import importlib
from functools import reduce
from typing import Dict
def query_params_to_single_dict(query_params: Dict):
return reduce(lambda x, y: {**x, y[0]: y[1]}, list(filter(lambda row: row[1] is not None,
list(map(lambda row: (
row[0], row[1][0] if isinstance(row[1][0],
list) and len(
row[1][0]) > 0 else row[1][0]),
query_params.items())))), {})
return reduce(lambda x, y: {**x, **y}, list(
filter(lambda item: item is not None, [({key: value} if value is not None and len(value) > 0 else None) for
key, value in
query_params.items()])), {})
def get_exec_method(clazz_: str, method_: str):
"""
根据 class 和method函数 获取执行函数
:param clazz_: class 字符串
:param method_: 执行函数
:return: 执行函数
"""
clazz_split = clazz_.split('.')
clazz_name = clazz_split[-1]
package = ".".join([clazz_split[index] for index in range(len(clazz_split) - 1)])
package_model = importlib.import_module(package)
return getattr(getattr(package_model, clazz_name), method_)

View File

@ -0,0 +1,70 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file rsa_util.py
@date2023/11/3 11:13
@desc:
"""
import base64
import os
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
from Crypto.PublicKey import RSA
# 对密钥加密的密码
secret_code = "mac_kb_password"
def generate():
"""
生成 私钥秘钥对
:return:{key:'公钥',value:'私钥'}
"""
# 生成一个 2048 位的密钥
key = RSA.generate(2048)
# 获取私钥
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
protection="scryptAndAES128-CBC")
return {'key': key.publickey().export_key(), 'value': encrypted_key}
def get_key_pair():
if not os.path.exists("/opt/maxkb/conf/receiver.pem"):
kv = generate()
private_file_out = open("/opt/maxkb/conf/private.pem", "wb")
private_file_out.write(kv.get('value'))
private_file_out.close()
receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb")
receiver_file_out.write(kv.get('key'))
receiver_file_out.close()
return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()}
def encrypt(msg, public_key: str | None = None):
"""
加密
:param msg: 加密数据
:param public_key: 公钥
:return: 加密后的数据
"""
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
return base64.b64encode(encrypt_msg).decode()
def decrypt(msg, pri_key: str | None = None):
"""
解密
:param msg: 需要解密的数据
:param pri_key: 私钥
:return: 解密后数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8")

79
apps/common/util/test.py Normal file
View File

@ -0,0 +1,79 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file test.py
@date2023/11/15 15:13
@desc:
"""
import time
from django.core import signing
import hashlib
from django.core.cache import cache
# alg使用的算法
HEADER = {'typ': 'JWP', 'alg': 'default'}
TOKEN_KEY = 'solomon_world_token'
TOKEN_SALT = 'solomonwanc@gmail.com'
TIME_OUT = 30 * 60
# 加密
def encrypt(obj):
value = signing.dumps(obj, key=TOKEN_KEY, salt=TOKEN_SALT)
value = signing.b64_encode(value.encode()).decode()
return value
# 解密
def decrypt(src):
src = signing.b64_decode(src.encode()).decode()
raw = signing.loads(src, key=TOKEN_KEY, salt=TOKEN_SALT)
print(type(raw))
return raw
# 生成token信息
def create_token(username, password):
# 1. 加密头信息
header = encrypt(HEADER)
# 2. 构造Payload
payload = {
"username": username,
"password": password,
"iat": time.time()
}
payload = encrypt(payload)
# 3. 生成签名
md5 = hashlib.md5()
md5.update(("%s.%s" % (header, payload)).encode())
signature = md5.hexdigest()
token = "%s.%s.%s" % (header, payload, signature)
# 4.存储到缓存中
cache.set(username, token, TIME_OUT)
return token
def get_payload(token):
payload = str(token).split('.')[1]
payload = decrypt(payload)
return payload
# 通过token获取用户名
def get_username(token):
payload = get_payload(token)
return payload['username']
pass
def check_token(token):
username = get_username(token)
print('username', username)
last_token = cache.get(username)
if last_token:
return last_token == token
return False
if __name__ == '__main__':
token = create_token('zhangsan', 'lisi')

View File

@ -98,13 +98,15 @@ class DataSetSerializers(serializers.ModelSerializer):
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
{'user_id': models.CharField(),
'team_member_permission.auth_target_type': models.CharField(),
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
base_field=models.CharField(max_length=256,
blank=True,
choices=AuthOperate.choices,
default=AuthOperate.USE)
)})).filter(
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'],
'team_member_permission.auth_target_type': 'DATASET'})
return query_set_dict

View File

@ -19,7 +19,7 @@ from common.db.search import page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models import Paragraph, Problem
from dataset.models import Paragraph, Problem, Document
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
@ -41,7 +41,7 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
message="段落在1-1024个字符之间")
])
title = serializers.CharField(required=False)
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
problem_list = ProblemInstanceSerializer(required=False, many=True)
@ -165,6 +165,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Document).filter(id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "文档id不正确")
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
ParagraphSerializers(data=instance).is_valid(raise_exception=True)

View File

@ -54,6 +54,13 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
paragraph_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'),
document_id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "段落id不正确")
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid()

View File

@ -26,7 +26,8 @@ class Dataset(APIView):
@swagger_auto_schema(operation_summary="获取数据集列表",
operation_id="获取数据集列表",
manual_parameters=DataSetSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()))
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
tags=["数据集"])
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request):
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
@ -37,7 +38,9 @@ class Dataset(APIView):
@swagger_auto_schema(operation_summary="创建数据集",
operation_id="创建数据集",
request_body=DataSetSerializers.Create.get_request_body_api(),
responses=get_api_response(DataSetSerializers.Create.get_response_body_api()))
responses=get_api_response(DataSetSerializers.Create.get_response_body_api()),
tags=["数据集"]
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request):
s = DataSetSerializers.Create(data=request.data)
@ -50,7 +53,8 @@ class Dataset(APIView):
@action(methods="DELETE", detail=False)
@swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
responses=result.get_default_response(),
tags=["数据集"])
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
@ -62,7 +66,8 @@ class Dataset(APIView):
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()),
tags=["数据集"])
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id')))
def get(self, request: Request, dataset_id: str):
@ -72,7 +77,9 @@ class Dataset(APIView):
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
request_body=DataSetSerializers.Operate.get_request_body_api(),
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()),
tags=["数据集"]
)
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')))
def put(self, request: Request, dataset_id: str):
@ -87,7 +94,9 @@ class Dataset(APIView):
operation_id="获取数据集分页列表",
manual_parameters=get_page_request_params(
DataSetSerializers.Query.get_request_params_api()),
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()),
tags=["数据集"]
)
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request, current_page, page_size):
d = DataSetSerializers.Query(

View File

@ -28,7 +28,8 @@ class Document(APIView):
operation_id="创建文档",
request_body=DocumentSerializers.Create.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
tags=["数据集/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
@ -40,7 +41,8 @@ class Document(APIView):
@swagger_auto_schema(operation_summary="文档列表",
operation_id="文档列表",
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()))
responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()),
tags=["数据集/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
@ -57,7 +59,8 @@ class Document(APIView):
@swagger_auto_schema(operation_summary="获取文档详情",
operation_id="获取文档详情",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
tags=["数据集/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
@ -71,7 +74,8 @@ class Document(APIView):
operation_id="修改文档",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
request_body=DocumentSerializers.Operate.get_request_body_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
tags=["数据集/文档"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
@ -86,7 +90,8 @@ class Document(APIView):
@swagger_auto_schema(operation_summary="删除文档",
operation_id="删除文档",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
responses=result.get_default_response(),
tags=["数据集/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
@ -101,7 +106,9 @@ class Document(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="分段文档",
operation_id="分段文档",
manual_parameters=DocumentSerializers.Split.get_request_params_api())
manual_parameters=DocumentSerializers.Split.get_request_params_api(),
tags=["数据集/文档"],
security=[])
def post(self, request: Request):
ds = DocumentSerializers.Split(
data={'file': request.FILES.getlist('file'),
@ -116,7 +123,8 @@ class Document(APIView):
@swagger_auto_schema(operation_summary="获取数据集分页列表",
operation_id="获取数据集分页列表",
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()))
responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()),
tags=["数据集/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))

View File

@ -25,7 +25,9 @@ class Paragraph(APIView):
@swagger_auto_schema(operation_summary="段落列表",
operation_id="段落列表",
manual_parameters=ParagraphSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()))
responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()),
tags=["数据集/文档/段落"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
@ -41,7 +43,8 @@ class Paragraph(APIView):
operation_id="创建段落",
manual_parameters=ParagraphSerializers.Create.get_request_params_api(),
request_body=ParagraphSerializers.Create.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()))
responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()),
tags=["数据集/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
@ -57,7 +60,8 @@ class Paragraph(APIView):
operation_id="修改段落数据",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
request_body=ParagraphSerializers.Operate.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())
,tags=["数据集/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
@ -71,7 +75,8 @@ class Paragraph(APIView):
@swagger_auto_schema(operation_summary="获取段落详情",
operation_id="获取段落详情",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()),
tags=["数据集/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
@ -85,7 +90,8 @@ class Paragraph(APIView):
@swagger_auto_schema(operation_summary="删除段落",
operation_id="删除段落",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
responses=result.get_default_response(),
tags=["数据集/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
@ -103,7 +109,8 @@ class Paragraph(APIView):
operation_id="分页获取段落列表",
manual_parameters=result.get_page_request_params(
ParagraphSerializers.Query.get_request_params_api()),
responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()))
responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()),
tags=["数据集/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))

View File

@ -25,7 +25,8 @@ class Problem(APIView):
operation_id="添加段落关联问题",
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
request_body=ProblemSerializers.Create.get_request_body_api(),
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()))
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()),
tags=["数据集/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
@ -38,7 +39,8 @@ class Problem(APIView):
@swagger_auto_schema(operation_summary="获取段落问题列表",
operation_id="获取段落问题列表",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()))
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()),
tags=["数据集/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
@ -54,7 +56,8 @@ class Problem(APIView):
@swagger_auto_schema(operation_summary="删除段落问题",
operation_id="删除段落问题",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
responses=result.get_default_response())
responses=result.get_default_response(),
tags=["数据集/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.10 on 2023-11-09 07:45
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('embedding', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='embedding',
name='star_num',
field=models.IntegerField(default=0, verbose_name='点赞数量'),
),
migrations.AddField(
model_name='embedding',
name='trample_num',
field=models.IntegerField(default=0, verbose_name='点踩数量'),
),
]

View File

@ -0,0 +1,17 @@
# Generated by Django 4.1.10 on 2023-11-14 07:30
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('embedding', '0002_embedding_star_num_embedding_trample_num'),
]
operations = [
migrations.AlterUniqueTogether(
name='embedding',
unique_together={('source_id', 'source_type')},
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2023-11-14 08:09
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('embedding', '0003_alter_embedding_unique_together'),
]
operations = [
migrations.AlterField(
model_name='embedding',
name='source_type',
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('-1', '其他')], default='0', max_length=5, verbose_name='资源类型'),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2023-11-14 08:11
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('embedding', '0004_alter_embedding_source_type'),
]
operations = [
migrations.AlterField(
model_name='embedding',
name='source_type',
field=models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=5, verbose_name='资源类型'),
),
]

View File

@ -23,7 +23,7 @@ class Embedding(models.Model):
source_id = models.CharField(max_length=128, verbose_name="资源id")
source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices,
source_type = models.CharField(verbose_name='资源类型', max_length=5, choices=SourceType.choices,
default=SourceType.PROBLEM)
is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True)
@ -36,5 +36,11 @@ class Embedding(models.Model):
embedding = VectorField(verbose_name="向量")
star_num = models.IntegerField(default=0, verbose_name="点赞数量")
trample_num = models.IntegerField(default=0,
verbose_name="点踩数量")
class Meta:
db_table = "embedding"
unique_together = ['source_id', 'source_type']

View File

@ -0,0 +1,15 @@
SELECT * FROM (SELECT
*,
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
CASE
WHEN embedding.star_num - embedding.trample_num = 0 THEN
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
END AS score
FROM
embedding,
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
${embedding_query}
) temp
WHERE similarity>0.5
ORDER BY (similarity + score) DESC LIMIT 1

View File

@ -47,6 +47,8 @@ class BaseVectorStore(ABC):
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
star_num: int,
trample_num: int,
embedding=None):
"""
插入向量数据
@ -58,12 +60,15 @@ class BaseVectorStore(ABC):
:param is_active: 是否禁用
:param embedding: 向量化处理器
:param paragraph_id 段落id
:param star_num 点赞数量
:param trample_num 点踩数量
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num,
trample_num, embedding)
def batch_save(self, data_list: List[Dict], embedding=None):
"""
@ -81,6 +86,8 @@ class BaseVectorStore(ABC):
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
star_num: int,
trample_num: int,
embedding: HuggingFaceEmbeddings):
pass
@ -89,7 +96,10 @@ class BaseVectorStore(ABC):
pass
@abstractmethod
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_id_list: list[str],
is_active: bool,
embedding: HuggingFaceEmbeddings):
pass
@abstractmethod

View File

@ -6,14 +6,20 @@
@date2023/10/19 15:28
@desc:
"""
import json
import os
import uuid
from typing import Dict, List
from django.db.models import QuerySet
from langchain.embeddings import HuggingFaceEmbeddings
from common.db.search import native_search, generate_sql_by_query_dict
from common.db.sql_execute import select_one
from common.util.file_util import get_file_content
from embedding.models import Embedding, SourceType
from embedding.vector.base_vector import BaseVectorStore
from smartdoc.conf import PROJECT_DIR
class PGVector(BaseVectorStore):
@ -27,6 +33,8 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
star_num: int,
trample_num: int,
embedding: HuggingFaceEmbeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(),
@ -37,6 +45,8 @@ class PGVector(BaseVectorStore):
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
star_num=star_num,
trample_num=trample_num
)
embedding.save()
return True
@ -44,19 +54,41 @@ class PGVector(BaseVectorStore):
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
dataset_id=text_list[index].get('dataset_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=embeddings[index]) for index in
range(0, len(text_list))]) if len(text_list) > 0 else None
embedding_list = [Embedding(id=uuid.uuid1(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
dataset_id=text_list[index].get('dataset_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
star_num=text_list[index].get('star_num'),
trample_num=text_list[index].get('trample_num'),
embedding=embeddings[index]) for index in
range(0, len(text_list))]
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_id_list: list[str],
is_active: bool,
embedding: HuggingFaceEmbeddings):
exclude_dict = {}
if dataset_id_list is None or len(dataset_id_list) == 0:
return None
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
embedding_query = embedding.embed_query(query_text)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
if exclude_id_list is not None and len(exclude_id_list) > 0:
exclude_dict.__setitem__('id__in', exclude_id_list)
query_set = query_set.exclude(**exclude_dict)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'embedding_search.sql')),
with_table_name=True)
embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params))
return embedding_model
def update_by_source_id(self, source_id: str, instance: Dict):
QuerySet(Embedding).filter(source_id=source_id).update(**instance)

View File

@ -0,0 +1,34 @@
# Generated by Django 4.1.10 on 2023-11-06 08:51
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
('users', '0001_initial'),
('setting', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='Model',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('name', models.CharField(max_length=128, verbose_name='名称')),
('model_type', models.CharField(max_length=128, verbose_name='模型类型')),
('model_name', models.CharField(max_length=128, verbose_name='模型名称')),
('provider', models.CharField(max_length=20, verbose_name='供应商')),
('credential', models.CharField(max_length=5120, verbose_name='模型认证信息')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='成员用户id')),
],
options={
'db_table': 'model',
'unique_together': {('name', 'user_id')},
},
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2023-11-13 05:55
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0002_model'),
]
operations = [
migrations.AlterField(
model_name='model',
name='provider',
field=models.CharField(max_length=128, verbose_name='供应商'),
),
]

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2023/10/31 17:16
@desc:
"""

View File

@ -0,0 +1,130 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_model_provider.py
@date2023/10/31 16:19
@desc:
"""
from abc import ABC, abstractmethod
from enum import Enum
from functools import reduce
from typing import Dict, List
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.language_model import LanguageModelInput
class IModelProvider(ABC):
@abstractmethod
def get_model_provide_info(self):
pass
@abstractmethod
def get_model_type_list(self):
pass
@abstractmethod
def get_model_list(self, model_type):
pass
@abstractmethod
def get_model_credential(self, model_type, model_name):
pass
@abstractmethod
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
pass
@abstractmethod
def get_dialogue_number(self):
pass
class BaseModelCredential(ABC):
@abstractmethod
def is_valid(self, model_type: str, model_name, model: Dict[str, object], raise_exception=False):
pass
@abstractmethod
def encryption_dict(self, model_info: Dict[str, object]):
"""
:param model_info: 模型数据
:return: 加密后数据
"""
pass
@staticmethod
def encryption(message: str):
"""
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
:param message:
:return:
"""
max_pre_len = 8
max_post_len = 4
message_len = len(message)
pre_len = int(message_len / 5 * 2)
post_len = int(message_len / 5 * 1)
pre_str = "".join([message[index] for index in
range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))])
end_str = "".join(
[message[index] for index in
range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)])
content = "***************"
return pre_str + content + end_str
class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': '大语言模型'}
class ModelInfo:
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
**keywords):
self.name = name
self.desc = desc
self.model_type = model_type.name
self.model_credential = model_credential
if keywords is not None:
for key in keywords.keys():
self.__setattr__(key, keywords.get(key))
def get_name(self):
"""
获取模型名称
:return: 模型名称
"""
return self.name
def get_desc(self):
"""
获取模型描述
:return: 模型描述
"""
return self.desc
def get_model_type(self):
return self.model_type
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__") and not attr == 'model_credential'], {})
class ModelProvideInfo:
def __init__(self, provider: str, name: str, icon: str):
self.provider = provider
self.name = name
self.icon = icon
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__")], {})

View File

@ -0,0 +1,17 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file model_provider_constants.py
@date2023/11/2 14:55
@desc:
"""
from enum import Enum
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
class ModelProvideConstants(Enum):
model_azure_provider = AzureModelProvider()
model_wenxin_provider = WenxinModelProvider()

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2023/10/31 17:16
@desc:
"""

View File

@ -0,0 +1,110 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file azure_model_provider.py
@date2023/10/31 16:19
@desc:
"""
import os
from typing import Dict, List
from langchain.chat_models import AzureChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, BaseMessage
from langchain.schema.language_model import LanguageModelInput
from common import froms
from common.exception.app_exception import AppApiException
from common.froms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
ModelInfo, \
ModelTypeConst
from smartdoc.conf import PROJECT_DIR
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = AzureModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持')
if model_name not in model_dict:
raise AppApiException(500, f'{model_name} 模型名称不支持')
for key in ['api_base', 'api_key', 'deployment_name']:
if key not in model_credential:
if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段')
else:
return False
try:
AzureModelProvider().query(model_type, model_name, model_credential,
message=[HumanMessage(content='valid')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(500, '校验失败,请检查参数是否正确')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = froms.TextInputField('API 域名', required=True)
api_key = froms.PasswordInputField("API Key", required=True)
deployment_name = froms.TextInputField("部署名", required=True)
azure_llm_model_credential = AzureLLMModelCredential()
model_dict = {
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
'gpt-3.5-turbo-0301': ModelInfo('gpt-3.5-turbo-0301', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
'gpt-3.5-turbo-16k-0613': ModelInfo('gpt-3.5-turbo-16k-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview')
}
class AzureModelProvider(IModelProvider):
def get_dialogue_number(self):
return 3
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatOpenAI(
openai_api_base=model_credential.get('api_base'),
openai_api_version=model_info.api_version,
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure",
tiktoken_model_name=model_name
)
return azure_chat_open_ai
def get_model_credential(self, model_type, model_name):
return model_dict.get(model_name).model_credential
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon',
'azure_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -0,0 +1,2 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" id="Layer_1" x="0px" y="0px" width="64px" height="64px" viewBox="0 0 64 64" enable-background="new 0 0 64 64" xml:space="preserve"> <image id="image0" width="64" height="64" x="0" y="0" href=" AAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAABmJLR0QA/wD/AP+gvaeTAAAA CXBIWXMAAA7EAAAOxAGVKw4bAAAJXUlEQVR42u2aa2wcVxXH/+fO7K4fa6epX2snDSqNHb8TQ4HE Nn0EldR5v2wnLRQZCSHxqSFqCUJCfEAVihBCSFRIICEB3RabkFYVAipFgEgjhNUQYjs2SYGw1LFj O8l67V17X3P5MLO7szuvuxsnMeArbZydk8zM+c3/f+69xwOsjbWxNtbG//Gge3HSosd2biCZNZY8 4vMxt1xc4Gl48Ya6v2t/DwZ++Mql/woA3o9/8Rvy+k1fQzIhk6TgkecPoLypIc/UtR+c64+9N/Kl np7QxfOzK3m/bCVPVlT/zGOuis1fBzEZRFDiSXzw+i+hJBJiSXMAnINrn8wxABwfbfv+r19eyftd cQByxebjYEREDGASQBIS4SgWxicckkY6aZ6ddC6U58o7uldUtSsNoJ+DAGIAMRCTAMmF+b9eMU2c Wz/tLCXooNSt6/hk96oEUNSwqwlErSAC1wBwYiAmIzR6FUo87iRxoxIAA5SNnz/VvyoByBWbj6k1 lQDSVMAkcCZBSXCExiZEJC6ihCPl27pX7L5XDkBlfZ+aOFQATLMBaTa4PG5MDCK+N8R85R3dT64q AEVbnm0DUWNaASD11EwrhkxGaPwalHhctNiZxHk6vvFzp/pWFQCpsv4YB4Gnnr5mASJJs4IMJQmE xsbzTBpZStDFj5Rv7ZZWDQC5sqE3lTiH9iFoyTMQYyDmwvzIuIjERZRQVb6t6+lVAcDTuLsDoHr1 G+kUQOBEAEngJAGSjIW//QNKLOYocRElbHjhVO+qACBX1vfrk86srjPrAbUWyEgqhPmxcRGJG6HA EDtStrVLfqAApMp6SFVbetWc9VOgDgZjgFYLSHJhfnRCuNjZQ0FFeXv3px4oALmy4XGAPqx6PiX5 HBhg6TUBmIyF968jGY2KFjtbKBs++5W7tsFdAfA07+tPz/ta0gYYRKD00lgG5wyhKxMiEheBcris /e5sUDAAqbIBIOrLlr4FDJbZHJHkwvyVqyISF1HC+rL2rl0PBIBctWU7gE3GAmgCgxg4mAbChcV/ /gvJ5WXBp81toWx47uW7skHBANwt+/u4PmkA1jBI3RlqH4XL6hZZrNg5QTlY1tbluq8ApKotBKAX Or87wqBMMSTZheDEtTxnAF3SyIKyrqytq+d+A+gE0cZMgoAIDEpDkBG+/gGSS0uOEs/ErZVQd/yl gvcGBQHwtB7IFD/D3G8Bg7T1AJNATAInGaGJqyISz4lzMyj7va2dnvsCQJU/9WaSBoRhaFvk1GwQ mZp+X0DilkrQQSkra+3afX8AVDc+wYlqubbxyU7aAYaWPNeWxsvXx74Jzm87SVxECbXHCrNB3gDc rYd0cz8yT1cUBktbIXLn9z8b4hxnBSQuooS93pbOvH8HkRcAqbqRgeioYarLB0a6Y8x+pdy5Hrnx 2ukhAYnbQEkf8Ja1du67xwCangJRte287wiDgYghdvWdQQBYGDl/DpzfEk2amyghFff1vnT0ngJw tR3uM93x5QNDhbCYnLv2tgrgQoJznBEsdsY4UnEOcL7X29xZek8AsOomCUSH1URg8mTzgMHYW8m5 a9HUuW/4Tw8JSFxECcXels799wSAVNO0E6AqfdKFwoiNvT2kP/fC6Lu/4xyzpokBeUHxHT2Z195A GIC7/Wi/2Y6vABih5OzEb/TnXhy9kATHGfunzWHmexMl9Hibd5StKABW0yyD6JBxWisIxpvJmYlo 7jWm3jg9KFrsHKAUlTaJ20AIgFTT/AwHPZzV9i4QRmz07JDZNRZGL/yBc0w7+547QvEdOSm8KBIC 4Nraa9H2zhtGMDEz8VuzayyOXVDA+RkBiYtAebbm8Ml1KwKA1TS7ADqkfstpe+cP46wyMx63utbU 4LcHxX1vEsscdpc27jggAsCxnyb5WnaB6KH0AZ76g9K1DdqFkyXrkPR4Lc+luIvrcWr0x1n/SUkQ ElEgsUzTyyG5Ymk56SrySKm3RFL/LPvamR9Z8RQFAN7GHX0AfnLXAFzb+nvVE6cquS6YAyO6fqNu F2g6urWPdi7dNlnbKgdv3ETVo5t0SekyNE065wtPH/p09cEvPzTz5neCdjdkawHma/Fw0MFUc9O4 rs+e6uTIHSeexpHaIDEJkNwITs9Z+x5i9lB/zc5dpY07Djtd3gFAaw8I5akEM543h+Gen4a8eCtf Apl2meRCeDGGeCQi7HvzuHrcu2WH42xgC8DVceyoVUGzguEOzcAVupmfAvQ2kD0ITs1aVHlYPW2L ON9ZfeBERUEAmK+1GMABkbZ3LgxX+DbcwakcozpASNvAheDNORGJC0CBq7TB3gbWAGpbd3Mir22n 1waGtDQPT3AS4IoIgSwbRCJxxCIRR4nbQoEaL3GwgSUAueN4fm1vExhsOQz37ck5AAuOCtC9WgfZ g/mbcyISF4HyVNW+E1V5AWC+1lIQ7c+37W22CJLikUEEhp8GYF8d9RAkN4IzcyISF4EilzZst2yU mAOoa9sLULFlc9MKhsmKMH5pcBD+gfdw/tVuAAEbAjobyFhaSiAWDjtKPBWzg1JSv91yi2wKQP7I 8wW1vWGEMaVMj/4RAHD+1Qn4BzoBjInaIDh7S7TYmSpBF3+ias+LNUIAqLbNy4n2cDJL2gKGZgVu hDGkTI9lqmBgeBL+gScB/FnEBvOzt/L0vaUSpJJ6cxsYALDa9v0AFUHb5Bhl7rwiTMGIX3rDuPUN DN+Cf2AngHOWADQbLC8nEQ2HHSUuooSS+u2ms4EBgPz4Z/rMiloaBoRhTCpTo++aPunAcBj+gd0A fpFDIMcGbszPzolIPCduGuuu3P1inS0Aqm0vB1GP01QnCGNImR61XgkFhmPwD/QD+JGtDeZuCxc7 y9dtAXDOWcnmTxhskAWA1bXvAcgt3Om1gZH4y+s/h9MIDCvwD3wBgeFvWdkgGlcQjYRFi52pElJQ SjZvP2JvAaJHLft5AosgHYxA4qL/T44AVAiAf+CrCAyfUG9bbwMZkDyIR2MiErdXgjo22QMAzpn1 88wLnTUMZWrkB0LJ64d/4LsIDL8AIJGxAYPkKUJRaZmIxI1QjOOd3ANZ79vyG5cnqW7bv6nc1wLQ w8bFDUCW0yIAohvK1MgriYuvncbijOBOSDdG3roM4AI+9LEWgNd5XBLqaivgduvegNH3PkwaIRZj AcBPA987djJ+ezKW932tjbWxNtbG/+r4D87QQ0KXWZJ6AAAAMnRFWHRDb21tZW50AHhyOmQ6REFG allncE9SVXM6MixqOjUwNTg3NDI3NTEsdDoyMzA1MTkxOOBfQ5wAAAAldEVYdGRhdGU6Y3JlYXRl ADIwMjMtMTAtMzFUMTA6MTQ6MjgrMDE6MDAxbYpSAAAAJXRFWHRkYXRlOm1vZGlmeQAyMDIzLTEw LTMxVDEwOjE0OjI4KzAxOjAwQDAy7gAAAABJRU5ErkJggg=="/>
</svg>

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2023/10/31 17:16
@desc:
"""

View File

@ -0,0 +1,2 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" id="Layer_1" x="0px" y="0px" width="64px" height="64px" viewBox="0 0 64 64" enable-background="new 0 0 64 64" xml:space="preserve"> <image id="image0" width="64" height="64" x="0" y="0" href=" AAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAABmJLR0QA/wD/AP+gvaeTAAAA CXBIWXMAAA7EAAAOxAGVKw4bAAAJXUlEQVR42u2aa2wcVxXH/+fO7K4fa6epX2snDSqNHb8TQ4HE Nn0EldR5v2wnLRQZCSHxqSFqCUJCfEAVihBCSFRIICEB3RabkFYVAipFgEgjhNUQYjs2SYGw1LFj O8l67V17X3P5MLO7szuvuxsnMeArbZydk8zM+c3/f+69xwOsjbWxNtbG//Gge3HSosd2biCZNZY8 4vMxt1xc4Gl48Ya6v2t/DwZ++Mql/woA3o9/8Rvy+k1fQzIhk6TgkecPoLypIc/UtR+c64+9N/Kl np7QxfOzK3m/bCVPVlT/zGOuis1fBzEZRFDiSXzw+i+hJBJiSXMAnINrn8wxABwfbfv+r19eyftd cQByxebjYEREDGASQBIS4SgWxicckkY6aZ6ddC6U58o7uldUtSsNoJ+DAGIAMRCTAMmF+b9eMU2c Wz/tLCXooNSt6/hk96oEUNSwqwlErSAC1wBwYiAmIzR6FUo87iRxoxIAA5SNnz/VvyoByBWbj6k1 lQDSVMAkcCZBSXCExiZEJC6ihCPl27pX7L5XDkBlfZ+aOFQATLMBaTa4PG5MDCK+N8R85R3dT64q AEVbnm0DUWNaASD11EwrhkxGaPwalHhctNiZxHk6vvFzp/pWFQCpsv4YB4Gnnr5mASJJs4IMJQmE xsbzTBpZStDFj5Rv7ZZWDQC5sqE3lTiH9iFoyTMQYyDmwvzIuIjERZRQVb6t6+lVAcDTuLsDoHr1 G+kUQOBEAEngJAGSjIW//QNKLOYocRElbHjhVO+qACBX1vfrk86srjPrAbUWyEgqhPmxcRGJG6HA EDtStrVLfqAApMp6SFVbetWc9VOgDgZjgFYLSHJhfnRCuNjZQ0FFeXv3px4oALmy4XGAPqx6PiX5 HBhg6TUBmIyF968jGY2KFjtbKBs++5W7tsFdAfA07+tPz/ta0gYYRKD00lgG5wyhKxMiEheBcris /e5sUDAAqbIBIOrLlr4FDJbZHJHkwvyVqyISF1HC+rL2rl0PBIBctWU7gE3GAmgCgxg4mAbChcV/ /gvJ5WXBp81toWx47uW7skHBANwt+/u4PmkA1jBI3RlqH4XL6hZZrNg5QTlY1tbluq8ApKotBKAX Or87wqBMMSTZheDEtTxnAF3SyIKyrqytq+d+A+gE0cZMgoAIDEpDkBG+/gGSS0uOEs/ErZVQd/yl gvcGBQHwtB7IFD/D3G8Bg7T1AJNATAInGaGJqyISz4lzMyj7va2dnvsCQJU/9WaSBoRhaFvk1GwQ mZp+X0DilkrQQSkra+3afX8AVDc+wYlqubbxyU7aAYaWPNeWxsvXx74Jzm87SVxECbXHCrNB3gDc rYd0cz8yT1cUBktbIXLn9z8b4hxnBSQuooS93pbOvH8HkRcAqbqRgeioYarLB0a6Y8x+pdy5Hrnx 2ukhAYnbQEkf8Ja1du67xwCangJRte287wiDgYghdvWdQQBYGDl/DpzfEk2amyghFff1vnT0ngJw tR3uM93x5QNDhbCYnLv2tgrgQoJznBEsdsY4UnEOcL7X29xZek8AsOomCUSH1URg8mTzgMHYW8m5 a9HUuW/4Tw8JSFxECcXels799wSAVNO0E6AqfdKFwoiNvT2kP/fC6Lu/4xyzpokBeUHxHT2Z195A GIC7/Wi/2Y6vABih5OzEb/TnXhy9kATHGfunzWHmexMl9Hibd5StKABW0yyD6JBxWisIxpvJmYlo 7jWm3jg9KFrsHKAUlTaJ20AIgFTT/AwHPZzV9i4QRmz07JDZNRZGL/yBc0w7+547QvEdOSm8KBIC 4Nraa9H2zhtGMDEz8VuzayyOXVDA+RkBiYtAebbm8Ml1KwKA1TS7ADqkfstpe+cP46wyMx63utbU 4LcHxX1vEsscdpc27jggAsCxnyb5WnaB6KH0AZ76g9K1DdqFkyXrkPR4Lc+luIvrcWr0x1n/SUkQ ElEgsUzTyyG5Ymk56SrySKm3RFL/LPvamR9Z8RQFAN7GHX0AfnLXAFzb+nvVE6cquS6YAyO6fqNu F2g6urWPdi7dNlnbKgdv3ETVo5t0SekyNE065wtPH/p09cEvPzTz5neCdjdkawHma/Fw0MFUc9O4 rs+e6uTIHSeexpHaIDEJkNwITs9Z+x5i9lB/zc5dpY07Djtd3gFAaw8I5akEM543h+Gen4a8eCtf Apl2meRCeDGGeCQi7HvzuHrcu2WH42xgC8DVceyoVUGzguEOzcAVupmfAvQ2kD0ITs1aVHlYPW2L ON9ZfeBERUEAmK+1GMABkbZ3LgxX+DbcwakcozpASNvAheDNORGJC0CBq7TB3gbWAGpbd3Mir22n 1waGtDQPT3AS4IoIgSwbRCJxxCIRR4nbQoEaL3GwgSUAueN4fm1vExhsOQz37ck5AAuOCtC9WgfZ g/mbcyISF4HyVNW+E1V5AWC+1lIQ7c+37W22CJLikUEEhp8GYF8d9RAkN4IzcyISF4EilzZst2yU mAOoa9sLULFlc9MKhsmKMH5pcBD+gfdw/tVuAAEbAjobyFhaSiAWDjtKPBWzg1JSv91yi2wKQP7I 8wW1vWGEMaVMj/4RAHD+1Qn4BzoBjInaIDh7S7TYmSpBF3+ias+LNUIAqLbNy4n2cDJL2gKGZgVu hDGkTI9lqmBgeBL+gScB/FnEBvOzt/L0vaUSpJJ6cxsYALDa9v0AFUHb5Bhl7rwiTMGIX3rDuPUN DN+Cf2AngHOWADQbLC8nEQ2HHSUuooSS+u2ms4EBgPz4Z/rMiloaBoRhTCpTo++aPunAcBj+gd0A fpFDIMcGbszPzolIPCduGuuu3P1inS0Aqm0vB1GP01QnCGNImR61XgkFhmPwD/QD+JGtDeZuCxc7 y9dtAXDOWcnmTxhskAWA1bXvAcgt3Om1gZH4y+s/h9MIDCvwD3wBgeFvWdkgGlcQjYRFi52pElJQ SjZvP2JvAaJHLft5AosgHYxA4qL/T44AVAiAf+CrCAyfUG9bbwMZkDyIR2MiErdXgjo22QMAzpn1 88wLnTUMZWrkB0LJ64d/4LsIDL8AIJGxAYPkKUJRaZmIxI1QjOOd3ANZ79vyG5cnqW7bv6nc1wLQ w8bFDUCW0yIAohvK1MgriYuvncbijOBOSDdG3roM4AI+9LEWgNd5XBLqaivgduvegNH3PkwaIRZj AcBPA987djJ+ezKW932tjbWxNtbG/+r4D87QQ0KXWZJ6AAAAMnRFWHRDb21tZW50AHhyOmQ6REFG allncE9SVXM6MixqOjUwNTg3NDI3NTEsdDoyMzA1MTkxOOBfQ5wAAAAldEVYdGRhdGU6Y3JlYXRl ADIwMjMtMTAtMzFUMTA6MTQ6MjgrMDE6MDAxbYpSAAAAJXRFWHRkYXRlOm1vZGlmeQAyMDIzLTEw LTMxVDEwOjE0OjI4KzAxOjAwQDAy7gAAAABJRU5ErkJggg=="/>
</svg>

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

@ -0,0 +1,78 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file qian_fan_chat_model.py
@date2023/11/10 17:45
@desc:
"""
from typing import Optional, List, Any, Iterator, cast
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models import QianfanChatEndpoint
from langchain.chat_models.base import BaseChatModel
from langchain.load import dumpd
from langchain.schema import LLMResult
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
class QianfanChatModel(QianfanChatEndpoint):
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if len(input) % 2 == 0:
input = [HumanMessage(content='占位'), *input]
input = [
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
for index in range(0, len(input))]
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
assert generation is not None
except BaseException as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)

View File

@ -0,0 +1,134 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file wenxin_model_provider.py
@date2023/10/31 16:19
@desc:
"""
import os
from typing import Dict
from langchain.chat_models import QianfanChatEndpoint
from langchain.chat_models.baidu_qianfan_endpoint import convert_message_to_dict
from langchain.schema import HumanMessage
from common import froms
from common.exception.app_exception import AppApiException
from common.froms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
ModelInfo, IModelProvider
from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel
from smartdoc.conf import PROJECT_DIR
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = WenxinModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持')
if model_name not in model_dict:
raise AppApiException(500, f'{model_name} 模型名称不支持')
for key in ['api_key', 'secret_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段')
else:
return False
try:
WenxinModelProvider().get_model(model_type, model_name, model_credential)([HumanMessage(content='valid')])
except Exception as e:
if raise_exception:
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'secret_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
self.secret_key = model_info.get('secret_key')
return self
api_key = froms.PasswordInputField('API Key', required=True)
secret_key = froms.PasswordInputField("Secret Key", required=True)
win_xin_llm_model_credential = WenxinLLMModelCredential()
model_dict = {
'ERNIE-Bot-4': ModelInfo('ERNIE-Bot-4',
'ERNIE-Bot-4是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'ERNIE-Bot': ModelInfo('ERNIE-Bot',
'ERNIE-Bot是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'ERNIE-Bot-turbo': ModelInfo('ERNIE-Bot-turbo',
'ERNIE-Bot-turbo是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力响应速度更快。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'BLOOMZ-7B': ModelInfo('BLOOMZ-7B',
'BLOOMZ-7B是业内知名的大语言模型由BigScience研发并开源能够以46种语言和13种编程语言输出文本。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Llama-2-7b-chat': ModelInfo('Llama-2-7b-chat',
'Llama-2-7b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-7b-chat是高性能原生开源版本适用于对话场景。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Llama-2-13b-chat': ModelInfo('Llama-2-13b-chat',
'Llama-2-13b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-13b-chat是性能与效果均衡的原生开源版本适用于对话场景。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Llama-2-70b-chat': ModelInfo('Llama-2-70b-chat',
'Llama-2-70b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-70b-chat是高精度效果的原生开源版本。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Qianfan-Chinese-Llama-2-7B': ModelInfo('Qianfan-Chinese-Llama-2-7B',
'千帆团队在Llama-2-7b基础上的中文增强版本在CMMLU、C-EVAL等中文数据集上表现优异。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Qianfan-Chinese-Llama-2-13B': ModelInfo('Qianfan-Chinese-Llama-2-13B',
'千帆团队在Llama-2-13b基础上的中文增强版本在CMMLU、C-EVAL等中文数据集上表现优异。',
ModelTypeConst.LLM, win_xin_llm_model_credential)
}
class WenxinModelProvider(IModelProvider):
def get_dialogue_number(self):
return 2
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
**model_kwargs) -> QianfanChatEndpoint:
return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False))
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
raise AppApiException(500, f'不支持的模型:{model_name}')
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_wenxin_provider', name='Azure OpenAI', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'wenxin_model_provider', 'icon',
'azure_icon_svg')))

View File

@ -0,0 +1,119 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file provider_serializers.py
@date2023/11/2 14:01
@desc:
"""
import json
import uuid
from typing import Dict
from django.db.models import QuerySet
from rest_framework import serializers
from common.exception.app_exception import AppApiException
from common.util.rsa_util import encrypt, decrypt
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
class ModelSerializer(serializers.Serializer):
class Query(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
name = serializers.CharField(required=False)
model_type = serializers.CharField(required=False)
model_name = serializers.CharField(required=False)
def list(self, with_valid):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
name = self.data.get('name')
model_query_set = QuerySet(Model).filter(user_id=user_id)
query_params = {}
if name is not None:
query_params['name__contains'] = name
if self.data.get('model_type') is not None:
query_params['model_type'] = self.data.get('model_type')
if self.data.get('model_name') is not None:
query_params['model_name'] = self.data.get('model_name')
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)]
class Create(serializers.Serializer):
user_id = serializers.CharField(required=True)
name = serializers.CharField(required=True)
provider = serializers.CharField(required=True)
model_type = serializers.CharField(required=True)
model_name = serializers.CharField(required=True)
credential = serializers.DictField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if QuerySet(Model).filter(user_id=self.data.get('user_id'),
name=self.data.get('name')).exists():
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
# 校验模型认证数据
ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'),
self.data.get(
'model_name')).is_valid(
self.data.get('model_type'),
self.data.get('model_name'),
self.data.get('credential'),
raise_exception=True)
def insert(self, user_id, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
credential = self.data.get('credential')
name = self.data.get('name')
provider = self.data.get('provider')
model_type = self.data.get('model_type')
model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), user_id=user_id, name=name,
credential=encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
model.save()
return ModelSerializer.Operate(data={'id': model.id}).one(user_id, with_valid=True)
@staticmethod
def model_to_dict(model: Model):
credential = json.loads(decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).encryption_dict(
credential)}
class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True)
def one(self, user_id, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).get(id=self.data.get('id'), user_id=user_id)
return ModelSerializer.model_to_dict(model)
class ProviderSerializer(serializers.Serializer):
provider = serializers.CharField(required=True)
method = serializers.CharField(required=True)
def exec(self, exec_params: Dict[str, object], with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
provider = self.data.get('provider')
method = self.data.get('method')
return getattr(ModelProvideConstants[provider].value, method)(exec_params)

View File

@ -233,12 +233,12 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer):
member_permission_list))
# 分为 APPLICATION DATASET俩组
groups = itertools.groupby(
list(map(lambda m: {**m, 'member_id': member_id,
'operate': dict(
map(lambda key: (key, True if m.get('operate') is not None and m.get(
'operate').__contains__(key) else False),
[Operate.USE.value, Operate.MANAGE.value]))},
member_permission_list)),
sorted(list(map(lambda m: {**m, 'member_id': member_id,
'operate': dict(
map(lambda key: (key, True if m.get('operate') is not None and m.get(
'operate').__contains__(key) else False),
[Operate.USE.value, Operate.MANAGE.value]))},
member_permission_list)), key=lambda x: x.get('type')),
key=lambda x: x.get('type'))
return dict([(key, list(group)) for key, group in groups])

View File

@ -0,0 +1,156 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file provide_api.py
@date2023/11/2 14:25
@desc:
"""
from drf_yasg import openapi
from common.mixins.api_mixin import ApiMixin
class ModelQueryApi(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='name',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='模型名称'),
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='模型类型'),
openapi.Parameter(name='model_name', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='基础模型名称')
]
class ModelCreateApi(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
title="调用函数所需要的参数",
description="调用函数所需要的参数",
required=['provide', 'model_info'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING,
title="模型名称",
description="模型名称"),
'provider': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商",
description="供应商"),
'model_type': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商",
description="供应商"),
'model_name': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商",
description="供应商"),
'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
title="模型证书信息",
description="模型证书信息")
}
)
class ProvideApi(ApiMixin):
class ModelTypeList(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='provider',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='供应名称'),
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['key', 'value'],
properties={
'key': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型描述",
description="模型类型描述", default="大语言模型"),
'value': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值",
description="模型类型值", default="LLM"),
}
)
class ModelList(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='provider',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='供应名称'),
openapi.Parameter(name='model_type',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='模型类型'),
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'model_type'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="模型名称",
description="模型名称", default="模型名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="模型描述",
description="模型描述", default="xxx模型"),
'model_type': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值",
description="模型类型值", default="LLM"),
}
)
class ModelForm(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='provider',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='供应名称'),
openapi.Parameter(name='model_type',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='模型类型'),
openapi.Parameter(name='model_name',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='模型名称'),
]
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='provider',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='供应商'),
openapi.Parameter(name='method',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='需要执行的函数'),
]
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
title="调用函数所需要的参数",
description="调用函数所需要的参数",
)

View File

@ -5,5 +5,14 @@ from . import views
app_name = "team"
urlpatterns = [
path('team/member', views.TeamMember.as_view(), name="team"),
path('team/member/<str:member_id>', views.TeamMember.Operate.as_view(), name='member')
path('team/member/<str:member_id>', views.TeamMember.Operate.as_view(), name='member'),
path('provider/<str:provider>/<str:method>', views.Provide.Exec.as_view(), name='provide_exec'),
path('provider', views.Provide.as_view(), name='provide'),
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
path('provider/model_list', views.Provide.ModelList.as_view(),
name="provider/model_name_list"),
path('provider/model_form', views.Provide.ModelForm.as_view(),
name="provider/model_form"),
path('model', views.Model.as_view(), name='model'),
]

View File

@ -11,7 +11,8 @@ from rest_framework.decorators import action
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import PermissionConstants
from common.response import result
from setting.serializers.team_serializers import TeamMemberSerializer, get_response_body_api, \
UpdateTeamMemberPermissionSerializer
@ -23,14 +24,18 @@ class TeamMember(APIView):
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取团队成员列表",
operation_id="获取团员成员列表",
responses=result.get_api_response(get_response_body_api()))
responses=result.get_api_response(get_response_body_api()),
tags=["团队"])
@has_permissions(PermissionConstants.TEAM_READ)
def get(self, request: Request):
return result.success(TeamMemberSerializer(data={'team_id': str(request.user.id)}).list_member())
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="添加成员",
operation_id="添加成员",
request_body=TeamMemberSerializer().get_request_body_api())
request_body=TeamMemberSerializer().get_request_body_api(),
tags=["团队"])
@has_permissions(PermissionConstants.TEAM_CREATE)
def post(self, request: Request):
team = TeamMemberSerializer(data={'team_id': str(request.user.id)})
return result.success((team.add_member(**request.data)))
@ -41,7 +46,9 @@ class TeamMember(APIView):
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取团队成员权限",
operation_id="获取团队成员权限",
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api())
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(),
tags=["团队"])
@has_permissions(PermissionConstants.TEAM_READ)
def get(self, request: Request, member_id: str):
return result.success(TeamMemberSerializer.Operate(
data={'member_id': member_id, 'team_id': str(request.user.id)}).list_member_permission())
@ -50,8 +57,10 @@ class TeamMember(APIView):
@swagger_auto_schema(operation_summary="修改团队成员权限",
operation_id="修改团队成员权限",
request_body=UpdateTeamMemberPermissionSerializer().get_request_body_api(),
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api()
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(),
tags=["团队"]
)
@has_permissions(PermissionConstants.TEAM_EDIT)
def put(self, request: Request, member_id: str):
return result.success(TeamMemberSerializer.Operate(
data={'member_id': member_id, 'team_id': str(request.user.id)}).edit(request.data))
@ -59,8 +68,10 @@ class TeamMember(APIView):
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="移除成员",
operation_id="移除成员",
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api()
manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(),
tags=["团队"]
)
@has_permissions(PermissionConstants.TEAM_DELETE)
def delete(self, request: Request, member_id: str):
return result.success(TeamMemberSerializer.Operate(
data={'member_id': member_id, 'team_id': str(request.user.id)}).delete())

View File

@ -7,3 +7,4 @@
@desc:
"""
from .Team import *
from .model import *

122
apps/setting/views/model.py Normal file
View File

@ -0,0 +1,122 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file model.py
@date2023/11/2 13:55
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import PermissionConstants
from common.response import result
from common.util.common import query_params_to_single_dict
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer
from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi
class Model(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建模型",
operation_id="创建模型",
request_body=ModelCreateApi.get_request_body_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_CREATE)
def post(self, request: Request):
return result.success(
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表",
manual_parameters=ModelQueryApi.get_request_params_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
return result.success(
ModelSerializer.Query(
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
with_valid=True))
class Provide(APIView):
class Exec(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="调用供应商函数,获取表单数据",
operation_id="调用供应商函数,获取表单数据",
manual_parameters=ProvideApi.get_request_params_api(),
request_body=ProvideApi.get_request_body_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def post(self, request: Request, provider: str, method: str):
return result.success(
ProviderSerializer(data={'provider': provider, 'method': method}).exec(request.data, with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型供应商数据",
operation_id="获取模型供应商列表"
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
return result.success(
[ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in
ModelProvideConstants.__members__])
class ModelTypeList(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型类型列表",
operation_id="获取模型类型类型列表",
manual_parameters=ProvideApi.ModelTypeList.get_request_params_api(),
responses=result.get_api_array_response(ProvideApi.ModelTypeList.get_response_body_api())
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
return result.success(ModelProvideConstants[provider].value.get_model_type_list())
class ModelList(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型创建表单",
manual_parameters=ProvideApi.ModelList.get_request_params_api(),
responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api())
, tags=["模型"]
)
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
model_type = request.query_params.get('model_type')
return result.success(
ModelProvideConstants[provider].value.get_model_list(
model_type))
class ModelForm(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型创建表单",
operation_id="获取模型创建表单",
manual_parameters=ProvideApi.ModelForm.get_request_params_api(),
tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
model_type = request.query_params.get('model_type')
model_name = request.query_params.get('model_name')
return result.success(
ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list())

View File

@ -99,6 +99,10 @@ CACHES = {
"token_cache": {
'BACKEND': 'common.cache.file_cache.FileCache',
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
},
"chat_cache": {
'BACKEND': 'common.cache.file_cache.FileCache',
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "chat_cache") # 文件夹路径
}
}

View File

@ -5,13 +5,13 @@ The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
1. Add an import: froms my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
1. Add an import: froms other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
1. Import the include() function: froms django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
import os
@ -43,7 +43,8 @@ schema_view = get_schema_view(
urlpatterns = [
path("api/", include("users.urls")),
path("api/", include("dataset.urls")),
path("api/", include("setting.urls"))
path("api/", include("setting.urls")),
path("api/", include("application.urls"))
]

View File

@ -17,9 +17,9 @@ application = get_wsgi_application()
def post_handler():
from common.event.listener_manage import ListenerManagement
ListenerManagement().run()
ListenerManagement.init_embedding_model_signal.send()
from common import event
event.run()
event.ListenerManagement.init_embedding_model_signal.send()
post_handler()

View File

@ -19,6 +19,7 @@ from django.db.models import Q
from drf_yasg import openapi
from rest_framework import serializers
from common.constants.authentication_type import AuthenticationType
from common.constants.exception_code_constants import ExceptionCodeConstants
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role
from common.exception.app_exception import AppApiException
@ -66,7 +67,8 @@ class LoginSerializer(ApiMixin, serializers.Serializer):
:return: 用户Token(认证信息)
"""
user = self.is_valid()
token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email})
token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email,
'type': AuthenticationType.USER.value})
return token
class Meta:

View File

@ -7,7 +7,6 @@
@desc:
"""
from django.core import cache
from django.db.models import QuerySet
from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
@ -21,7 +20,6 @@ from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants, CompareConstants
from common.response import result
from smartdoc.settings import JWT_AUTH
from users.models.user import User as UserModel
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
RePasswordSerializer, \
SendEmailSerializer, UserProfile
@ -36,7 +34,8 @@ class User(APIView):
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取当前用户信息",
operation_id="获取当前用户信息",
responses=result.get_api_response(UserProfile.get_response_body_api()))
responses=result.get_api_response(UserProfile.get_response_body_api()),
tags=['用户'])
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
def get(self, request: Request):
return result.success(UserProfile.get_user_profile(request.user))
@ -58,7 +57,8 @@ class ResetCurrentUserPasswordView(APIView):
description="密码")
}
),
responses=RePasswordSerializer().get_response_body_api())
responses=RePasswordSerializer().get_response_body_api(),
tags=['用户'])
def post(self, request: Request):
data = {'email': request.user.email}
data.update(request.data)
@ -77,7 +77,8 @@ class SendEmailToCurrentUserView(APIView):
@permission_classes((AllowAny,))
@swagger_auto_schema(operation_summary="发送邮件到当前用户",
operation_id="发送邮件到当前用户",
responses=SendEmailSerializer().get_response_body_api())
responses=SendEmailSerializer().get_response_body_api(),
tags=['用户'])
def post(self, request: Request):
serializer_obj = SendEmailSerializer(data={'email': request.user.email, 'type': "reset_password"})
if serializer_obj.is_valid(raise_exception=True):
@ -91,7 +92,8 @@ class Logout(APIView):
@permission_classes((AllowAny,))
@swagger_auto_schema(operation_summary="登出",
operation_id="登出",
responses=SendEmailSerializer().get_response_body_api())
responses=SendEmailSerializer().get_response_body_api(),
tags=['用户'])
def post(self, request: Request):
token_cache.delete(request.META.get('HTTP_AUTHORIZATION', None
))
@ -104,7 +106,9 @@ class Login(APIView):
@swagger_auto_schema(operation_summary="登录",
operation_id="登录",
request_body=LoginSerializer().get_request_body_api(),
responses=LoginSerializer().get_response_body_api())
responses=LoginSerializer().get_response_body_api(),
security=[],
tags=['用户'])
def post(self, request: Request):
login_request = LoginSerializer(data=request.data)
# 校验请求参数
@ -121,7 +125,9 @@ class Register(APIView):
@swagger_auto_schema(operation_summary="用户注册",
operation_id="用户注册",
request_body=RegisterSerializer().get_request_body_api(),
responses=RegisterSerializer().get_response_body_api())
responses=RegisterSerializer().get_response_body_api(),
security=[],
tags=['用户'])
def post(self, request: Request):
serializer_obj = RegisterSerializer(data=request.data)
if serializer_obj.is_valid(raise_exception=True):
@ -136,7 +142,9 @@ class RePasswordView(APIView):
@swagger_auto_schema(operation_summary="修改密码",
operation_id="修改密码",
request_body=RePasswordSerializer().get_request_body_api(),
responses=RePasswordSerializer().get_response_body_api())
responses=RePasswordSerializer().get_response_body_api(),
security=[],
tags=['用户'])
def post(self, request: Request):
serializer_obj = RePasswordSerializer(data=request.data)
return result.success(serializer_obj.reset_password())
@ -149,7 +157,9 @@ class CheckCode(APIView):
@swagger_auto_schema(operation_summary="校验验证码是否正确",
operation_id="校验验证码是否正确",
request_body=CheckCodeSerializer().get_request_body_api(),
responses=CheckCodeSerializer().get_response_body_api())
responses=CheckCodeSerializer().get_response_body_api(),
security=[],
tags=['用户'])
def post(self, request: Request):
return result.success(CheckCodeSerializer(data=request.data).is_valid(raise_exception=True))
@ -160,7 +170,9 @@ class SendEmail(APIView):
@swagger_auto_schema(operation_summary="发送邮件",
operation_id="发送邮件",
request_body=SendEmailSerializer().get_request_body_api(),
responses=SendEmailSerializer().get_response_body_api())
responses=SendEmailSerializer().get_response_body_api(),
security=[],
tags=['用户'])
def post(self, request: Request):
serializer_obj = SendEmailSerializer(data=request.data)
if serializer_obj.is_valid(raise_exception=True):

View File

@ -12,7 +12,7 @@ djangorestframework = "3.14.0"
drf-yasg = "1.21.7"
django-filter = "23.2"
elasticsearch = "8.9.0"
langchain = "0.0.274"
langchain = "^0.0.321"
psycopg2-binary = "2.9.7"
jieba = "^0.42.1"
diskcache = "^5.6.3"
@ -22,6 +22,10 @@ chardet = "^5.2.0"
torch = "^2.1.0"
sentence-transformers = "^2.2.2"
blinker = "^1.6.3"
openai = "^0.28.1"
tiktoken = "^0.5.1"
qianfan = "^0.1.1"
pycryptodome = "^3.19.0"
[build-system]