From f28b222388492dc293fa873c165799f63f19e3d9 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 16 Nov 2023 13:16:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=BA=94=E7=94=A8=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3,=E6=A8=A1=E5=9E=8B=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- apps/application/migrations/0001_initial.py | 50 ++- ...cord_dataset_alter_chatrecord_paragraph.py | 25 ++ ...cord_dataset_alter_chatrecord_paragraph.py | 25 ++ ...004_alter_chatrecord_source_id_and_more.py | 23 ++ .../0005_alter_chatrecord_source_type.py | 18 + ...pplicationaccesstoken_applicationapikey.py | 43 ++ apps/application/models/application.py | 56 ++- .../serializers/application_serializers.py | 366 ++++++++++++++++++ .../serializers/chat_message_serializers.py | 167 ++++++++ .../serializers/chat_serializers.py | 282 ++++++++++++++ apps/application/sql/list_application.sql | 8 + .../application/sql/list_application_chat.sql | 16 + .../sql/list_application_dataset.sql | 20 + .../swagger_api/application_api.py | 165 ++++++++ apps/application/swagger_api/chat_api.py | 159 ++++++++ apps/application/urls.py | 35 ++ apps/application/views.py | 3 - apps/application/views/__init__.py | 10 + apps/application/views/application_views.py | 250 ++++++++++++ apps/application/views/chat_views.py | 198 ++++++++++ apps/common/auth/authenticate.py | 53 ++- apps/common/auth/authentication.py | 21 +- apps/common/constants/authentication_type.py | 16 + apps/common/constants/permission_constants.py | 46 ++- apps/common/db/compiler.py | 2 +- apps/common/db/search.py | 30 +- apps/common/event/__init__.py | 15 + apps/common/event/common.py | 18 + apps/common/event/listener_chat_message.py | 54 +++ apps/common/event/listener_manage.py | 11 +- apps/common/exception/app_exception.py | 11 + apps/common/froms/__init__.py | 22 ++ apps/common/froms/array_card.py | 36 ++ apps/common/froms/base_field.py | 159 ++++++++ apps/common/froms/base_form.py | 16 + apps/common/froms/combobox_field.py | 41 ++ apps/common/froms/multi_select.py | 41 ++ apps/common/froms/number_input_field.py | 27 ++ apps/common/froms/object_card.py | 36 ++ apps/common/froms/password_input.py | 27 ++ apps/common/froms/radio_field.py | 41 ++ apps/common/froms/single_select_field.py | 41 ++ apps/common/froms/switch_btn.py | 28 ++ apps/common/froms/tab_card.py | 36 ++ apps/common/froms/table_radio.py | 36 ++ apps/common/froms/text_input_field.py | 27 ++ apps/common/sql/list_embedding_text.sql | 10 +- apps/common/util/common.py | 25 +- apps/common/util/rsa_util.py | 70 ++++ apps/common/util/test.py | 79 ++++ .../serializers/dataset_serializers.py | 4 +- .../serializers/paragraph_serializers.py | 10 +- .../serializers/problem_serializers.py | 7 + apps/dataset/views/dataset.py | 21 +- apps/dataset/views/document.py | 22 +- apps/dataset/views/paragraph.py | 19 +- apps/dataset/views/problem.py | 9 +- ...mbedding_star_num_embedding_trample_num.py | 23 ++ .../0003_alter_embedding_unique_together.py | 17 + .../0004_alter_embedding_source_type.py | 18 + .../0005_alter_embedding_source_type.py | 18 + apps/embedding/models/embedding.py | 8 +- apps/embedding/sql/embedding_search.sql | 15 + apps/embedding/vector/base_vector.py | 14 +- apps/embedding/vector/pg_vector.py | 54 ++- apps/setting/migrations/0002_model.py | 34 ++ .../migrations/0003_alter_model_provider.py | 18 + apps/setting/models_provider/__init__.py | 8 + .../models_provider/base_model_provider.py | 130 +++++++ .../constants/model_provider_constants.py | 17 + .../impl/azure_model_provider/__init__.py | 8 + .../azure_model_provider.py | 110 ++++++ .../azure_model_provider/icon/azure_icon_svg | 2 + .../impl/wenxin_model_provider/__init__.py | 8 + .../wenxin_model_provider/icon/azure_icon_svg | 2 + .../model/qian_fan_chat_model.py | 78 ++++ .../wenxin_model_provider.py | 134 +++++++ .../serializers/provider_serializers.py | 119 ++++++ apps/setting/serializers/team_serializers.py | 12 +- apps/setting/swagger_api/provide_api.py | 156 ++++++++ apps/setting/urls.py | 11 +- apps/setting/views/Team.py | 23 +- apps/setting/views/__init__.py | 1 + apps/setting/views/model.py | 122 ++++++ apps/smartdoc/settings/base.py | 4 + apps/smartdoc/urls.py | 9 +- apps/smartdoc/wsgi.py | 6 +- apps/users/serializers/user_serializers.py | 4 +- apps/users/views/user.py | 34 +- pyproject.toml | 6 +- 91 files changed, 4173 insertions(+), 138 deletions(-) create mode 100644 apps/application/migrations/0002_alter_chatrecord_dataset_alter_chatrecord_paragraph.py create mode 100644 apps/application/migrations/0003_alter_chatrecord_dataset_alter_chatrecord_paragraph.py create mode 100644 apps/application/migrations/0004_alter_chatrecord_source_id_and_more.py create mode 100644 apps/application/migrations/0005_alter_chatrecord_source_type.py create mode 100644 apps/application/migrations/0006_applicationaccesstoken_applicationapikey.py create mode 100644 apps/application/serializers/application_serializers.py create mode 100644 apps/application/serializers/chat_message_serializers.py create mode 100644 apps/application/serializers/chat_serializers.py create mode 100644 apps/application/sql/list_application.sql create mode 100644 apps/application/sql/list_application_chat.sql create mode 100644 apps/application/sql/list_application_dataset.sql create mode 100644 apps/application/swagger_api/application_api.py create mode 100644 apps/application/swagger_api/chat_api.py create mode 100644 apps/application/urls.py delete mode 100644 apps/application/views.py create mode 100644 apps/application/views/__init__.py create mode 100644 apps/application/views/application_views.py create mode 100644 apps/application/views/chat_views.py create mode 100644 apps/common/constants/authentication_type.py create mode 100644 apps/common/event/__init__.py create mode 100644 apps/common/event/common.py create mode 100644 apps/common/event/listener_chat_message.py create mode 100644 apps/common/froms/__init__.py create mode 100644 apps/common/froms/array_card.py create mode 100644 apps/common/froms/base_field.py create mode 100644 apps/common/froms/base_form.py create mode 100644 apps/common/froms/combobox_field.py create mode 100644 apps/common/froms/multi_select.py create mode 100644 apps/common/froms/number_input_field.py create mode 100644 apps/common/froms/object_card.py create mode 100644 apps/common/froms/password_input.py create mode 100644 apps/common/froms/radio_field.py create mode 100644 apps/common/froms/single_select_field.py create mode 100644 apps/common/froms/switch_btn.py create mode 100644 apps/common/froms/tab_card.py create mode 100644 apps/common/froms/table_radio.py create mode 100644 apps/common/froms/text_input_field.py create mode 100644 apps/common/util/rsa_util.py create mode 100644 apps/common/util/test.py create mode 100644 apps/embedding/migrations/0002_embedding_star_num_embedding_trample_num.py create mode 100644 apps/embedding/migrations/0003_alter_embedding_unique_together.py create mode 100644 apps/embedding/migrations/0004_alter_embedding_source_type.py create mode 100644 apps/embedding/migrations/0005_alter_embedding_source_type.py create mode 100644 apps/embedding/sql/embedding_search.sql create mode 100644 apps/setting/migrations/0002_model.py create mode 100644 apps/setting/migrations/0003_alter_model_provider.py create mode 100644 apps/setting/models_provider/__init__.py create mode 100644 apps/setting/models_provider/base_model_provider.py create mode 100644 apps/setting/models_provider/constants/model_provider_constants.py create mode 100644 apps/setting/models_provider/impl/azure_model_provider/__init__.py create mode 100644 apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py create mode 100644 apps/setting/models_provider/impl/azure_model_provider/icon/azure_icon_svg create mode 100644 apps/setting/models_provider/impl/wenxin_model_provider/__init__.py create mode 100644 apps/setting/models_provider/impl/wenxin_model_provider/icon/azure_icon_svg create mode 100644 apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py create mode 100644 apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py create mode 100644 apps/setting/serializers/provider_serializers.py create mode 100644 apps/setting/swagger_api/provide_api.py create mode 100644 apps/setting/views/model.py diff --git a/.gitignore b/.gitignore index 136ad6f4e..2ad357d57 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,7 @@ share/python-wheels/ MANIFEST # PyInstaller -# Usually these files are written by a python script from a template +# Usually these files are written by a python script froms a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec diff --git a/apps/application/migrations/0001_initial.py b/apps/application/migrations/0001_initial.py index 8837f441f..ebf6d47ec 100644 --- a/apps/application/migrations/0001_initial.py +++ b/apps/application/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.10 on 2023-10-24 12:13 +# Generated by Django 4.1.10 on 2023-11-14 07:30 import django.contrib.postgres.fields from django.db import migrations, models @@ -11,6 +11,7 @@ class Migration(migrations.Migration): initial = True dependencies = [ + ('setting', '0003_alter_model_provider'), ('users', '0001_initial'), ('dataset', '0001_initial'), ] @@ -26,8 +27,9 @@ class Migration(migrations.Migration): ('desc', models.CharField(max_length=128, verbose_name='引用描述')), ('prologue', models.CharField(max_length=1024, verbose_name='开场白')), ('example', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=256), size=None, verbose_name='示例列表')), + ('dialogue_number', models.IntegerField(default=0, verbose_name='会话数量')), ('status', models.BooleanField(default=True, verbose_name='是否发布')), - ('is_active', models.BooleanField(default=True)), + ('model', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user')), ], options={ @@ -35,13 +37,49 @@ class Migration(migrations.Migration): }, ), migrations.CreateModel( - name='ApplicationDatasetMapping', + name='Chat', fields=[ - ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('application', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')), - ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('abstract', models.CharField(max_length=256, verbose_name='摘要')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ], + options={ + 'db_table': 'application_chat', + }, + ), + migrations.CreateModel( + name='ChatRecord', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('vote_status', models.CharField(choices=[('-1', '未投票'), ('0', '赞同'), ('1', '反对')], default='-1', max_length=10, verbose_name='投票')), + ('source_id', models.UUIDField(verbose_name='资源id 段落/问题 id ')), + ('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')), + ('message_tokens', models.IntegerField(default=0, verbose_name='请求token数量')), + ('answer_tokens', models.IntegerField(default=0, verbose_name='响应token数量')), + ('problem_text', models.CharField(max_length=1024, verbose_name='问题')), + ('answer_text', models.CharField(max_length=1024, verbose_name='答案')), + ('improve_problem_id_list', django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='改进标注列表')), + ('index', models.IntegerField(verbose_name='对话下标')), + ('chat', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.chat')), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集')), + ('paragraph', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id')), + ], + options={ + 'db_table': 'application_chat_record', + }, + ), + migrations.CreateModel( + name='ApplicationDatasetMapping', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='dataset.dataset')), ], options={ 'db_table': 'application_dataset_mapping', diff --git a/apps/application/migrations/0002_alter_chatrecord_dataset_alter_chatrecord_paragraph.py b/apps/application/migrations/0002_alter_chatrecord_dataset_alter_chatrecord_paragraph.py new file mode 100644 index 000000000..706aaa261 --- /dev/null +++ b/apps/application/migrations/0002_alter_chatrecord_dataset_alter_chatrecord_paragraph.py @@ -0,0 +1,25 @@ +# Generated by Django 4.1.10 on 2023-11-14 08:02 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0001_initial'), + ('application', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='dataset', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集'), + ), + migrations.AlterField( + model_name='chatrecord', + name='paragraph', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id'), + ), + ] diff --git a/apps/application/migrations/0003_alter_chatrecord_dataset_alter_chatrecord_paragraph.py b/apps/application/migrations/0003_alter_chatrecord_dataset_alter_chatrecord_paragraph.py new file mode 100644 index 000000000..0cc698559 --- /dev/null +++ b/apps/application/migrations/0003_alter_chatrecord_dataset_alter_chatrecord_paragraph.py @@ -0,0 +1,25 @@ +# Generated by Django 4.1.10 on 2023-11-14 08:03 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0001_initial'), + ('application', '0002_alter_chatrecord_dataset_alter_chatrecord_paragraph'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='dataset', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集'), + ), + migrations.AlterField( + model_name='chatrecord', + name='paragraph', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落id'), + ), + ] diff --git a/apps/application/migrations/0004_alter_chatrecord_source_id_and_more.py b/apps/application/migrations/0004_alter_chatrecord_source_id_and_more.py new file mode 100644 index 000000000..6e15c2745 --- /dev/null +++ b/apps/application/migrations/0004_alter_chatrecord_source_id_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.10 on 2023-11-14 08:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0003_alter_chatrecord_dataset_alter_chatrecord_paragraph'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='source_id', + field=models.UUIDField(null=True, verbose_name='资源id 段落/问题 id '), + ), + migrations.AlterField( + model_name='chatrecord', + name='source_type', + field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('-1', '其他')], default='-1', max_length=5, verbose_name='资源类型'), + ), + ] diff --git a/apps/application/migrations/0005_alter_chatrecord_source_type.py b/apps/application/migrations/0005_alter_chatrecord_source_type.py new file mode 100644 index 000000000..59776684c --- /dev/null +++ b/apps/application/migrations/0005_alter_chatrecord_source_type.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.10 on 2023-11-14 08:11 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0004_alter_chatrecord_source_id_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='source_type', + field=models.CharField(blank=True, choices=[('0', '问题'), ('1', '段落')], default='0', max_length=2, null=True, verbose_name='资源类型'), + ), + ] diff --git a/apps/application/migrations/0006_applicationaccesstoken_applicationapikey.py b/apps/application/migrations/0006_applicationaccesstoken_applicationapikey.py new file mode 100644 index 000000000..c4a65a9a2 --- /dev/null +++ b/apps/application/migrations/0006_applicationaccesstoken_applicationapikey.py @@ -0,0 +1,43 @@ +# Generated by Django 4.1.10 on 2023-11-15 09:55 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0001_initial'), + ('application', '0005_alter_chatrecord_source_type'), + ] + + operations = [ + migrations.CreateModel( + name='ApplicationAccessToken', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='application.application', verbose_name='应用id')), + ('access_token', models.CharField(max_length=128, unique=True, verbose_name='用户公开访问 认证token')), + ('is_active', models.BooleanField(default=True, verbose_name='是否开启公开访问')), + ], + options={ + 'db_table': 'application_access_token', + }, + ), + migrations.CreateModel( + name='ApplicationApiKey', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('secret_key', models.CharField(max_length=1024, unique=True, verbose_name='秘钥')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', verbose_name='用户id')), + ], + options={ + 'db_table': 'application_api_key', + }, + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 79cef354d..b7daeadf0 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -12,7 +12,9 @@ from django.contrib.postgres.fields import ArrayField from django.db import models from common.mixins.app_model_mixin import AppModelMixin -from dataset.models.data_set import DataSet +from dataset.models.data_set import DataSet, Paragraph +from embedding.models import SourceType +from setting.models.model_management import Model from users.models import User @@ -22,16 +24,62 @@ class Application(AppModelMixin): desc = models.CharField(max_length=128, verbose_name="引用描述") prologue = models.CharField(max_length=1024, verbose_name="开场白") example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True)) + dialogue_number = models.IntegerField(default=0, verbose_name="会话数量") status = models.BooleanField(default=True, verbose_name="是否发布") user = models.ForeignKey(User, on_delete=models.DO_NOTHING) - is_active = models.BooleanField(default=True) + model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, db_constraint=False) + class Meta: db_table = "application" class ApplicationDatasetMapping(AppModelMixin): - application = models.ForeignKey(Application, on_delete=models.DO_NOTHING) - dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + application = models.ForeignKey(Application, on_delete=models.CASCADE) + dataset = models.ForeignKey(DataSet, on_delete=models.CASCADE) class Meta: db_table = "application_dataset_mapping" + + +class Chat(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + application = models.ForeignKey(Application, on_delete=models.CASCADE) + abstract = models.CharField(max_length=256, verbose_name="摘要") + + class Meta: + db_table = "application_chat" + + +class VoteChoices(models.TextChoices): + """订单类型""" + UN_VOTE = -1, '未投票' + STAR = 0, '赞同' + TRAMPLE = 1, '反对' + + +class ChatRecord(AppModelMixin): + """ + 对话日志 详情 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + chat = models.ForeignKey(Chat, on_delete=models.CASCADE) + vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, + default=VoteChoices.UN_VOTE) + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集", blank=True, null=True) + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落id", blank=True, null=True) + source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True) + source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices, + default=SourceType.PROBLEM, blank=True, null=True) + message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) + answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) + problem_text = models.CharField(max_length=1024, verbose_name="问题") + answer_text = models.CharField(max_length=1024, verbose_name="答案") + improve_problem_id_list = ArrayField(verbose_name="改进标注列表", + base_field=models.UUIDField(max_length=128, blank=True) + , default=list) + + index = models.IntegerField(verbose_name="对话下标") + + class Meta: + db_table = "application_chat_record" diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py new file mode 100644 index 000000000..817808de7 --- /dev/null +++ b/apps/application/serializers/application_serializers.py @@ -0,0 +1,366 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_serializers.py + @date:2023/11/7 10:02 + @desc: +""" +import hashlib +import os +import uuid +from typing import Dict + +from django.contrib.postgres.fields import ArrayField +from django.core import cache +from django.core import signing +from django.db import transaction, models +from django.db.models import QuerySet +from rest_framework import serializers + +from application.models import Application, ApplicationDatasetMapping +from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey +from common.constants.authentication_type import AuthenticationType +from common.db.search import get_dynamics_model, native_search, native_page_search +from common.db.sql_execute import select_list +from common.exception.app_exception import AppApiException, NotFound404, AppAuthenticationFailed +from common.util.file_util import get_file_content +from dataset.models import DataSet +from setting.models import AuthOperate +from setting.models.model_management import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from smartdoc.conf import PROJECT_DIR +from smartdoc.settings import JWT_AUTH + +token_cache = cache.caches['token_cache'] + + +class ModelDatasetAssociation(serializers.Serializer): + user_id = serializers.UUIDField(required=True) + model_id = serializers.CharField(required=True) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + model_id = self.data.get('model_id') + user_id = self.data.get('user_id') + if not QuerySet(Model).filter(id=model_id).exists(): + raise AppApiException(500, f'模型不存在【{model_id}】') + dataset_id_list = list(set(self.data.get('dataset_id_list'))) + exist_dataset_id_list = [str(dataset.id) for dataset in + QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)] + for dataset_id in dataset_id_list: + if not exist_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f'数据集id不存在【{dataset_id}】') + + +class ApplicationSerializerModel(serializers.ModelSerializer): + class Meta: + model = Application + fields = "__all__" + + +class ApplicationSerializer(serializers.Serializer): + name = serializers.CharField(required=True) + desc = serializers.CharField(required=True) + model_id = serializers.CharField(required=True) + multiple_rounds_dialogue = serializers.BooleanField(required=True) + prologue = serializers.CharField(required=True) + example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True)) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + + class AccessTokenSerializer(serializers.Serializer): + application_id = serializers.UUIDField(required=True) + + class AccessTokenEditSerializer(serializers.Serializer): + access_token_reset = serializers.UUIDField(required=False) + is_active = serializers.BooleanField(required=False) + + def edit(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ApplicationSerializer.AccessTokenSerializer.AccessTokenEditSerializer(data=instance).is_valid( + raise_exception=True) + + application_access_token = QuerySet(ApplicationAccessToken).get( + application_id=self.data.get('application_id')) + if 'is_active' in instance: + application_access_token.is_active = instance.get("is_active") + if 'access_token_reset' in instance and instance.get('access_token_reset'): + application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24] + application_access_token.save() + return self.one(with_valid=False) + + def one(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get("application_id") + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=application_id).first() + if application_access_token is None: + application_access_token = ApplicationAccessToken(application_id=application_id, + access_token=hashlib.md5( + str(uuid.uuid1()).encode()).hexdigest()[ + 8:24], is_active=True) + application_access_token.save() + return {'application_id': application_access_token.application_id, + 'access_token': application_access_token.access_token, + "is_active": application_access_token.is_active} + + class Authentication(serializers.Serializer): + access_token = serializers.CharField(required=True) + + def auth(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + access_token = self.data.get("access_token") + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + if application_access_token is not None and application_access_token.is_active: + token = signing.dumps({'application_id': str(application_access_token.application_id), + 'user_id': str(application_access_token.application.user.id), + 'access_token': application_access_token.access_token, + 'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value}) + token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA']) + return token + else: + raise AppAuthenticationFailed(401, "无效的access_token") + + class Edit(serializers.Serializer): + name = serializers.CharField(required=False) + desc = serializers.CharField(required=False) + model_id = serializers.CharField(required=False) + multiple_rounds_dialogue = serializers.BooleanField(required=False) + prologue = serializers.CharField(required=False) + example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True)) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + + def is_valid(self, *, user_id=None, raise_exception=False): + super().is_valid(raise_exception=True) + ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'), + 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() + + class Create(serializers.Serializer): + user_id = serializers.UUIDField(required=True) + + @transaction.atomic + def insert(self, application: Dict): + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + ApplicationSerializer(data=application).is_valid(user_id=user_id, raise_exception=True) + application_model = ApplicationSerializer.Create.to_application_model(user_id, application) + dataset_id_list = application.get('dataset_id_list', []) + application_dataset_mapping_model_list = [ + ApplicationSerializer.Create.to_application_dateset_mapping(application_model.id, dataset_id) for + dataset_id in dataset_id_list] + # 插入应用 + application_model.save() + # 插入认证信息 + ApplicationAccessToken(application_id=application_model.id, + access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save() + # 插入关联数据 + QuerySet(ApplicationDatasetMapping).bulk_create(application_dataset_mapping_model_list) + return True + + @staticmethod + def to_application_model(user_id: str, application: Dict): + return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'), + prologue=application.get('prologue'), example=application.get('example'), + dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0, + status=True, user_id=user_id, model_id=application.get('model_id'), + ) + + @staticmethod + def to_application_dateset_mapping(application_id: str, dataset_id: str): + return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id) + + class Query(serializers.Serializer): + name = serializers.CharField(required=False) + + desc = serializers.CharField(required=False) + + user_id = serializers.UUIDField(required=True) + + def get_query_set(self): + user_id = self.data.get("user_id") + query_set_dict = {} + query_set = QuerySet(model=get_dynamics_model( + {'temp_application.name': models.CharField(), 'temp_application.desc': models.CharField()})) + if "desc" in self.data and self.data.get('desc') is not None: + query_set = query_set.filter(**{'temp_application.desc__contains': self.data.get("desc")}) + if "name" in self.data and self.data.get('name') is not None: + query_set = query_set.filter(**{'temp_application.name__contains': self.data.get("name")}) + + query_set_dict['default_sql'] = query_set + + query_set_dict['application_custom_sql'] = QuerySet(model=get_dynamics_model( + {'application.user_id': models.CharField(), + })).filter( + **{'application.user_id': user_id} + ) + + query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model( + {'user_id': models.CharField(), + 'team_member_permission.auth_target_type': models.CharField(), + 'team_member_permission.operate': ArrayField(verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE) + )})).filter( + **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'], + 'team_member_permission.auth_target_type': 'APPLICATION'}) + + return query_set_dict + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return [ApplicationSerializer.Query.reset_application(a) for a in + native_search(self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')))] + + @staticmethod + def reset_application(application: Dict): + application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False + del application['dialogue_number'] + return application + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')), + post_records_handler=ApplicationSerializer.Query.reset_application) + + class ApplicationModel(serializers.ModelSerializer): + class Meta: + model = Application + fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number', 'status'] + + class Operate(serializers.Serializer): + application_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Application).filter(id=self.data.get('application_id')).exists(): + raise AppApiException(500, '不存在的应用id') + + def delete(self, with_valid=True): + if with_valid: + self.is_valid() + QuerySet(Application).filter(id=self.data.get('application_id')).delete() + return True + + def one(self, with_valid=True): + if with_valid: + self.is_valid() + application_id = self.data.get("application_id") + application = QuerySet(Application).get(id=application_id) + dataset_list = self.list_dataset(with_valid=False) + mapping_dataset_id_list = [adm.dataset_id for adm in + QuerySet(ApplicationDatasetMapping).filter(application_id=application_id)] + dataset_id_list = [d.get('id') for d in + list(filter(lambda row: mapping_dataset_id_list.__contains__(row.get('id')), + dataset_list))] + return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data), + 'dataset_id_list': dataset_id_list} + + def profile(self, with_valid=True): + if with_valid: + self.is_valid() + application_id = self.data.get("application_id") + application = QuerySet(Application).get(id=application_id) + return ApplicationSerializer.Query.reset_application( + ApplicationSerializer.ApplicationModel(application).data) + + def edit(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid() + ApplicationSerializer.Edit(data=instance).is_valid( + raise_exception=True) + application_id = self.data.get("application_id") + + application = QuerySet(Application).get(id=application_id) + + model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id) + + update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + if update_key == 'multiple_rounds_dialogue': + application.__setattr__('dialogue_number', + 0 if instance.get(update_key) else ModelProvideConstants[ + model.provider].value.get_dialogue_number()) + else: + application.__setattr__(update_key, instance.get(update_key)) + application.save() + + if 'dataset_id_list' in instance: + dataset_id_list = instance.get('dataset_id_list') + # 当前用户可修改关联的数据集列表 + application_dataset_id_list = [dataset_dict.get('id') for dataset_dict in + self.list_dataset(with_valid=False)] + for dataset_id in dataset_id_list: + if not application_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f"未知的数据集id${dataset_id},无法关联") + + # 删除已经关联的id + QuerySet(ApplicationDatasetMapping).filter(dataset_id__in=application_dataset_id_list, + application_id=application_id).delete() + # 插入 + QuerySet(ApplicationDatasetMapping).bulk_create( + [ApplicationDatasetMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in + dataset_id_list]) if len(dataset_id_list) > 0 else None + return self.one(with_valid=False) + + def list_dataset(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).get(id=self.data.get("application_id")) + return select_list(get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')), + [self.data.get('user_id'), application.user_id, self.data.get('user_id')]) + + class ApplicationKeySerializerModel(serializers.ModelSerializer): + class Meta: + model = ApplicationApiKey + fields = "__all__" + + class ApplicationKeySerializer(serializers.Serializer): + user_id = serializers.UUIDField(required=True) + + application_id = serializers.UUIDField(required=True) + + def generate(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + user_id = self.data.get("user_id") + application_id = self.data.get("application_id") + secret_key = 'application-' + hashlib.md5(str(uuid.uuid1()).encode()).hexdigest() + application_api_key = ApplicationApiKey(id=uuid.uuid1(), secret_key=secret_key, user_id=user_id, + application_id=application_id) + application_api_key.save() + return ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + user_id = self.data.get("user_id") + application_id = self.data.get("application_id") + return [ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data for + application_api_key in + QuerySet(ApplicationApiKey).filter(user_id=user_id, application_id=application_id)] + + class Operate(serializers.Serializer): + application_id = serializers.UUIDField(required=True) + + api_key_id = serializers.CharField(required=True) + + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + api_key_id = self.data.get("api_key_id") + application_id = self.data.get('application_id') + QuerySet(ApplicationApiKey).filter(id=api_key_id, + application_id=application_id).delete() diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py new file mode 100644 index 000000000..9a38b926d --- /dev/null +++ b/apps/application/serializers/chat_message_serializers.py @@ -0,0 +1,167 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_message_serializers.py + @date:2023/11/14 13:51 + @desc: +""" +import json +import uuid +from typing import List + +from django.db.models import QuerySet +from django.http import StreamingHttpResponse +from langchain.chat_models.base import BaseChatModel +from langchain.schema import HumanMessage +from rest_framework import serializers, status +from django.core.cache import cache +from common import event +from common.config.embedding_config import VectorStore, EmbeddingModel +from common.response import result +from dataset.models import Paragraph +from embedding.models import SourceType +from setting.models.model_management import Model + +chat_cache = cache + + +class MessageManagement: + @staticmethod + def get_message(title: str, content: str, message: str): + if content is None: + return HumanMessage(content=message) + return HumanMessage(content=( + f'已知信息:{title}:{content} ' + '根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从已知信息中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 ' + f'问题是:{message}')) + + +class ChatMessage: + def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str, + document_id: str, + paragraph_id, + source_type: SourceType, + source_id: str, + answer: str, + message_tokens: int, + answer_token: int): + self.id = id + self.problem = problem + self.title = title + self.paragraph = paragraph + self.embedding_id = embedding_id + self.dataset_id = dataset_id + self.document_id = document_id + self.paragraph_id = paragraph_id + self.source_type = source_type + self.source_id = source_id + self.answer = answer + self.message_tokens = message_tokens + self.answer_token = answer_token + + def get_chat_message(self): + return MessageManagement.get_message(self.problem, self.paragraph, self.problem) + + +class ChatInfo: + def __init__(self, + chat_id: str, + model: Model, + chat_model: BaseChatModel, + application_id: str | None, + dataset_id_list: List[str], + exclude_document_id_list: list[str], + dialogue_number: int): + self.chat_id = chat_id + self.application_id = application_id + self.model = model + self.chat_model = chat_model + self.dataset_id_list = dataset_id_list + self.exclude_document_id_list = exclude_document_id_list + self.dialogue_number = dialogue_number + self.chat_message_list: List[ChatMessage] = [] + + def append_chat_message(self, chat_message: ChatMessage): + self.chat_message_list.append(chat_message) + if self.application_id is not None: + event.ListenerChatMessage.record_chat_message_signal.send( + event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id, + chat_message) + ) + + def get_context_message(self): + start_index = len(self.chat_message_list) - self.dialogue_number + return [self.chat_message_list[index].get_chat_message() for index in + range(start_index if start_index > 0 else 0, len(self.chat_message_list))] + + +class ChatMessageSerializer(serializers.Serializer): + chat_id = serializers.UUIDField(required=True) + + def chat(self, message): + self.is_valid(raise_exception=True) + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = chat_cache.get(chat_id) + if chat_info is None: + return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期") + + chat_model = chat_info.chat_model + vector = VectorStore.get_embedding_vector() + # 向量库检索 + _value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list, + [chat_message.embedding_id for chat_message in + (list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))], + True, + EmbeddingModel.get_embedding_model()) + # 查询段落id详情 + paragraph = None + if _value is not None: + paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id')) + if paragraph is None: + vector.delete_by_paragraph_id(_value.get('paragraph_id')) + + title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content) + + embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get( + 'id'), _value.get( + 'dataset_id'), _value.get( + 'document_id'), _value.get( + 'paragraph_id'), _value.get( + 'source_type'), _value.get( + 'source_id')) if _value is not None else (None, None, None, None, None, None) + # 获取上下文 + history_message = chat_info.get_context_message() + + # 构建会话请求问题 + chat_message = [*history_message, MessageManagement.get_message(title, content, message)] + # 对话 + result_data = chat_model.stream(chat_message) + + _id = str(uuid.uuid1()) + + def event_content(response): + all_text = '' + try: + for chunk in response: + all_text += chunk.content + yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None, + 'content': chunk.content}) + "\n\n" + + chat_info.append_chat_message( + ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id, + paragraph_id, + source_type, + source_id, all_text, chat_model.get_num_tokens_from_messages(chat_message), + chat_model.get_num_tokens(all_text))) + # 重新设置缓存 + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) + except Exception as e: + yield e + + r = StreamingHttpResponse(streaming_content=event_content(result_data), + content_type='text/event-stream;charset=utf-8') + + r['Cache-Control'] = 'no-cache' + return r diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py new file mode 100644 index 000000000..708e062ec --- /dev/null +++ b/apps/application/serializers/chat_serializers.py @@ -0,0 +1,282 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_serializers.py + @date:2023/11/14 9:59 + @desc: +""" +import datetime +import json +import os +import uuid +from typing import Dict + +from django.core.cache import cache +from django.db import transaction +from django.db.models import QuerySet +from rest_framework import serializers + +from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord +from application.serializers.application_serializers import ModelDatasetAssociation +from application.serializers.chat_message_serializers import ChatInfo +from common.db.search import native_search, native_page_search, page_search +from common.exception.app_exception import AppApiException +from common.util.file_util import get_file_content +from common.util.lock import try_lock, un_lock +from common.util.rsa_util import decrypt +from dataset.models import Document, Problem, Paragraph +from embedding.models import SourceType, Embedding +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.views import Model +from smartdoc.conf import PROJECT_DIR + +chat_cache = cache + + +class ChatSerializers(serializers.Serializer): + class Query(serializers.Serializer): + abstract = serializers.CharField(required=False) + history_day = serializers.IntegerField(required=True) + user_id = serializers.UUIDField(required=True) + application_id = serializers.UUIDField(required=True) + + def get_end_time(self): + history_day = self.data.get('history_day') + return datetime.datetime.now() - datetime.timedelta(days=history_day) + + def get_query_set(self): + end_time = self.get_end_time() + return QuerySet(Chat).filter(application_id=self.data.get("application_id"), create_time__gte=end_time) + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return native_search(self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')), + with_table_name=True) + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')), + with_table_name=True) + + class OpenChat(serializers.Serializer): + user_id = serializers.UUIDField(required=True) + + application_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + user_id = self.data.get('user_id') + application_id = self.data.get('application_id') + if not QuerySet(Application).filter(id=application_id, user_id=user_id).exists(): + raise AppApiException(500, '应用不存在') + + def open(self): + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).get(id=application_id) + model = application.model + dataset_id_list = [str(row.dataset_id) for row in + QuerySet(ApplicationDatasetMapping).filter( + application_id=application_id)] + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + decrypt(model.credential)), + streaming=True) + + chat_id = str(uuid.uuid1()) + chat_cache.set(chat_id, + ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list, + [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)], + application.dialogue_number), timeout=60 * 30) + return chat_id + + class OpenTempChat(serializers.Serializer): + user_id = serializers.UUIDField(required=True) + + model_id = serializers.UUIDField(required=True) + + multiple_rounds_dialogue = serializers.BooleanField(required=True) + + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + ModelDatasetAssociation( + data={'user_id': self.data.get('user_id'), 'model_id': self.data.get('model_id'), + 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() + + def open(self): + self.is_valid(raise_exception=True) + chat_id = str(uuid.uuid1()) + model = QuerySet(Model).get(user_id=self.data.get('user_id'), id=self.data.get('model_id')) + dataset_id_list = self.data.get('dataset_id_list') + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + decrypt(model.credential)), + streaming=True) + chat_cache.set(chat_id, + ChatInfo(chat_id, model, chat_model, None, dataset_id_list, + [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)], + 3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30) + return chat_id + + +def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler): + if source_type == SourceType.PROBLEM: + problem = QuerySet(Problem).get(id=source_id) + if problem is not None: + problem.__setattr__(field, post_handler(problem)) + problem.save() + embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type) + embedding.__setattr__(field, problem.__getattribute__(field)) + embedding.save() + if source_type == SourceType.PARAGRAPH: + paragraph = QuerySet(Paragraph).get(id=source_id) + if paragraph is not None: + paragraph.__setattr__(field, post_handler(paragraph)) + paragraph.save() + embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type) + embedding.__setattr__(field, paragraph.__getattribute__(field)) + embedding.save() + + +class ChatRecordSerializerModel(serializers.ModelSerializer): + class Meta: + model = ChatRecord + fields = "__all__" + + +class ChatRecordSerializer(serializers.Serializer): + class Query(serializers.Serializer): + application_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True) + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return [ChatRecordSerializerModel(chat_record).data for chat_record in + QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))] + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return page_search(current_page, page_size, + QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"), + post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data) + + class Vote(serializers.Serializer): + chat_id = serializers.UUIDField(required=True) + + chat_record_id = serializers.UUIDField(required=True) + + vote_status = serializers.ChoiceField(choices=VoteChoices.choices) + + @transaction.atomic + def vote(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + if not try_lock(self.data.get('chat_record_id')): + raise AppApiException(500, "正在对当前会话纪要进行投票中,请勿重复发送请求") + try: + chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'), + chat_id=self.data.get('chat_id')) + if chat_record_details_model is None: + raise AppApiException(500, "不存在的对话 chat_record_id") + vote_status = self.data.get("vote_status") + if chat_record_details_model.vote_status == VoteChoices.UN_VOTE: + if vote_status == VoteChoices.STAR: + # 点赞 + chat_record_details_model.vote_status = VoteChoices.STAR + # 点赞数量 +1 + vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, + 'star_num', + lambda r: r.star_num + 1) + + if vote_status == VoteChoices.TRAMPLE: + # 点踩 + chat_record_details_model.vote_status = VoteChoices.TRAMPLE + # 点踩数量+1 + vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, + 'trample_num', + lambda r: r.trample_num + 1) + chat_record_details_model.save() + else: + if vote_status == VoteChoices.UN_VOTE: + # 取消点赞 + chat_record_details_model.vote_status = VoteChoices.UN_VOTE + chat_record_details_model.save() + if chat_record_details_model.vote_status == VoteChoices.STAR: + # 点赞数量 -1 + vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, + 'star_num', lambda r: r.star_num - 1) + if chat_record_details_model.vote_status == VoteChoices.TRAMPLE: + # 点踩数量 -1 + vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, + 'trample_num', lambda r: r.trample_num - 1) + + else: + raise AppApiException(500, "已经投票过,请先取消后再进行投票") + finally: + un_lock(self.data.get('chat_record_id')) + + return True + + class ImproveSerializer(serializers.Serializer): + title = serializers.CharField(required=False) + content = serializers.CharField(required=True) + + class Improve(serializers.Serializer): + chat_id = serializers.UUIDField(required=True) + + chat_record_id = serializers.UUIDField(required=True) + + dataset_id = serializers.UUIDField(required=True) + + document_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Document).filter(id=self.data.get('document_id'), + dataset_id=self.data.get('dataset_id')).exists(): + raise AppApiException(500, "文档id不正确") + + @transaction.atomic + def improve(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ChatRecordSerializer.ImproveSerializer(data=instance).is_valid(raise_exception=True) + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + if chat_record is None: + raise AppApiException(500, '不存在的对话记录') + + document_id = self.data.get("document_id") + dataset_id = self.data.get("dataset_id") + paragraph = Paragraph(id=uuid.uuid1(), + document_id=document_id, + content=instance.get("content"), + dataset_id=dataset_id, + title=instance.get("title") if 'title' in instance else '') + + problem = Problem(id=uuid.uuid1(), content=chat_record.problem_text, paragraph_id=paragraph.id, + document_id=document_id, dataset_id=dataset_id) + # 插入问题 + problem.save() + # 插入段落 + paragraph.save() + chat_record.improve_problem_id_list.append(problem.id) + # 添加标注 + chat_record.save() + return True diff --git a/apps/application/sql/list_application.sql b/apps/application/sql/list_application.sql new file mode 100644 index 000000000..283c7bb92 --- /dev/null +++ b/apps/application/sql/list_application.sql @@ -0,0 +1,8 @@ +SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION + SELECT + * + FROM + application + WHERE + application."id" IN ( SELECT team_member_permission.target FROM team_member team_member LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" ${team_member_permission_custom_sql}) + ) temp_application ${default_sql} \ No newline at end of file diff --git a/apps/application/sql/list_application_chat.sql b/apps/application/sql/list_application_chat.sql new file mode 100644 index 000000000..7fcf862d6 --- /dev/null +++ b/apps/application/sql/list_application_chat.sql @@ -0,0 +1,16 @@ +SELECT + * +FROM + application_chat application_chat + LEFT JOIN ( + SELECT COUNT + ( "id" ) AS chat_record_count, + SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num, + SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num, + SUM ( CASE WHEN array_length( application_chat_record.improve_problem_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_problem_id_list, 1 ) END ) AS mark_sum, + chat_id + FROM + application_chat_record + GROUP BY + application_chat_record.chat_id + ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id \ No newline at end of file diff --git a/apps/application/sql/list_application_dataset.sql b/apps/application/sql/list_application_dataset.sql new file mode 100644 index 000000000..691036fc9 --- /dev/null +++ b/apps/application/sql/list_application_dataset.sql @@ -0,0 +1,20 @@ +SELECT + * +FROM + dataset +WHERE + user_id = %s UNION +SELECT + * +FROM + dataset +WHERE + "id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + WHERE + ( "team_member_permission"."auth_target_type" = 'DATASET' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s ) + ) \ No newline at end of file diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py new file mode 100644 index 000000000..b53589fdf --- /dev/null +++ b/apps/application/swagger_api/application_api.py @@ -0,0 +1,165 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_api.py + @date:2023/11/7 10:50 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + +""" + name = serializers.CharField(required=True) + desc = serializers.CharField(required=True) + model_id = serializers.CharField(required=True) + multiple_rounds_dialogue = serializers.BooleanField(required=True) + prologue = serializers.CharField(required=True) + example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True)) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) +""" + + +class ApplicationApi(ApiMixin): + class Authentication(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['access_token', ], + properties={ + 'access_token': openapi.Schema(type=openapi.TYPE_STRING, title="应用认证token", + description="应用认证token"), + + } + ) + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'user_id', 'status', 'create_time', + 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", + description="是否开启多轮对话"), + 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), + 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="示例列表", description="示例列表"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户", description="所属用户"), + + 'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否发布", description='是否发布'), + + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'), + + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'), + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="关联数据集Id列表", + description="关联数据集Id列表(查询详情的时候返回)") + } + ) + + class ApiKey(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id') + + ] + + class Operate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='api_key_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用api_key id') + ] + + class AccessToken(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id') + + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'access_token_reset': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重置Token", + description="重置Token"), + + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否激活", description="是否激活"), + + } + ) + + class Create(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", + description="是否开启多轮对话"), + 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), + 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="示例列表", description="示例列表"), + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="关联数据集Id列表", description="关联数据集Id列表") + + } + ) + + class Query(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='应用名称'), + openapi.Parameter(name='desc', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='应用描述') + ] + + class Operate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + + ] diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py new file mode 100644 index 000000000..89fab5310 --- /dev/null +++ b/apps/application/swagger_api/chat_api.py @@ -0,0 +1,159 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_api.py + @date:2023/11/7 17:29 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ChatApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['message'], + properties={ + 'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"), + + } + ) + + class OpenChat(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + + ] + + class OpenTempChat(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['model_id', 'multiple_rounds_dialogue'], + properties={ + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="关联数据集Id列表", description="关联数据集Id列表"), + 'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话", + description="是否开启多轮会话") + } + ) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='history_day', + in_=openapi.IN_QUERY, + type=openapi.TYPE_NUMBER, + required=True, + description='历史天数') + ] + + +class ChatRecordApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='对话id'), + ] + + +class ImproveApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话id'), + openapi.Parameter(name='chat_record_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话记录id'), + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="段落标题", + description="段落标题"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容") + + } + ) + + +class VoteApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话id'), + openapi.Parameter(name='chat_record_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话记录id') + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['vote_status'], + properties={ + 'vote_status': openapi.Schema(type=openapi.TYPE_STRING, title="投票状态", + description="-1:取消投票|0:赞同|1:反对"), + + } + ) diff --git a/apps/application/urls.py b/apps/application/urls.py new file mode 100644 index 000000000..2533720be --- /dev/null +++ b/apps/application/urls.py @@ -0,0 +1,35 @@ +from django.urls import path + +from . import views + +app_name = "application" +urlpatterns = [ + path('application', views.Application.as_view(), name="application"), + path('application/profile', views.Application.Profile.as_view()), + path('application/authentication', views.Application.Authentication.as_view()), + path('application//api_key', views.Application.ApplicationKey.as_view()), + path("application//api_key/", + views.Application.ApplicationKey.Operate.as_view()), + path('application/', views.Application.Operate.as_view(), name='application/operate'), + path('application//list_dataset', views.Application.ListApplicationDataSet.as_view(), + name='application/dataset'), + path('application//access_token', views.Application.AccessToken.as_view(), + name='application/access_token'), + path('application//', views.Application.Page.as_view(), name='application_page'), + path('application//chat/open', views.ChatView.Open.as_view()), + path("application/chat/open", views.ChatView.OpenTemp.as_view()), + path('application//chat', views.ChatView.as_view(), name='chats'), + path('application//chat//', views.ChatView.Page.as_view()), + path('application//chat//chat_record/', views.ChatView.ChatRecord.as_view()), + path('application//chat//chat_record//', + views.ChatView.ChatRecord.Page.as_view()), + path('application//chat//chat_record//vote', + views.ChatView.ChatRecord.Vote.as_view(), + name=''), + path( + 'application//chat//chat_record//dataset//document_id//improve', + views.ChatView.ChatRecord.Improve.as_view(), + name=''), + path('application/chat_message/', views.ChatView.Message.as_view()) + +] diff --git a/apps/application/views.py b/apps/application/views.py deleted file mode 100644 index 91ea44a21..000000000 --- a/apps/application/views.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.shortcuts import render - -# Create your views here. diff --git a/apps/application/views/__init__.py b/apps/application/views/__init__.py new file mode 100644 index 000000000..52d004041 --- /dev/null +++ b/apps/application/views/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/9/25 17:12 + @desc: +""" +from .application_views import * +from .chat_views import * diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py new file mode 100644 index 000000000..5cfff18f6 --- /dev/null +++ b/apps/application/views/application_views.py @@ -0,0 +1,250 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_views.py + @date:2023/10/27 14:56 + @desc: +""" + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from application.serializers.application_serializers import ApplicationSerializer +from application.swagger_api.application_api import ApplicationApi +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import CompareConstants, PermissionConstants, Permission, Group, Operate, \ + ViewPermission, RoleConstants +from common.exception.app_exception import AppAuthenticationFailed +from common.response import result +from common.util.common import query_params_to_single_dict +from dataset.serializers.dataset_serializers import DataSetSerializers + + +class Application(APIView): + authentication_classes = [TokenAuth] + + class Profile(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用相关信息", + operation_id="获取应用相关信息", + tags=["应用/会话"]) + def get(self, request: Request): + if 'application_id' in request.auth.keywords: + return result.success(ApplicationSerializer.Operate( + data={'application_id': request.auth.keywords.get('application_id'), + 'user_id': request.user.id}).profile()) + else: + raise AppAuthenticationFailed(401, "身份异常") + + class ApplicationKey(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="新增ApiKey", + operation_id="新增ApiKey", + tags=['应用/API_KEY'], + manual_parameters=ApplicationApi.ApiKey.get_request_params_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def post(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.ApplicationKeySerializer( + data={'application_id': application_id, 'user_id': request.user.id}).generate()) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用API_KEY列表", + operation_id="获取应用API_KEY列表", + tags=['应用/API_KEY'], + manual_parameters=ApplicationApi.ApiKey.get_request_params_api() + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.ApplicationKeySerializer( + data={'application_id': application_id, 'user_id': request.user.id}).list()) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除应用API_KEY", + operation_id="删除应用API_KEY", + tags=['应用/API_KEY'], + manual_parameters=ApplicationApi.ApiKey.Operate.get_request_params_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE, + dynamic_tag=k.get('application_id')), + compare=CompareConstants.AND) + def delete(self, request: Request, application_id: str, api_key_id: str): + return result.success( + ApplicationSerializer.ApplicationKeySerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id, + 'api_key_id': api_key_id}).delete()) + + class AccessToken(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改 应用AccessToken", + operation_id="修改 应用AccessToken", + tags=['应用/公开访问'], + manual_parameters=ApplicationApi.AccessToken.get_request_params_api(), + request_body=ApplicationApi.AccessToken.get_request_body_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用 AccessToken信息", + operation_id="获取应用 AccessToken信息", + manual_parameters=ApplicationApi.AccessToken.get_request_params_api(), + tags=['应用/公开访问'], + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).one()) + + class Authentication(APIView): + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="应用认证", + operation_id="应用认证", + request_body=ApplicationApi.Authentication.get_request_body_api(), + tags=["应用/认证"], + security=[]) + def post(self, request: Request): + return result.success( + ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建应用", + operation_id="创建应用", + request_body=ApplicationApi.Create.get_request_body_api(), + tags=['应用']) + @has_permissions(PermissionConstants.APPLICATION_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + ApplicationSerializer.Create(data={'user_id': request.user.id}).insert(request.data) + return result.success(True) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用列表", + operation_id="获取应用列表", + manual_parameters=ApplicationApi.Query.get_request_params_api(), + responses=result.get_api_array_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + def get(self, request: Request): + return result.success( + ApplicationSerializer.Query( + data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list()) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除应用", + operation_id="删除应用", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), + lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE, + dynamic_tag=k.get('application_id')), compare=CompareConstants.AND) + def delete(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).delete( + with_valid=True)) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改应用", + operation_id="修改应用", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + request_body=ApplicationApi.Create.get_request_body_api(), + responses=result.get_api_array_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit( + request.data)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用详情", + operation_id="获取应用详情", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + responses=result.get_api_array_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN, + RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).one()) + + class ListApplicationDataSet(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取当前应用可使用的数据集", + operation_id="获取当前应用可使用的数据集", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()), + tags=['应用']) + @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).list_dataset()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取应用列表", + operation_id="分页获取应用列表", + manual_parameters=result.get_page_request_params( + ApplicationApi.Query.get_request_params_api()), + responses=result.get_page_api_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + def get(self, request: Request, current_page: int, page_size: int): + return result.success( + ApplicationSerializer.Query( + data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).page( + current_page, page_size)) diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py new file mode 100644 index 000000000..cc17384b3 --- /dev/null +++ b/apps/application/views/chat_views.py @@ -0,0 +1,198 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_views.py + @date:2023/11/14 9:53 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from application.serializers.chat_message_serializers import ChatMessageSerializer +from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer +from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import Permission, Group, Operate, \ + RoleConstants, ViewPermission, CompareConstants +from common.exception.app_exception import AppAuthenticationFailed +from common.response import result +from common.util.common import query_params_to_single_dict + + +class ChatView(APIView): + authentication_classes = [TokenAuth] + + class Open(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取会话id,根据应用id", + operation_id="获取会话id,根据应用id", + manual_parameters=ChatApi.OpenChat.get_request_params_api(), + tags=["应用/会话"]) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN, + RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) + ) + def get(self, request: Request, application_id: str): + return result.success(ChatSerializers.OpenChat( + data={'user_id': request.user.id, 'application_id': application_id}).open()) + + class OpenTemp(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="获取会话id(根据模型id,数据集列表,是否多轮会话)", + operation_id="获取会话id", + request_body=ChatApi.OpenTempChat.get_request_body_api(), + tags=["应用/会话"]) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def post(self, request: Request): + return result.success(ChatSerializers.OpenTempChat( + data={**request.data, 'user_id': request.user.id}).open()) + + class Message(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="对话", + operation_id="对话", + request_body=ChatApi.get_request_body_api(), + tags=["应用/会话"]) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def post(self, request: Request, chat_id: str): + return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message')) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话列表", + operation_id="获取对话列表", + manual_parameters=ChatApi.get_request_params_api(), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str): + return result.success(ChatSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'application_id': application_id, + 'user_id': request.user.id}).list()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取对话列表", + operation_id="分页获取对话列表", + manual_parameters=result.get_page_request_params(ChatApi.get_request_params_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, current_page: int, page_size: int): + return result.success(ChatSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'application_id': application_id, + 'user_id': request.user.id}).page(current_page=current_page, + page_size=page_size)) + + class ChatRecord(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话记录列表", + operation_id="获取对话记录列表", + manual_parameters=ChatRecordApi.get_request_params_api(), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, chat_id: str): + return result.success(ChatRecordSerializer.Query( + data={'application_id': application_id, + 'chat_id': chat_id}).list()) + + class Page(APIView): + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话记录列表", + operation_id="获取对话记录列表", + manual_parameters=result.get_page_request_params( + ChatRecordApi.get_request_params_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int): + return result.success(ChatRecordSerializer.Query( + data={'application_id': application_id, + 'chat_id': chat_id}).page(current_page, page_size)) + + class Vote(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="点赞,点踩", + operation_id="点赞,点踩", + manual_parameters=VoteApi.get_request_params_api(), + request_body=VoteApi.get_request_body_api(), + responses=result.get_default_response(), + tags=["应用/会话"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): + return result.success(ChatRecordSerializer.Vote( + data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id, + 'chat_record_id': chat_record_id}).vote()) + + class Improve(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="标注", + operation_id="标注", + manual_parameters=ImproveApi.get_request_params_api(), + request_body=ImproveApi.get_request_body_api(), + responses=result.get_default_response(), + tags=["应用/对话日志/标注"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.MANAGE, + dynamic_tag=keywords.get( + 'dataset_id'))], + ) + )) + def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str, dataset_id: str, + document_id: str): + return result.success(ChatRecordSerializer.Improve( + data={'chat_id': chat_id, 'chat_record_id': chat_record_id, + 'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data)) diff --git a/apps/common/auth/authenticate.py b/apps/common/auth/authenticate.py index bac2dd9b2..dffc37137 100644 --- a/apps/common/auth/authenticate.py +++ b/apps/common/auth/authenticate.py @@ -12,7 +12,10 @@ from django.core import signing from django.db.models import QuerySet from rest_framework.authentication import TokenAuthentication -from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants +from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \ + Operate from common.exception.app_exception import AppAuthenticationFailed from smartdoc.settings import JWT_AUTH from users.models.user import User, get_user_dynamics_permission @@ -34,13 +37,29 @@ class TokenAuth(TokenAuthentication): if auth is None: raise AppAuthenticationFailed(1003, '未登录,请先登录') try: + if str(auth).startswith("application-"): + application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first() + if application_api_key is None: + raise AppAuthenticationFailed(500, "secret_key 无效") + permission_list = [Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=str( + application_api_key.application_id)), + Permission(group=Group.APPLICATION, + operate=Operate.MANAGE, + dynamic_tag=str( + application_api_key.application_id)) + ] + return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY], + permission_list=permission_list, + application_id=application_api_key.application_id) # 解析 token - user = signing.loads(auth) - if 'id' in user: - cache_token = token_cache.get(auth) - if cache_token is None: - raise AppAuthenticationFailed(1002, "登录过期") - user = QuerySet(User).get(id=user['id']) + auth_details = signing.loads(auth) + cache_token = token_cache.get(auth) + if cache_token is None: + raise AppAuthenticationFailed(1002, "登录过期") + if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value: + user = QuerySet(User).get(id=auth_details['id']) # 续期 token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds()) rule = RoleConstants[user.role] @@ -49,6 +68,26 @@ class TokenAuth(TokenAuthentication): permission_list += get_user_dynamics_permission(str(user.id)) return user, Auth(role_list=[rule], permission_list=permission_list) + if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get( + 'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=auth_details.get('application_id')).first() + if application_access_token is None: + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + if not application_access_token.is_active: + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + if not application_access_token.access_token == auth_details.get('access_token'): + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + return application_access_token.application.user, Auth( + role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN], + permission_list=[ + Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=str( + application_access_token.application_id))], + application_id=application_access_token.application_id + ) + else: raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") diff --git a/apps/common/auth/authentication.py b/apps/common/auth/authentication.py index 5a155bc66..b27e1d1ea 100644 --- a/apps/common/auth/authentication.py +++ b/apps/common/auth/authentication.py @@ -35,23 +35,30 @@ def exist_role_by_role_constants(user_role: List[RoleConstants], return any(list(map(lambda up: role_list.__contains__(up), user_role))) -def exist_permissions_by_view_permission(user_role: List[RoleConstants], user_permission: List[PermissionConstants], - permission: ViewPermission): +def exist_permissions_by_view_permission(user_role: List[RoleConstants], + user_permission: List[PermissionConstants | object], + permission: ViewPermission, request, **kwargs): """ 用户是否存在这些权限 + :param request: :param user_role: 用户角色 :param user_permission: 用户权限 :param permission: 所属权限 :return: 是否存在 True False """ role_ok = any(list(map(lambda ur: permission.roleList.__contains__(ur), user_role))) - permission_ok = any(list(map(lambda up: permission.permissionList.__contains__(up), user_permission))) + permission_list = [user_p(request, kwargs) if callable(user_p) else user_p for user_p in + permission.permissionList + ] + permission_ok = any(list(map(lambda up: permission_list.__contains__(up), + user_permission))) return role_ok | permission_ok if permission.compare == CompareConstants.OR else role_ok & permission_ok -def exist_permissions(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission): +def exist_permissions(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request, + **kwargs): if isinstance(permission, ViewPermission): - return exist_permissions_by_view_permission(user_role, user_permission, permission) + return exist_permissions_by_view_permission(user_role, user_permission, permission, request, **kwargs) elif isinstance(permission, RoleConstants): return exist_role_by_role_constants(user_role, [permission]) elif isinstance(permission, PermissionConstants): @@ -64,9 +71,9 @@ def exist_permissions(user_role: List[RoleConstants], user_permission: List[Perm def exist(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request, **kwargs): if callable(permission): p = permission(request, kwargs) - return exist_permissions(user_role, user_permission, p) + return exist_permissions(user_role, user_permission, p, request) else: - return exist_permissions(user_role, user_permission, permission) + return exist_permissions(user_role, user_permission, permission, request, **kwargs) def has_permissions(*permission, compare=CompareConstants.OR): diff --git a/apps/common/constants/authentication_type.py b/apps/common/constants/authentication_type.py new file mode 100644 index 000000000..f223d2a72 --- /dev/null +++ b/apps/common/constants/authentication_type.py @@ -0,0 +1,16 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: authentication_type.py + @date:2023/11/14 20:03 + @desc: +""" +from enum import Enum + + +class AuthenticationType(Enum): + # 或者 + USER = "USER" + # 并且 + APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN" diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index f948f9f42..e4106ed7e 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -21,6 +21,10 @@ class Group(Enum): SETTING = "SETTING" + MODEL = "MODEL" + + TEAM = "TEAM" + class Operate(Enum): """ @@ -40,16 +44,24 @@ class Operate(Enum): USE = "USE" +class RoleGroup(Enum): + USER = 'USER' + APPLICATION_KEY = "APPLICATION_KEY" + APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN" + + class Role: - def __init__(self, name: str, decs: str): + def __init__(self, name: str, decs: str, group: RoleGroup): self.name = name self.decs = decs + self.group = group class RoleConstants(Enum): - ADMIN = Role("管理员", "管理员,预制目前不会使用") - USER = Role("用户", "用户所有权限") - SESSION = Role("会话", decs="只拥有应用会话框接口权限") + ADMIN = Role("管理员", "管理员,预制目前不会使用", RoleGroup.USER) + USER = Role("用户", "用户所有权限", RoleGroup.USER) + APPLICATION_ACCESS_TOKEN = Role("会话", "只拥有应用会话框接口权限", RoleGroup.APPLICATION_ACCESS_TOKEN), + APPLICATION_KEY = Role("应用私钥", "应用私钥", RoleGroup.APPLICATION_KEY) class Permission: @@ -90,9 +102,29 @@ class PermissionConstants(Enum): APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + APPLICATION_CREATE = Permission(group=Group.APPLICATION, operate=Operate.CREATE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + SETTING_READ = Permission(group=Group.SETTING, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + MODEL_READ = Permission(group=Group.MODEL, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + MODEL_EDIT = Permission(group=Group.MODEL, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + MODEL_DELETE = Permission(group=Group.MODEL, operate=Operate.DELETE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + MODEL_CREATE = Permission(group=Group.MODEL, operate=Operate.CREATE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_READ = Permission(group=Group.TEAM, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_CREATE = Permission(group=Group.TEAM, operate=Operate.CREATE, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_DELETE = Permission(group=Group.TEAM, operate=Operate.DELETE, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_EDIT = Permission(group=Group.TEAM, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + def get_permission_list_by_role(role: RoleConstants): """ @@ -110,9 +142,11 @@ class Auth: 用于存储当前用户的角色和权限 """ - def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants]): + def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission], + **keywords): self.role_list = role_list self.permission_list = permission_list + self.keywords = keywords class CompareConstants(Enum): @@ -123,7 +157,7 @@ class CompareConstants(Enum): class ViewPermission: - def __init__(self, roleList: List[RoleConstants], permissionList: List[PermissionConstants], + def __init__(self, roleList: List[RoleConstants], permissionList: List[PermissionConstants | object], compare=CompareConstants.OR): self.roleList = roleList self.permissionList = permissionList diff --git a/apps/common/db/compiler.py b/apps/common/db/compiler.py index 81f21664f..9a65f93e1 100644 --- a/apps/common/db/compiler.py +++ b/apps/common/db/compiler.py @@ -19,7 +19,7 @@ class AppSQLCompiler(SQLCompiler): field_replace_dict = {} self.field_replace_dict = field_replace_dict - def get_query_str(self, with_limits=True, with_table_name=True): + def get_query_str(self, with_limits=True, with_table_name=False): refcounts_before = self.query.alias_refcount.copy() try: extra_select, order_by, group_by = self.pre_sql_setup() diff --git a/apps/common/db/search.py b/apps/common/db/search.py index 007fb9f27..f8d4a6878 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -32,9 +32,10 @@ def get_dynamics_model(attr: dict, table_name='dynamics'): def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str, - field_replace_dict: None | Dict[str, Dict[str, str]] = None): + field_replace_dict: None | Dict[str, Dict[str, str]] = None, with_table_name=False): """ 生成 查询sql + :param with_table_name: :param queryset_dict: 多条件 查询条件 :param select_string: 查询sql :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 @@ -45,7 +46,8 @@ def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string result_params = [] for key in queryset_dict.keys(): value = queryset_dict.get(key) - sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key)) + sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key), + with_table_name) params_dict = {**params_dict, select_string.index("${" + key + "}"): params} select_string = select_string.replace("${" + key + "}", sql) @@ -55,7 +57,7 @@ def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string def generate_sql_by_query(queryset: QuerySet, select_string: str, - field_replace_dict: None | Dict[str, str] = None): + field_replace_dict: None | Dict[str, str] = None, with_table_name=False): """ 生成 查询sql :param queryset: 查询条件 @@ -63,13 +65,14 @@ def generate_sql_by_query(queryset: QuerySet, select_string: str, :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 :return: sql:需要查询的sql params: sql 参数 """ - sql, params = compiler_queryset(queryset, field_replace_dict) + sql, params = compiler_queryset(queryset, field_replace_dict, with_table_name) return select_string + " " + sql, params -def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None): +def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None, with_table_name=False): """ 解析 queryset查询对象 + :param with_table_name: :param queryset: 查询对象 :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 :return: sql:需要查询的sql params: sql 参数 @@ -80,15 +83,16 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s field_replace_dict = get_field_replace_dict(queryset) app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, field_replace_dict=field_replace_dict) - sql, params = app_sql_compiler.get_query_str(with_table_name=False) + sql, params = app_sql_compiler.get_query_str(with_table_name) return sql, params def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str, field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None, - with_search_one=False): + with_search_one=False, with_table_name=False): """ 复杂查询 + :param with_table_name: 生成sql是否包含表名 :param queryset: 查询条件构造器 :param select_string: 查询前缀 不包括 where limit 等信息 :param field_replace_dict: 需要替换的字段 @@ -96,9 +100,9 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str, :return: 查询结果 """ if isinstance(queryset, Dict): - exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict) + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) else: - exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict) + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) if with_search_one: return select_one(exec_sql, exec_params) else: @@ -121,9 +125,11 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str, field_replace_dict=None, - post_records_handler=lambda r: r): + post_records_handler=lambda r: r, + with_table_name=False): """ 复杂分页查询 + :param with_table_name: :param current_page: 当前页 :param page_size: 每页大小 :param queryset: 查询条件 @@ -133,9 +139,9 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet | D :return: 分页结果 """ if isinstance(queryset, Dict): - exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict) + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) else: - exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict) + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql total = select_one(total_sql, exec_params) limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql( diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py new file mode 100644 index 000000000..cb278a862 --- /dev/null +++ b/apps/common/event/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/11/10 10:43 + @desc: +""" +from .listener_manage import * +from .listener_chat_message import * + + +def run(): + listener_manage.ListenerManagement().run() + listener_chat_message.ListenerChatMessage().run() diff --git a/apps/common/event/common.py b/apps/common/event/common.py new file mode 100644 index 000000000..bd553f07f --- /dev/null +++ b/apps/common/event/common.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2023/11/10 10:41 + @desc: +""" +from concurrent.futures import ThreadPoolExecutor + +work_thread_pool = ThreadPoolExecutor(5) + + +def poxy(poxy_function): + def inner(args): + work_thread_pool.submit(poxy_function, args) + + return inner diff --git a/apps/common/event/listener_chat_message.py b/apps/common/event/listener_chat_message.py new file mode 100644 index 000000000..8415e3a7a --- /dev/null +++ b/apps/common/event/listener_chat_message.py @@ -0,0 +1,54 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: listener_manage.py + @date:2023/10/20 14:01 + @desc: +""" + +from blinker import signal +from django.db.models import QuerySet + +from application.models import ChatRecord, Chat +from application.serializers.chat_message_serializers import ChatMessage +from common.event.common import poxy + + +class RecordChatMessageArgs: + def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage): + self.index = index + self.chat_id = chat_id + self.application_id = application_id + self.chat_message = chat_message + + +class ListenerChatMessage: + record_chat_message_signal = signal("record_chat_message") + + @staticmethod + @poxy + def record_chat_message(args: RecordChatMessageArgs): + if not QuerySet(Chat).filter(id=args.chat_id).exists(): + Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save() + # 插入会话记录 + try: + chat_record = ChatRecord( + id=args.chat_message.id, + chat_id=args.chat_id, + dataset_id=args.chat_message.dataset_id, + paragraph_id=args.chat_message.paragraph_id, + source_id=args.chat_message.source_id, + source_type=args.chat_message.source_type, + problem_text=args.chat_message.problem, + answer_text=args.chat_message.answer, + index=args.index, + message_tokens=args.chat_message.message_tokens, + answer_tokens=args.chat_message.answer_token) + chat_record.save() + except Exception as e: + print(e) + + def run(self): + # 记录会话 + ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 6b3c85af6..5f84ad335 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -7,7 +7,6 @@ @desc: """ import os -from concurrent.futures import ThreadPoolExecutor import django.db.models from blinker import signal @@ -15,22 +14,14 @@ from django.db.models import QuerySet from common.config.embedding_config import VectorStore, EmbeddingModel from common.db.search import native_search, get_dynamics_model +from common.event.common import poxy from common.util.file_util import get_file_content from dataset.models import Paragraph, Status, Document from embedding.models import SourceType from smartdoc.conf import PROJECT_DIR -def poxy(poxy_function): - def inner(args): - ListenerManagement.work_thread_pool.submit(poxy_function, args) - - return inner - - class ListenerManagement: - work_thread_pool = ThreadPoolExecutor(5) - embedding_by_problem_signal = signal("embedding_by_problem") embedding_by_paragraph_signal = signal("embedding_by_paragraph") embedding_by_dataset_signal = signal("embedding_by_dataset") diff --git a/apps/common/exception/app_exception.py b/apps/common/exception/app_exception.py index 2a50c1e6d..ffa9e91e0 100644 --- a/apps/common/exception/app_exception.py +++ b/apps/common/exception/app_exception.py @@ -20,6 +20,17 @@ class AppApiException(Exception): self.message = message +class NotFound404(AppApiException): + """ + 未认证(未登录)异常 + """ + status_code = status.HTTP_404_NOT_FOUND + + def __init__(self, code, message): + self.code = code + self.message = message + + class AppAuthenticationFailed(AppApiException): """ 未认证(未登录)异常 diff --git a/apps/common/froms/__init__.py b/apps/common/froms/__init__.py new file mode 100644 index 000000000..5a1f3be58 --- /dev/null +++ b/apps/common/froms/__init__.py @@ -0,0 +1,22 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/10/31 17:56 + @desc: +""" +from .array_card import * +from .base_field import * +from .base_form import * +from .combobox_field import * +from .multi_select import * +from .number_input_field import * +from .object_card import * +from .password_input import * +from .radio_field import * +from .single_select_field import * +from .switch_btn import * +from .tab_card import * +from .table_radio import * +from .text_input_field import * diff --git a/apps/common/froms/array_card.py b/apps/common/froms/array_card.py new file mode 100644 index 000000000..1e4daa303 --- /dev/null +++ b/apps/common/froms/array_card.py @@ -0,0 +1,36 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: array_card.py + @date:2023/10/31 18:03 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import BaseExecField, TriggerType + + +class ArrayCard(BaseExecField): + """ + 收集List[Object] + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("ArrayCard", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) diff --git a/apps/common/froms/base_field.py b/apps/common/froms/base_field.py new file mode 100644 index 000000000..21da03868 --- /dev/null +++ b/apps/common/froms/base_field.py @@ -0,0 +1,159 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_field.py + @date:2023/10/31 18:07 + @desc: +""" +from enum import Enum +from typing import List, Dict + + +class TriggerType(Enum): + # 执行函数获取 OptionList数据 + OPTION_LIST = 'OPTION_LIST' + # 执行函数获取子表单 + CHILD_FORMS = 'CHILD_FORMS' + + +class BaseField: + def __init__(self, + input_type: str, + label: str, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + """ + + :param input_type: 字段 + :param label: 提示 + :param default_value: 默认值 + :param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示 + :param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段 + :param relation_trigger_field_list: 指定那些字段有值的时候 调用当前字段的 执行函数获取optionList数据 + :param relation_trigger_value_list: 指定那些字段有值 并且值在relation_trigger_value_list列表中 则执行函数获取optionList数据 + :param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单 + :param attrs: 前端attr数据 + :param props_info: 其他额外信息 + """ + if props_info is None: + props_info = {} + if attrs is None: + attrs = {} + self.label = label + self.attrs = attrs + self.props_info = props_info + self.default_value = default_value + self.input_type = input_type + self.relation_show_field_list = [] if relation_show_field_list is None else relation_show_field_list + self.relation_show_value_list = [] if relation_show_value_list is None else relation_show_value_list + self.relation_trigger_field_list = [] if relation_trigger_field_list is None else relation_trigger_field_list + self.relation_trigger_value_field_list = [] if relation_trigger_value_list is None else relation_trigger_value_list + self.required = required + self.trigger_type = trigger_type + + def to_dict(self): + return { + 'input_type': self.input_type, + 'label': self.label, + 'required': self.required, + 'default_value': self.default_value, + 'relation_show_field_list': self.relation_show_field_list, + 'relation_show_value_list': self.relation_show_value_list, + 'relation_trigger_field_list': self.relation_trigger_field_list, + 'relation_trigger_value_field_list': self.relation_trigger_value_field_list, + 'trigger_type': self.trigger_type.value, + 'attrs': self.attrs, + 'props_info': self.props_info, + } + + +class BaseDefaultOptionField(BaseField): + def __init__(self, input_type: str, + label: str, + text_field: str, + value_field: str, + option_list: List[dict], + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + """ + + :param input_type: 字段 + :param label: label + :param text_field: 文本字段 + :param value_field: 值字段 + :param option_list: 可选列表 + :param required: 是否必填 + :param default_value: 默认值 + :param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示 + :param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段 + :param attrs: 前端attr数据 + :param props_info: 其他额外信息 + """ + super().__init__(input_type, label, required, default_value, relation_show_field_list, relation_show_value_list, + [], [], TriggerType.OPTION_LIST, attrs, props_info) + self.text_field = text_field + self.value_field = value_field + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field, + 'option_list': self.option_list} + + +class BaseExecField(BaseField): + def __init__(self, + input_type: str, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + """ + + :param input_type: 字段 + :param label: 提示 + :param text_field: 文本字段 + :param value_field: 值字段 + :param provider: 指定供应商 + :param method: 执行供应商函数 method + :param required: 是否必填 + :param default_value: 默认值 + :param relation_show_field_list: 指定那些当那些字段有值的时候 当前字段显示 + :param relation_show_value_list: 指定字段有值 并且值在relation_show_value_list列表中则显示当前字段 + :param relation_trigger_field_list: 指定那些字段有值的时候 调用当前字段的 执行函数获取optionList数据 + :param relation_trigger_value_list: 指定那些字段有值 并且值在relation_trigger_value_list列表中 则执行函数获取optionList数据 + :param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单 + :param attrs: 前端attr数据 + :param props_info: 其他额外信息 + """ + super().__init__(input_type, label, required, default_value, relation_show_field_list, relation_show_value_list, + relation_trigger_field_list, relation_trigger_value_list, trigger_type, attrs, props_info) + self.text_field = text_field + self.value_field = value_field + self.provider = provider + self.method = method + + def to_dict(self): + return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field, + 'provider': self.provider, 'method': self.method} diff --git a/apps/common/froms/base_form.py b/apps/common/froms/base_form.py new file mode 100644 index 000000000..49e4556fa --- /dev/null +++ b/apps/common/froms/base_form.py @@ -0,0 +1,16 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_form.py + @date:2023/11/1 16:04 + @desc: +""" +from common.froms import BaseField + + +class BaseForm: + def to_form_list(self): + return [{**self.__getattribute__(key).to_dict(), 'field': key} for key in + list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField), + [attr for attr in vars(self.__class__) if not attr.startswith("__")]))] diff --git a/apps/common/froms/combobox_field.py b/apps/common/froms/combobox_field.py new file mode 100644 index 000000000..2a506f852 --- /dev/null +++ b/apps/common/froms/combobox_field.py @@ -0,0 +1,41 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: combobox_field.py + @date:2023/10/31 17:59 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import BaseExecField, TriggerType + + +class Combobox(BaseExecField): + """ + 多选框 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str = None, + method: str = None, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("Combobox", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} diff --git a/apps/common/froms/multi_select.py b/apps/common/froms/multi_select.py new file mode 100644 index 000000000..12c41f1e1 --- /dev/null +++ b/apps/common/froms/multi_select.py @@ -0,0 +1,41 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: multi_select.py + @date:2023/10/31 18:00 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import BaseExecField, TriggerType + + +class MultiSelect(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str = None, + method: str = None, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("MultiSelect", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} \ No newline at end of file diff --git a/apps/common/froms/number_input_field.py b/apps/common/froms/number_input_field.py new file mode 100644 index 000000000..59d8e7e08 --- /dev/null +++ b/apps/common/froms/number_input_field.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: number_input_field.py + @date:2023/10/31 17:58 + @desc: +""" +from typing import List + +from common.froms.base_field import BaseField, TriggerType + + +class NumberInput(BaseField): + """ + 文本输入框 + """ + + def __init__(self, label: str, + required: bool = False, + default_value=None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + attrs=None, props_info=None): + super().__init__('NumberInput', label, required, default_value, relation_show_field_list, + relation_show_value_list, [], [], + TriggerType.OPTION_LIST, attrs, props_info) diff --git a/apps/common/froms/object_card.py b/apps/common/froms/object_card.py new file mode 100644 index 000000000..3e9075dce --- /dev/null +++ b/apps/common/froms/object_card.py @@ -0,0 +1,36 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: object_card.py + @date:2023/10/31 18:02 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import BaseExecField, TriggerType + + +class ObjectCard(BaseExecField): + """ + 收集对象子表卡片 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("ObjectCard", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) diff --git a/apps/common/froms/password_input.py b/apps/common/froms/password_input.py new file mode 100644 index 000000000..81b9d697b --- /dev/null +++ b/apps/common/froms/password_input.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: password_input.py + @date:2023/11/1 14:48 + @desc: +""" +from typing import List + +from common.froms import BaseField, TriggerType + + +class PasswordInputField(BaseField): + """ + 文本输入框 + """ + + def __init__(self, label: str, + required: bool = False, + default_value=None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + attrs=None, props_info=None): + super().__init__('TextInput', label, required, default_value, relation_show_field_list, + relation_show_value_list, [], [], + TriggerType.OPTION_LIST, attrs, props_info) diff --git a/apps/common/froms/radio_field.py b/apps/common/froms/radio_field.py new file mode 100644 index 000000000..ce2146070 --- /dev/null +++ b/apps/common/froms/radio_field.py @@ -0,0 +1,41 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: radio_field.py + @date:2023/10/31 17:59 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import BaseExecField, TriggerType + + +class Radio(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("Radio", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} \ No newline at end of file diff --git a/apps/common/froms/single_select_field.py b/apps/common/froms/single_select_field.py new file mode 100644 index 000000000..9d3ecb2a9 --- /dev/null +++ b/apps/common/froms/single_select_field.py @@ -0,0 +1,41 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: single_select_field.py + @date:2023/10/31 18:00 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import TriggerType, BaseExecField + + +class SingleSelect(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str = None, + method: str = None, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("SingleSelect", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} diff --git a/apps/common/froms/switch_btn.py b/apps/common/froms/switch_btn.py new file mode 100644 index 000000000..32aadbd5f --- /dev/null +++ b/apps/common/froms/switch_btn.py @@ -0,0 +1,28 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: switch_btn.py + @date:2023/10/31 18:00 + @desc: +""" +from typing import List + +from common.froms.base_field import TriggerType, BaseField + + +class SwitchBtn(BaseField): + """ + 开关 + """ + + def __init__(self, + label: str, + required: bool = False, + default_value=None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + attrs=None, props_info=None): + super().__init__('SwitchBtn', label, required, default_value, relation_show_field_list, + relation_show_value_list, [], [], + TriggerType.OPTION_LIST, attrs, props_info) diff --git a/apps/common/froms/tab_card.py b/apps/common/froms/tab_card.py new file mode 100644 index 000000000..50caa3c79 --- /dev/null +++ b/apps/common/froms/tab_card.py @@ -0,0 +1,36 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: tab_card.py + @date:2023/10/31 18:03 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import BaseExecField, TriggerType + + +class TabCard(BaseExecField): + """ + 收集 Tab类型数据 tab1:{},tab2:{} + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("TabCard", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) diff --git a/apps/common/froms/table_radio.py b/apps/common/froms/table_radio.py new file mode 100644 index 000000000..ef7cb7f8d --- /dev/null +++ b/apps/common/froms/table_radio.py @@ -0,0 +1,36 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: table_radio.py + @date:2023/10/31 18:01 + @desc: +""" +from typing import List, Dict + +from common.froms.base_field import TriggerType, BaseExecField + + +class TableRadio(BaseExecField): + """ + table 单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + relation_trigger_field_list: List[str] = None, + relation_trigger_value_list: List[str] = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("TableRadio", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_list, relation_show_value_list, relation_trigger_field_list, + relation_trigger_value_list, trigger_type, attrs, props_info) diff --git a/apps/common/froms/text_input_field.py b/apps/common/froms/text_input_field.py new file mode 100644 index 000000000..1fceb65cc --- /dev/null +++ b/apps/common/froms/text_input_field.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_input_field.py + @date:2023/10/31 17:58 + @desc: +""" +from typing import List + +from common.froms.base_field import BaseField, TriggerType + + +class TextInputField(BaseField): + """ + 文本输入框 + """ + + def __init__(self, label: str, + required: bool = False, + default_value=None, + relation_show_field_list: List[str] = None, + relation_show_value_list: List[str] = None, + attrs=None, props_info=None): + super().__init__('TextInput', label, required, default_value, relation_show_field_list, + relation_show_value_list, [], [], + TriggerType.OPTION_LIST, attrs, props_info) diff --git a/apps/common/sql/list_embedding_text.sql b/apps/common/sql/list_embedding_text.sql index 2c55697ec..f45697b47 100644 --- a/apps/common/sql/list_embedding_text.sql +++ b/apps/common/sql/list_embedding_text.sql @@ -5,7 +5,9 @@ SELECT problem.dataset_id AS dataset_id, 0 AS source_type, problem."content" AS "text", - paragraph.is_active AS is_active + paragraph.is_active AS is_active, + problem.star_num as star_num, + problem.trample_num as trample_num FROM problem problem LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id @@ -18,8 +20,10 @@ SELECT paragraph."id" AS paragraph_id, paragraph.dataset_id AS dataset_id, 1 AS source_type, - paragraph."content" AS "text", - paragraph.is_active AS is_active + concat_ws(':',paragraph."title",paragraph."content") AS "text", + paragraph.is_active AS is_active, + paragraph.star_num as star_num, + paragraph.trample_num as trample_num FROM paragraph paragraph diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 0fce2affd..6cf4b29c8 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -6,14 +6,27 @@ @date:2023/10/16 16:42 @desc: """ +import importlib from functools import reduce from typing import Dict def query_params_to_single_dict(query_params: Dict): - return reduce(lambda x, y: {**x, y[0]: y[1]}, list(filter(lambda row: row[1] is not None, - list(map(lambda row: ( - row[0], row[1][0] if isinstance(row[1][0], - list) and len( - row[1][0]) > 0 else row[1][0]), - query_params.items())))), {}) + return reduce(lambda x, y: {**x, **y}, list( + filter(lambda item: item is not None, [({key: value} if value is not None and len(value) > 0 else None) for + key, value in + query_params.items()])), {}) + + +def get_exec_method(clazz_: str, method_: str): + """ + 根据 class 和method函数 获取执行函数 + :param clazz_: class 字符串 + :param method_: 执行函数 + :return: 执行函数 + """ + clazz_split = clazz_.split('.') + clazz_name = clazz_split[-1] + package = ".".join([clazz_split[index] for index in range(len(clazz_split) - 1)]) + package_model = importlib.import_module(package) + return getattr(getattr(package_model, clazz_name), method_) diff --git a/apps/common/util/rsa_util.py b/apps/common/util/rsa_util.py new file mode 100644 index 000000000..ed950776a --- /dev/null +++ b/apps/common/util/rsa_util.py @@ -0,0 +1,70 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: rsa_util.py + @date:2023/11/3 11:13 + @desc: +""" +import base64 +import os + +from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher +from Crypto.PublicKey import RSA + +# 对密钥加密的密码 +secret_code = "mac_kb_password" + + +def generate(): + """ + 生成 私钥秘钥对 + :return:{key:'公钥',value:'私钥'} + """ + # 生成一个 2048 位的密钥 + key = RSA.generate(2048) + + # 获取私钥 + encrypted_key = key.export_key(passphrase=secret_code, pkcs=8, + protection="scryptAndAES128-CBC") + return {'key': key.publickey().export_key(), 'value': encrypted_key} + + +def get_key_pair(): + if not os.path.exists("/opt/maxkb/conf/receiver.pem"): + kv = generate() + private_file_out = open("/opt/maxkb/conf/private.pem", "wb") + private_file_out.write(kv.get('value')) + private_file_out.close() + receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb") + receiver_file_out.write(kv.get('key')) + receiver_file_out.close() + return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()} + + +def encrypt(msg, public_key: str | None = None): + """ + 加密 + :param msg: 加密数据 + :param public_key: 公钥 + :return: 加密后的数据 + """ + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(public_key)) + encrypt_msg = cipher.encrypt(msg.encode("utf-8")) + return base64.b64encode(encrypt_msg).decode() + + +def decrypt(msg, pri_key: str | None = None): + """ + 解密 + :param msg: 需要解密的数据 + :param pri_key: 私钥 + :return: 解密后数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) + return decrypt_data.decode("utf-8") diff --git a/apps/common/util/test.py b/apps/common/util/test.py new file mode 100644 index 000000000..f271cf058 --- /dev/null +++ b/apps/common/util/test.py @@ -0,0 +1,79 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: test.py + @date:2023/11/15 15:13 + @desc: +""" +import time +from django.core import signing +import hashlib +from django.core.cache import cache + +# alg使用的算法 +HEADER = {'typ': 'JWP', 'alg': 'default'} +TOKEN_KEY = 'solomon_world_token' +TOKEN_SALT = 'solomonwanc@gmail.com' +TIME_OUT = 30 * 60 + +# 加密 +def encrypt(obj): + value = signing.dumps(obj, key=TOKEN_KEY, salt=TOKEN_SALT) + value = signing.b64_encode(value.encode()).decode() + return value + + +# 解密 +def decrypt(src): + src = signing.b64_decode(src.encode()).decode() + raw = signing.loads(src, key=TOKEN_KEY, salt=TOKEN_SALT) + print(type(raw)) + return raw + + +# 生成token信息 +def create_token(username, password): + # 1. 加密头信息 + header = encrypt(HEADER) + # 2. 构造Payload + payload = { + "username": username, + "password": password, + "iat": time.time() + } + payload = encrypt(payload) + # 3. 生成签名 + md5 = hashlib.md5() + md5.update(("%s.%s" % (header, payload)).encode()) + signature = md5.hexdigest() + token = "%s.%s.%s" % (header, payload, signature) + # 4.存储到缓存中 + cache.set(username, token, TIME_OUT) + return token + + +def get_payload(token): + payload = str(token).split('.')[1] + payload = decrypt(payload) + return payload + + +# 通过token获取用户名 +def get_username(token): + payload = get_payload(token) + return payload['username'] + pass + + +def check_token(token): + username = get_username(token) + print('username', username) + last_token = cache.get(username) + if last_token: + return last_token == token + return False + + +if __name__ == '__main__': + token = create_token('zhangsan', 'lisi') diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 5673c1999..c199f842f 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -98,13 +98,15 @@ class DataSetSerializers(serializers.ModelSerializer): query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model( {'user_id': models.CharField(), + 'team_member_permission.auth_target_type': models.CharField(), 'team_member_permission.operate': ArrayField(verbose_name="权限操作列表", base_field=models.CharField(max_length=256, blank=True, choices=AuthOperate.choices, default=AuthOperate.USE) )})).filter( - **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']}) + **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'], + 'team_member_permission.auth_target_type': 'DATASET'}) return query_set_dict diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 9e4e7867a..a604e3a05 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -19,7 +19,7 @@ from common.db.search import page_search from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin -from dataset.models import Paragraph, Problem +from dataset.models import Paragraph, Problem, Document from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer @@ -41,7 +41,7 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer): message="段落在1-1024个字符之间") ]) - title = serializers.CharField(required=False) + title = serializers.CharField(required=False, allow_null=True, allow_blank=True) problem_list = ProblemInstanceSerializer(required=False, many=True) @@ -165,6 +165,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True) + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Document).filter(id=self.data.get('document_id'), + dataset_id=self.data.get('dataset_id')).exists(): + raise AppApiException(500, "文档id不正确") + def save(self, instance: Dict, with_valid=True, with_embedding=True): if with_valid: ParagraphSerializers(data=instance).is_valid(raise_exception=True) diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 16c9ff1fa..6845bb34e 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -54,6 +54,13 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): paragraph_id = serializers.UUIDField(required=True) + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'), + document_id=self.data.get('document_id'), + dataset_id=self.data.get('dataset_id')).exists(): + raise AppApiException(500, "段落id不正确") + def save(self, instance: Dict, with_valid=True, with_embedding=True): if with_valid: self.is_valid() diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 175bedc40..20cd37321 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -26,7 +26,8 @@ class Dataset(APIView): @swagger_auto_schema(operation_summary="获取数据集列表", operation_id="获取数据集列表", manual_parameters=DataSetSerializers.Query.get_request_params_api(), - responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api())) + responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()), + tags=["数据集"]) @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) def get(self, request: Request): d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) @@ -37,7 +38,9 @@ class Dataset(APIView): @swagger_auto_schema(operation_summary="创建数据集", operation_id="创建数据集", request_body=DataSetSerializers.Create.get_request_body_api(), - responses=get_api_response(DataSetSerializers.Create.get_response_body_api())) + responses=get_api_response(DataSetSerializers.Create.get_response_body_api()), + tags=["数据集"] + ) @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) def post(self, request: Request): s = DataSetSerializers.Create(data=request.data) @@ -50,7 +53,8 @@ class Dataset(APIView): @action(methods="DELETE", detail=False) @swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集", manual_parameters=DataSetSerializers.Operate.get_request_params_api(), - responses=result.get_default_response()) + responses=result.get_default_response(), + tags=["数据集"]) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=keywords.get('dataset_id')), lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE, @@ -62,7 +66,8 @@ class Dataset(APIView): @action(methods="GET", detail=False) @swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id", manual_parameters=DataSetSerializers.Operate.get_request_params_api(), - responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) + responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()), + tags=["数据集"]) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=keywords.get('dataset_id'))) def get(self, request: Request, dataset_id: str): @@ -72,7 +77,9 @@ class Dataset(APIView): @swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息", manual_parameters=DataSetSerializers.Operate.get_request_params_api(), request_body=DataSetSerializers.Operate.get_request_body_api(), - responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) + responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()), + tags=["数据集"] + ) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=keywords.get('dataset_id'))) def put(self, request: Request, dataset_id: str): @@ -87,7 +94,9 @@ class Dataset(APIView): operation_id="获取数据集分页列表", manual_parameters=get_page_request_params( DataSetSerializers.Query.get_request_params_api()), - responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api())) + responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()), + tags=["数据集"] + ) @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) def get(self, request: Request, current_page, page_size): d = DataSetSerializers.Query( diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 9354e7b25..eaf4cdba0 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -28,7 +28,8 @@ class Document(APIView): operation_id="创建文档", request_body=DocumentSerializers.Create.get_request_body_api(), manual_parameters=DocumentSerializers.Create.get_request_params_api(), - responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())) + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["数据集/文档"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) @@ -40,7 +41,8 @@ class Document(APIView): @swagger_auto_schema(operation_summary="文档列表", operation_id="文档列表", manual_parameters=DocumentSerializers.Query.get_request_params_api(), - responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api())) + responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()), + tags=["数据集/文档"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=k.get('dataset_id'))) @@ -57,7 +59,8 @@ class Document(APIView): @swagger_auto_schema(operation_summary="获取文档详情", operation_id="获取文档详情", manual_parameters=DocumentSerializers.Operate.get_request_params_api(), - responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())) + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["数据集/文档"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=k.get('dataset_id'))) @@ -71,7 +74,8 @@ class Document(APIView): operation_id="修改文档", manual_parameters=DocumentSerializers.Operate.get_request_params_api(), request_body=DocumentSerializers.Operate.get_request_body_api(), - responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()) + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["数据集/文档"] ) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, @@ -86,7 +90,8 @@ class Document(APIView): @swagger_auto_schema(operation_summary="删除文档", operation_id="删除文档", manual_parameters=DocumentSerializers.Operate.get_request_params_api(), - responses=result.get_default_response()) + responses=result.get_default_response(), + tags=["数据集/文档"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) @@ -101,7 +106,9 @@ class Document(APIView): @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="分段文档", operation_id="分段文档", - manual_parameters=DocumentSerializers.Split.get_request_params_api()) + manual_parameters=DocumentSerializers.Split.get_request_params_api(), + tags=["数据集/文档"], + security=[]) def post(self, request: Request): ds = DocumentSerializers.Split( data={'file': request.FILES.getlist('file'), @@ -116,7 +123,8 @@ class Document(APIView): @swagger_auto_schema(operation_summary="获取数据集分页列表", operation_id="获取数据集分页列表", manual_parameters=DocumentSerializers.Query.get_request_params_api(), - responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api())) + responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()), + tags=["数据集/文档"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=k.get('dataset_id'))) diff --git a/apps/dataset/views/paragraph.py b/apps/dataset/views/paragraph.py index 37fe76fda..6162c939b 100644 --- a/apps/dataset/views/paragraph.py +++ b/apps/dataset/views/paragraph.py @@ -25,7 +25,9 @@ class Paragraph(APIView): @swagger_auto_schema(operation_summary="段落列表", operation_id="段落列表", manual_parameters=ParagraphSerializers.Query.get_request_params_api(), - responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api())) + responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()), + tags=["数据集/文档/段落"] + ) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=k.get('dataset_id'))) @@ -41,7 +43,8 @@ class Paragraph(APIView): operation_id="创建段落", manual_parameters=ParagraphSerializers.Create.get_request_params_api(), request_body=ParagraphSerializers.Create.get_request_body_api(), - responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api())) + responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()), + tags=["数据集/文档/段落"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) @@ -57,7 +60,8 @@ class Paragraph(APIView): operation_id="修改段落数据", manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), request_body=ParagraphSerializers.Operate.get_request_body_api(), - responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())) + responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()) + ,tags=["数据集/文档/段落"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) @@ -71,7 +75,8 @@ class Paragraph(APIView): @swagger_auto_schema(operation_summary="获取段落详情", operation_id="获取段落详情", manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), - responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())) + responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()), + tags=["数据集/文档/段落"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=k.get('dataset_id'))) @@ -85,7 +90,8 @@ class Paragraph(APIView): @swagger_auto_schema(operation_summary="删除段落", operation_id="删除段落", manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), - responses=result.get_default_response()) + responses=result.get_default_response(), + tags=["数据集/文档/段落"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) @@ -103,7 +109,8 @@ class Paragraph(APIView): operation_id="分页获取段落列表", manual_parameters=result.get_page_request_params( ParagraphSerializers.Query.get_request_params_api()), - responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api())) + responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()), + tags=["数据集/文档/段落"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=k.get('dataset_id'))) diff --git a/apps/dataset/views/problem.py b/apps/dataset/views/problem.py index d12064fbb..5e77a880a 100644 --- a/apps/dataset/views/problem.py +++ b/apps/dataset/views/problem.py @@ -25,7 +25,8 @@ class Problem(APIView): operation_id="添加段落关联问题", manual_parameters=ProblemSerializers.Create.get_request_params_api(), request_body=ProblemSerializers.Create.get_request_body_api(), - responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api())) + responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()), + tags=["数据集/文档/段落/问题"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) @@ -38,7 +39,8 @@ class Problem(APIView): @swagger_auto_schema(operation_summary="获取段落问题列表", operation_id="获取段落问题列表", manual_parameters=ProblemSerializers.Query.get_request_params_api(), - responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api())) + responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()), + tags=["数据集/文档/段落/问题"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=k.get('dataset_id'))) @@ -54,7 +56,8 @@ class Problem(APIView): @swagger_auto_schema(operation_summary="删除段落问题", operation_id="删除段落问题", manual_parameters=ProblemSerializers.Query.get_request_params_api(), - responses=result.get_default_response()) + responses=result.get_default_response(), + tags=["数据集/文档/段落/问题"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) diff --git a/apps/embedding/migrations/0002_embedding_star_num_embedding_trample_num.py b/apps/embedding/migrations/0002_embedding_star_num_embedding_trample_num.py new file mode 100644 index 000000000..0b024fd67 --- /dev/null +++ b/apps/embedding/migrations/0002_embedding_star_num_embedding_trample_num.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.10 on 2023-11-09 07:45 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('embedding', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='embedding', + name='star_num', + field=models.IntegerField(default=0, verbose_name='点赞数量'), + ), + migrations.AddField( + model_name='embedding', + name='trample_num', + field=models.IntegerField(default=0, verbose_name='点踩数量'), + ), + ] diff --git a/apps/embedding/migrations/0003_alter_embedding_unique_together.py b/apps/embedding/migrations/0003_alter_embedding_unique_together.py new file mode 100644 index 000000000..dacda192a --- /dev/null +++ b/apps/embedding/migrations/0003_alter_embedding_unique_together.py @@ -0,0 +1,17 @@ +# Generated by Django 4.1.10 on 2023-11-14 07:30 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('embedding', '0002_embedding_star_num_embedding_trample_num'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='embedding', + unique_together={('source_id', 'source_type')}, + ), + ] diff --git a/apps/embedding/migrations/0004_alter_embedding_source_type.py b/apps/embedding/migrations/0004_alter_embedding_source_type.py new file mode 100644 index 000000000..d2b089f74 --- /dev/null +++ b/apps/embedding/migrations/0004_alter_embedding_source_type.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.10 on 2023-11-14 08:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('embedding', '0003_alter_embedding_unique_together'), + ] + + operations = [ + migrations.AlterField( + model_name='embedding', + name='source_type', + field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('-1', '其他')], default='0', max_length=5, verbose_name='资源类型'), + ), + ] diff --git a/apps/embedding/migrations/0005_alter_embedding_source_type.py b/apps/embedding/migrations/0005_alter_embedding_source_type.py new file mode 100644 index 000000000..7095dd4d6 --- /dev/null +++ b/apps/embedding/migrations/0005_alter_embedding_source_type.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.10 on 2023-11-14 08:11 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('embedding', '0004_alter_embedding_source_type'), + ] + + operations = [ + migrations.AlterField( + model_name='embedding', + name='source_type', + field=models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=5, verbose_name='资源类型'), + ), + ] diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index e3606f664..83f786c68 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -23,7 +23,7 @@ class Embedding(models.Model): source_id = models.CharField(max_length=128, verbose_name="资源id") - source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices, + source_type = models.CharField(verbose_name='资源类型', max_length=5, choices=SourceType.choices, default=SourceType.PROBLEM) is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True) @@ -36,5 +36,11 @@ class Embedding(models.Model): embedding = VectorField(verbose_name="向量") + star_num = models.IntegerField(default=0, verbose_name="点赞数量") + + trample_num = models.IntegerField(default=0, + verbose_name="点踩数量") + class Meta: db_table = "embedding" + unique_together = ['source_id', 'source_type'] diff --git a/apps/embedding/sql/embedding_search.sql b/apps/embedding/sql/embedding_search.sql new file mode 100644 index 000000000..7787eb60b --- /dev/null +++ b/apps/embedding/sql/embedding_search.sql @@ -0,0 +1,15 @@ +SELECT * FROM (SELECT + *, + ( 1 - ( embedding.embedding <=> %s ) ) AS similarity, +CASE + + WHEN embedding.star_num - embedding.trample_num = 0 THEN + 0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) ) + END AS score +FROM + embedding, + ( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs + ${embedding_query} + ) temp + WHERE similarity>0.5 + ORDER BY (similarity + score) DESC LIMIT 1 \ No newline at end of file diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 90d57abb9..ecd9bdfaf 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -47,6 +47,8 @@ class BaseVectorStore(ABC): def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, + star_num: int, + trample_num: int, embedding=None): """ 插入向量数据 @@ -58,12 +60,15 @@ class BaseVectorStore(ABC): :param is_active: 是否禁用 :param embedding: 向量化处理器 :param paragraph_id 段落id + :param star_num 点赞数量 + :param trample_num 点踩数量 :return: bool """ if embedding is None: embedding = EmbeddingModel.get_embedding_model() self.save_pre_handler() - self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) + self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num, + trample_num, embedding) def batch_save(self, data_list: List[Dict], embedding=None): """ @@ -81,6 +86,8 @@ class BaseVectorStore(ABC): @abstractmethod def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, + star_num: int, + trample_num: int, embedding: HuggingFaceEmbeddings): pass @@ -89,7 +96,10 @@ class BaseVectorStore(ABC): pass @abstractmethod - def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings): + def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_id_list: list[str], + is_active: bool, + embedding: HuggingFaceEmbeddings): pass @abstractmethod diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 2daa8cbed..b58e964d5 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -6,14 +6,20 @@ @date:2023/10/19 15:28 @desc: """ +import json +import os import uuid from typing import Dict, List from django.db.models import QuerySet from langchain.embeddings import HuggingFaceEmbeddings +from common.db.search import native_search, generate_sql_by_query_dict +from common.db.sql_execute import select_one +from common.util.file_util import get_file_content from embedding.models import Embedding, SourceType from embedding.vector.base_vector import BaseVectorStore +from smartdoc.conf import PROJECT_DIR class PGVector(BaseVectorStore): @@ -27,6 +33,8 @@ class PGVector(BaseVectorStore): def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, + star_num: int, + trample_num: int, embedding: HuggingFaceEmbeddings): text_embedding = embedding.embed_query(text) embedding = Embedding(id=uuid.uuid1(), @@ -37,6 +45,8 @@ class PGVector(BaseVectorStore): source_id=source_id, embedding=text_embedding, source_type=source_type, + star_num=star_num, + trample_num=trample_num ) embedding.save() return True @@ -44,19 +54,41 @@ class PGVector(BaseVectorStore): def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) - QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(), - document_id=text_list[index].get('document_id'), - paragraph_id=text_list[index].get('paragraph_id'), - dataset_id=text_list[index].get('dataset_id'), - is_active=text_list[index].get('is_active', True), - source_id=text_list[index].get('source_id'), - source_type=text_list[index].get('source_type'), - embedding=embeddings[index]) for index in - range(0, len(text_list))]) if len(text_list) > 0 else None + embedding_list = [Embedding(id=uuid.uuid1(), + document_id=text_list[index].get('document_id'), + paragraph_id=text_list[index].get('paragraph_id'), + dataset_id=text_list[index].get('dataset_id'), + is_active=text_list[index].get('is_active', True), + source_id=text_list[index].get('source_id'), + source_type=text_list[index].get('source_type'), + star_num=text_list[index].get('star_num'), + trample_num=text_list[index].get('trample_num'), + embedding=embeddings[index]) for index in + range(0, len(text_list))] + QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True - def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings): - pass + def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_id_list: list[str], + is_active: bool, + embedding: HuggingFaceEmbeddings): + exclude_dict = {} + if dataset_id_list is None or len(dataset_id_list) == 0: + return None + query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active) + embedding_query = embedding.embed_query(query_text) + if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: + exclude_dict.__setitem__('document_id__in', exclude_document_id_list) + if exclude_id_list is not None and len(exclude_id_list) > 0: + exclude_dict.__setitem__('id__in', exclude_id_list) + query_set = query_set.exclude(**exclude_dict) + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'embedding_search.sql')), + with_table_name=True) + embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params)) + return embedding_model def update_by_source_id(self, source_id: str, instance: Dict): QuerySet(Embedding).filter(source_id=source_id).update(**instance) diff --git a/apps/setting/migrations/0002_model.py b/apps/setting/migrations/0002_model.py new file mode 100644 index 000000000..6afb591fa --- /dev/null +++ b/apps/setting/migrations/0002_model.py @@ -0,0 +1,34 @@ +# Generated by Django 4.1.10 on 2023-11-06 08:51 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0001_initial'), + ('setting', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Model', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=128, verbose_name='名称')), + ('model_type', models.CharField(max_length=128, verbose_name='模型类型')), + ('model_name', models.CharField(max_length=128, verbose_name='模型名称')), + ('provider', models.CharField(max_length=20, verbose_name='供应商')), + ('credential', models.CharField(max_length=5120, verbose_name='模型认证信息')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='成员用户id')), + ], + options={ + 'db_table': 'model', + 'unique_together': {('name', 'user_id')}, + }, + ), + ] diff --git a/apps/setting/migrations/0003_alter_model_provider.py b/apps/setting/migrations/0003_alter_model_provider.py new file mode 100644 index 000000000..a235711fd --- /dev/null +++ b/apps/setting/migrations/0003_alter_model_provider.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.10 on 2023-11-13 05:55 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0002_model'), + ] + + operations = [ + migrations.AlterField( + model_name='model', + name='provider', + field=models.CharField(max_length=128, verbose_name='供应商'), + ), + ] diff --git a/apps/setting/models_provider/__init__.py b/apps/setting/models_provider/__init__.py new file mode 100644 index 000000000..53b7001e5 --- /dev/null +++ b/apps/setting/models_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py new file mode 100644 index 000000000..18a6a6e13 --- /dev/null +++ b/apps/setting/models_provider/base_model_provider.py @@ -0,0 +1,130 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_model_provider.py + @date:2023/10/31 16:19 + @desc: +""" +from abc import ABC, abstractmethod +from enum import Enum +from functools import reduce +from typing import Dict, List + +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage +from langchain.schema.language_model import LanguageModelInput + + +class IModelProvider(ABC): + + @abstractmethod + def get_model_provide_info(self): + pass + + @abstractmethod + def get_model_type_list(self): + pass + + @abstractmethod + def get_model_list(self, model_type): + pass + + @abstractmethod + def get_model_credential(self, model_type, model_name): + pass + + @abstractmethod + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel: + pass + + @abstractmethod + def get_dialogue_number(self): + pass + + +class BaseModelCredential(ABC): + + @abstractmethod + def is_valid(self, model_type: str, model_name, model: Dict[str, object], raise_exception=False): + pass + + @abstractmethod + def encryption_dict(self, model_info: Dict[str, object]): + """ + :param model_info: 模型数据 + :return: 加密后数据 + """ + pass + + @staticmethod + def encryption(message: str): + """ + 加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890 + :param message: + :return: + """ + max_pre_len = 8 + max_post_len = 4 + message_len = len(message) + pre_len = int(message_len / 5 * 2) + post_len = int(message_len / 5 * 1) + pre_str = "".join([message[index] for index in + range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))]) + end_str = "".join( + [message[index] for index in + range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)]) + content = "***************" + return pre_str + content + end_str + + +class ModelTypeConst(Enum): + LLM = {'code': 'LLM', 'message': '大语言模型'} + + +class ModelInfo: + def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential, + **keywords): + self.name = name + self.desc = desc + self.model_type = model_type.name + self.model_credential = model_credential + if keywords is not None: + for key in keywords.keys(): + self.__setattr__(key, keywords.get(key)) + + def get_name(self): + """ + 获取模型名称 + :return: 模型名称 + """ + return self.name + + def get_desc(self): + """ + 获取模型描述 + :return: 模型描述 + """ + return self.desc + + def get_model_type(self): + return self.model_type + + def to_dict(self): + return reduce(lambda x, y: {**x, **y}, + [{attr: self.__getattribute__(attr)} for attr in vars(self) if + not attr.startswith("__") and not attr == 'model_credential'], {}) + + +class ModelProvideInfo: + def __init__(self, provider: str, name: str, icon: str): + self.provider = provider + + self.name = name + + self.icon = icon + + def to_dict(self): + return reduce(lambda x, y: {**x, **y}, + [{attr: self.__getattribute__(attr)} for attr in vars(self) if + not attr.startswith("__")], {}) diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py new file mode 100644 index 000000000..911332ac5 --- /dev/null +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -0,0 +1,17 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: model_provider_constants.py + @date:2023/11/2 14:55 + @desc: +""" +from enum import Enum + +from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider +from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider + + +class ModelProvideConstants(Enum): + model_azure_provider = AzureModelProvider() + model_wenxin_provider = WenxinModelProvider() diff --git a/apps/setting/models_provider/impl/azure_model_provider/__init__.py b/apps/setting/models_provider/impl/azure_model_provider/__init__.py new file mode 100644 index 000000000..53b7001e5 --- /dev/null +++ b/apps/setting/models_provider/impl/azure_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py new file mode 100644 index 000000000..69f6f4a37 --- /dev/null +++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -0,0 +1,110 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: azure_model_provider.py + @date:2023/10/31 16:19 + @desc: +""" +import os +from typing import Dict, List + +from langchain.chat_models import AzureChatOpenAI +from langchain.chat_models.base import BaseChatModel +from langchain.schema import HumanMessage, BaseMessage +from langchain.schema.language_model import LanguageModelInput + +from common import froms +from common.exception.app_exception import AppApiException +from common.froms import BaseForm +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ + ModelInfo, \ + ModelTypeConst +from smartdoc.conf import PROJECT_DIR + + +class AzureLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): + model_type_list = AzureModelProvider().get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(500, f'{model_type} 模型类型不支持') + + if model_name not in model_dict: + raise AppApiException(500, f'{model_name} 模型名称不支持') + + for key in ['api_base', 'api_key', 'deployment_name']: + if key not in model_credential: + if raise_exception: + raise AppApiException(500, f'{key} 字段为必填字段') + else: + return False + try: + AzureModelProvider().query(model_type, model_name, model_credential, + message=[HumanMessage(content='valid')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(500, '校验失败,请检查参数是否正确') + else: + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = froms.TextInputField('API 域名', required=True) + + api_key = froms.PasswordInputField("API Key", required=True) + + deployment_name = froms.TextInputField("部署名", required=True) + + +azure_llm_model_credential = AzureLLMModelCredential() + +model_dict = { + 'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential, + api_version='2023-07-01-preview'), + 'gpt-3.5-turbo-0301': ModelInfo('gpt-3.5-turbo-0301', '', ModelTypeConst.LLM, azure_llm_model_credential, + api_version='2023-07-01-preview'), + 'gpt-3.5-turbo-16k-0613': ModelInfo('gpt-3.5-turbo-16k-0613', '', ModelTypeConst.LLM, azure_llm_model_credential, + api_version='2023-07-01-preview') +} + + +class AzureModelProvider(IModelProvider): + + def get_dialogue_number(self): + return 3 + + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI: + model_info: ModelInfo = model_dict.get(model_name) + azure_chat_open_ai = AzureChatOpenAI( + openai_api_base=model_credential.get('api_base'), + openai_api_version=model_info.api_version, + deployment_name=model_credential.get('deployment_name'), + openai_api_key=model_credential.get('api_key'), + openai_api_type="azure", + tiktoken_model_name=model_name + ) + return azure_chat_open_ai + + def get_model_credential(self, model_type, model_name): + return model_dict.get(model_name).model_credential + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon', + 'azure_icon_svg'))) + + def get_model_list(self, model_type: str): + if model_type is None: + raise AppApiException(500, '模型类型不能为空') + return [model_dict.get(key).to_dict() for key in + list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] + + def get_model_type_list(self): + return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/azure_model_provider/icon/azure_icon_svg b/apps/setting/models_provider/impl/azure_model_provider/icon/azure_icon_svg new file mode 100644 index 000000000..df89a3901 --- /dev/null +++ b/apps/setting/models_provider/impl/azure_model_provider/icon/azure_icon_svg @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/__init__.py b/apps/setting/models_provider/impl/wenxin_model_provider/__init__.py new file mode 100644 index 000000000..53b7001e5 --- /dev/null +++ b/apps/setting/models_provider/impl/wenxin_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/icon/azure_icon_svg b/apps/setting/models_provider/impl/wenxin_model_provider/icon/azure_icon_svg new file mode 100644 index 000000000..df89a3901 --- /dev/null +++ b/apps/setting/models_provider/impl/wenxin_model_provider/icon/azure_icon_svg @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py new file mode 100644 index 000000000..bd6a5d1cf --- /dev/null +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py @@ -0,0 +1,78 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: qian_fan_chat_model.py + @date:2023/11/10 17:45 + @desc: +""" +from typing import Optional, List, Any, Iterator, cast + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models import QianfanChatEndpoint +from langchain.chat_models.base import BaseChatModel +from langchain.load import dumpd +from langchain.schema import LLMResult +from langchain.schema.language_model import LanguageModelInput +from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage +from langchain.schema.output import ChatGenerationChunk +from langchain.schema.runnable import RunnableConfig + + +class QianfanChatModel(QianfanChatEndpoint): + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Iterator[BaseMessageChunk]: + if len(input) % 2 == 0: + input = [HumanMessage(content='占位'), *input] + input = [ + HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content) + for index in range(0, len(input))] + if type(self)._stream == BaseChatModel._stream: + # model doesn't implement streaming, so use default implementation + yield cast( + BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) + ) + else: + config = config or {} + messages = self._convert_input(input).to_messages() + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs} + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = callback_manager.on_chat_model_start( + dumpd(self), + [messages], + invocation_params=params, + options=options, + name=config.get("run_name"), + ) + try: + generation: Optional[ChatGenerationChunk] = None + for chunk in self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.message + if generation is None: + generation = chunk + assert generation is not None + except BaseException as e: + run_manager.on_llm_error(e) + raise e + else: + run_manager.on_llm_end( + LLMResult(generations=[[generation]]), + ) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py b/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py new file mode 100644 index 000000000..1e744c645 --- /dev/null +++ b/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py @@ -0,0 +1,134 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: wenxin_model_provider.py + @date:2023/10/31 16:19 + @desc: +""" +import os +from typing import Dict + +from langchain.chat_models import QianfanChatEndpoint +from langchain.chat_models.baidu_qianfan_endpoint import convert_message_to_dict +from langchain.schema import HumanMessage + +from common import froms +from common.exception.app_exception import AppApiException +from common.froms import BaseForm +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ + ModelInfo, IModelProvider +from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel +from smartdoc.conf import PROJECT_DIR + + +class WenxinLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): + model_type_list = WenxinModelProvider().get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(500, f'{model_type} 模型类型不支持') + + if model_name not in model_dict: + raise AppApiException(500, f'{model_name} 模型名称不支持') + + for key in ['api_key', 'secret_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(500, f'{key} 字段为必填字段') + else: + return False + try: + WenxinModelProvider().get_model(model_type, model_name, model_credential)([HumanMessage(content='valid')]) + except Exception as e: + if raise_exception: + raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确") + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'secret_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + self.secret_key = model_info.get('secret_key') + return self + + api_key = froms.PasswordInputField('API Key', required=True) + + secret_key = froms.PasswordInputField("Secret Key", required=True) + + +win_xin_llm_model_credential = WenxinLLMModelCredential() +model_dict = { + 'ERNIE-Bot-4': ModelInfo('ERNIE-Bot-4', + 'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'ERNIE-Bot': ModelInfo('ERNIE-Bot', + 'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'ERNIE-Bot-turbo': ModelInfo('ERNIE-Bot-turbo', + 'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'BLOOMZ-7B': ModelInfo('BLOOMZ-7B', + 'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'Llama-2-7b-chat': ModelInfo('Llama-2-7b-chat', + 'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'Llama-2-13b-chat': ModelInfo('Llama-2-13b-chat', + 'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'Llama-2-70b-chat': ModelInfo('Llama-2-70b-chat', + 'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'Qianfan-Chinese-Llama-2-7B': ModelInfo('Qianfan-Chinese-Llama-2-7B', + '千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文数据集上表现优异。', + ModelTypeConst.LLM, win_xin_llm_model_credential), + + 'Qianfan-Chinese-Llama-2-13B': ModelInfo('Qianfan-Chinese-Llama-2-13B', + '千帆团队在Llama-2-13b基础上的中文增强版本,在CMMLU、C-EVAL等中文数据集上表现优异。', + ModelTypeConst.LLM, win_xin_llm_model_credential) + +} + + +class WenxinModelProvider(IModelProvider): + + def get_dialogue_number(self): + return 2 + + def get_model(self, model_type, model_name, model_credential: Dict[str, object], + **model_kwargs) -> QianfanChatEndpoint: + return QianfanChatModel(model=model_name, + qianfan_ak=model_credential.get('api_key'), + qianfan_sk=model_credential.get('secret_key'), + streaming=model_kwargs.get('streaming', False)) + + def get_model_type_list(self): + return [{'key': "大语言模型", 'value': "LLM"}] + + def get_model_list(self, model_type): + if model_type is None: + raise AppApiException(500, '模型类型不能为空') + return [model_dict.get(key).to_dict() for key in + list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] + + def get_model_credential(self, model_type, model_name): + if model_name in model_dict: + return model_dict.get(model_name).model_credential + raise AppApiException(500, f'不支持的模型:{model_name}') + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_wenxin_provider', name='Azure OpenAI', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'wenxin_model_provider', 'icon', + 'azure_icon_svg'))) diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py new file mode 100644 index 000000000..0e68a3aa3 --- /dev/null +++ b/apps/setting/serializers/provider_serializers.py @@ -0,0 +1,119 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: provider_serializers.py + @date:2023/11/2 14:01 + @desc: +""" +import json +import uuid +from typing import Dict + +from django.db.models import QuerySet +from rest_framework import serializers + +from common.exception.app_exception import AppApiException +from common.util.rsa_util import encrypt, decrypt +from setting.models.model_management import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +class ModelSerializer(serializers.Serializer): + class Query(serializers.Serializer): + user_id = serializers.UUIDField(required=True) + + name = serializers.CharField(required=False) + + model_type = serializers.CharField(required=False) + + model_name = serializers.CharField(required=False) + + def list(self, with_valid): + if with_valid: + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + name = self.data.get('name') + model_query_set = QuerySet(Model).filter(user_id=user_id) + query_params = {} + if name is not None: + query_params['name__contains'] = name + if self.data.get('model_type') is not None: + query_params['model_type'] = self.data.get('model_type') + if self.data.get('model_name') is not None: + query_params['model_name'] = self.data.get('model_name') + return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)] + + class Create(serializers.Serializer): + user_id = serializers.CharField(required=True) + + name = serializers.CharField(required=True) + + provider = serializers.CharField(required=True) + + model_type = serializers.CharField(required=True) + + model_name = serializers.CharField(required=True) + + credential = serializers.DictField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if QuerySet(Model).filter(user_id=self.data.get('user_id'), + name=self.data.get('name')).exists(): + raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在') + # 校验模型认证数据 + ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'), + self.data.get( + 'model_name')).is_valid( + self.data.get('model_type'), + self.data.get('model_name'), + self.data.get('credential'), + raise_exception=True) + + def insert(self, user_id, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + credential = self.data.get('credential') + name = self.data.get('name') + provider = self.data.get('provider') + model_type = self.data.get('model_type') + model_name = self.data.get('model_name') + model_credential_str = json.dumps(credential) + model = Model(id=uuid.uuid1(), user_id=user_id, name=name, + credential=encrypt(model_credential_str), + provider=provider, model_type=model_type, model_name=model_name) + model.save() + return ModelSerializer.Operate(data={'id': model.id}).one(user_id, with_valid=True) + + @staticmethod + def model_to_dict(model: Model): + credential = json.loads(decrypt(model.credential)) + return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, + 'model_name': model.model_name, + 'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, + model.model_name).encryption_dict( + credential)} + + class Operate(serializers.Serializer): + id = serializers.UUIDField(required=True) + + def one(self, user_id, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + model = QuerySet(Model).get(id=self.data.get('id'), user_id=user_id) + return ModelSerializer.model_to_dict(model) + + +class ProviderSerializer(serializers.Serializer): + provider = serializers.CharField(required=True) + + method = serializers.CharField(required=True) + + def exec(self, exec_params: Dict[str, object], with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + + provider = self.data.get('provider') + method = self.data.get('method') + return getattr(ModelProvideConstants[provider].value, method)(exec_params) diff --git a/apps/setting/serializers/team_serializers.py b/apps/setting/serializers/team_serializers.py index b4cb1364f..af4b2e34c 100644 --- a/apps/setting/serializers/team_serializers.py +++ b/apps/setting/serializers/team_serializers.py @@ -233,12 +233,12 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer): member_permission_list)) # 分为 APPLICATION DATASET俩组 groups = itertools.groupby( - list(map(lambda m: {**m, 'member_id': member_id, - 'operate': dict( - map(lambda key: (key, True if m.get('operate') is not None and m.get( - 'operate').__contains__(key) else False), - [Operate.USE.value, Operate.MANAGE.value]))}, - member_permission_list)), + sorted(list(map(lambda m: {**m, 'member_id': member_id, + 'operate': dict( + map(lambda key: (key, True if m.get('operate') is not None and m.get( + 'operate').__contains__(key) else False), + [Operate.USE.value, Operate.MANAGE.value]))}, + member_permission_list)), key=lambda x: x.get('type')), key=lambda x: x.get('type')) return dict([(key, list(group)) for key, group in groups]) diff --git a/apps/setting/swagger_api/provide_api.py b/apps/setting/swagger_api/provide_api.py new file mode 100644 index 000000000..38076cfeb --- /dev/null +++ b/apps/setting/swagger_api/provide_api.py @@ -0,0 +1,156 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: provide_api.py + @date:2023/11/2 14:25 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ModelQueryApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='模型名称'), + openapi.Parameter(name='model_type', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='模型类型'), + openapi.Parameter(name='model_name', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='基础模型名称') + ] + + +class ModelCreateApi(ApiMixin): + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="调用函数所需要的参数", + description="调用函数所需要的参数", + required=['provide', 'model_info'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, + title="模型名称", + description="模型名称"), + 'provider': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'model_type': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'model_name': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'credential': openapi.Schema(type=openapi.TYPE_OBJECT, + title="模型证书信息", + description="模型证书信息") + } + ) + + +class ProvideApi(ApiMixin): + class ModelTypeList(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='供应名称'), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['key', 'value'], + properties={ + 'key': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型描述", + description="模型类型描述", default="大语言模型"), + 'value': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值", + description="模型类型值", default="LLM"), + + } + ) + + class ModelList(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='供应名称'), + openapi.Parameter(name='model_type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='模型类型'), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc', 'model_type'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="模型名称", + description="模型名称", default="模型名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="模型描述", + description="模型描述", default="xxx模型"), + 'model_type': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值", + description="模型类型值", default="LLM"), + + } + ) + + class ModelForm(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='供应名称'), + openapi.Parameter(name='model_type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='模型类型'), + openapi.Parameter(name='model_name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='模型名称'), + ] + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='供应商'), + openapi.Parameter(name='method', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='需要执行的函数'), + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="调用函数所需要的参数", + description="调用函数所需要的参数", + ) diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 778d1080f..44b408c50 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -5,5 +5,14 @@ from . import views app_name = "team" urlpatterns = [ path('team/member', views.TeamMember.as_view(), name="team"), - path('team/member/', views.TeamMember.Operate.as_view(), name='member') + path('team/member/', views.TeamMember.Operate.as_view(), name='member'), + path('provider//', views.Provide.Exec.as_view(), name='provide_exec'), + path('provider', views.Provide.as_view(), name='provide'), + path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"), + path('provider/model_list', views.Provide.ModelList.as_view(), + name="provider/model_name_list"), + path('provider/model_form', views.Provide.ModelForm.as_view(), + name="provider/model_form"), + path('model', views.Model.as_view(), name='model'), + ] diff --git a/apps/setting/views/Team.py b/apps/setting/views/Team.py index 468a37c2b..30b884527 100644 --- a/apps/setting/views/Team.py +++ b/apps/setting/views/Team.py @@ -11,7 +11,8 @@ from rest_framework.decorators import action from rest_framework.views import APIView from rest_framework.views import Request -from common.auth import TokenAuth +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import PermissionConstants from common.response import result from setting.serializers.team_serializers import TeamMemberSerializer, get_response_body_api, \ UpdateTeamMemberPermissionSerializer @@ -23,14 +24,18 @@ class TeamMember(APIView): @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取团队成员列表", operation_id="获取团员成员列表", - responses=result.get_api_response(get_response_body_api())) + responses=result.get_api_response(get_response_body_api()), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_READ) def get(self, request: Request): return result.success(TeamMemberSerializer(data={'team_id': str(request.user.id)}).list_member()) @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="添加成员", operation_id="添加成员", - request_body=TeamMemberSerializer().get_request_body_api()) + request_body=TeamMemberSerializer().get_request_body_api(), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_CREATE) def post(self, request: Request): team = TeamMemberSerializer(data={'team_id': str(request.user.id)}) return result.success((team.add_member(**request.data))) @@ -41,7 +46,9 @@ class TeamMember(APIView): @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取团队成员权限", operation_id="获取团队成员权限", - manual_parameters=TeamMemberSerializer.Operate.get_request_params_api()) + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_READ) def get(self, request: Request, member_id: str): return result.success(TeamMemberSerializer.Operate( data={'member_id': member_id, 'team_id': str(request.user.id)}).list_member_permission()) @@ -50,8 +57,10 @@ class TeamMember(APIView): @swagger_auto_schema(operation_summary="修改团队成员权限", operation_id="修改团队成员权限", request_body=UpdateTeamMemberPermissionSerializer().get_request_body_api(), - manual_parameters=TeamMemberSerializer.Operate.get_request_params_api() + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(), + tags=["团队"] ) + @has_permissions(PermissionConstants.TEAM_EDIT) def put(self, request: Request, member_id: str): return result.success(TeamMemberSerializer.Operate( data={'member_id': member_id, 'team_id': str(request.user.id)}).edit(request.data)) @@ -59,8 +68,10 @@ class TeamMember(APIView): @action(methods=['DELETE'], detail=False) @swagger_auto_schema(operation_summary="移除成员", operation_id="移除成员", - manual_parameters=TeamMemberSerializer.Operate.get_request_params_api() + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(), + tags=["团队"] ) + @has_permissions(PermissionConstants.TEAM_DELETE) def delete(self, request: Request, member_id: str): return result.success(TeamMemberSerializer.Operate( data={'member_id': member_id, 'team_id': str(request.user.id)}).delete()) diff --git a/apps/setting/views/__init__.py b/apps/setting/views/__init__.py index 3eaf6ef6e..f959267f2 100644 --- a/apps/setting/views/__init__.py +++ b/apps/setting/views/__init__.py @@ -7,3 +7,4 @@ @desc: """ from .Team import * +from .model import * diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py new file mode 100644 index 000000000..84f32e673 --- /dev/null +++ b/apps/setting/views/model.py @@ -0,0 +1,122 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: model.py + @date:2023/11/2 13:55 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import PermissionConstants +from common.response import result +from common.util.common import query_params_to_single_dict +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer +from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi + + +class Model(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建模型", + operation_id="创建模型", + request_body=ModelCreateApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def post(self, request: Request): + return result.success( + ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id, + with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型列表", + operation_id="获取模型列表", + manual_parameters=ModelQueryApi.get_request_params_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + return result.success( + ModelSerializer.Query( + data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list( + with_valid=True)) + + +class Provide(APIView): + class Exec(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="调用供应商函数,获取表单数据", + operation_id="调用供应商函数,获取表单数据", + manual_parameters=ProvideApi.get_request_params_api(), + request_body=ProvideApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def post(self, request: Request, provider: str, method: str): + return result.success( + ProviderSerializer(data={'provider': provider, 'method': method}).exec(request.data, with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型供应商数据", + operation_id="获取模型供应商列表" + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + return result.success( + [ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in + ModelProvideConstants.__members__]) + + class ModelTypeList(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型类型列表", + operation_id="获取模型类型类型列表", + manual_parameters=ProvideApi.ModelTypeList.get_request_params_api(), + responses=result.get_api_array_response(ProvideApi.ModelTypeList.get_response_body_api()) + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + return result.success(ModelProvideConstants[provider].value.get_model_type_list()) + + class ModelList(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型列表", + operation_id="获取模型创建表单", + manual_parameters=ProvideApi.ModelList.get_request_params_api(), + responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api()) + , tags=["模型"] + ) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + model_type = request.query_params.get('model_type') + + return result.success( + ModelProvideConstants[provider].value.get_model_list( + model_type)) + + class ModelForm(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型创建表单", + operation_id="获取模型创建表单", + manual_parameters=ProvideApi.ModelForm.get_request_params_api(), + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + model_type = request.query_params.get('model_type') + model_name = request.query_params.get('model_name') + return result.success( + ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list()) diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index 8dc1ebf0b..643b28504 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -99,6 +99,10 @@ CACHES = { "token_cache": { 'BACKEND': 'common.cache.file_cache.FileCache', 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 + }, + "chat_cache": { + 'BACKEND': 'common.cache.file_cache.FileCache', + 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "chat_cache") # 文件夹路径 } } diff --git a/apps/smartdoc/urls.py b/apps/smartdoc/urls.py index e74af4c48..6246f8919 100644 --- a/apps/smartdoc/urls.py +++ b/apps/smartdoc/urls.py @@ -5,13 +5,13 @@ The `urlpatterns` list routes URLs to views. For more information please see: https://docs.djangoproject.com/en/4.2/topics/http/urls/ Examples: Function views - 1. Add an import: from my_app import views + 1. Add an import: froms my_app import views 2. Add a URL to urlpatterns: path('', views.home, name='home') Class-based views - 1. Add an import: from other_app.views import Home + 1. Add an import: froms other_app.views import Home 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') Including another URLconf - 1. Import the include() function: from django.urls import include, path + 1. Import the include() function: froms django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ import os @@ -43,7 +43,8 @@ schema_view = get_schema_view( urlpatterns = [ path("api/", include("users.urls")), path("api/", include("dataset.urls")), - path("api/", include("setting.urls")) + path("api/", include("setting.urls")), + path("api/", include("application.urls")) ] diff --git a/apps/smartdoc/wsgi.py b/apps/smartdoc/wsgi.py index ce5346b2b..df250f58e 100644 --- a/apps/smartdoc/wsgi.py +++ b/apps/smartdoc/wsgi.py @@ -17,9 +17,9 @@ application = get_wsgi_application() def post_handler(): - from common.event.listener_manage import ListenerManagement - ListenerManagement().run() - ListenerManagement.init_embedding_model_signal.send() + from common import event + event.run() + event.ListenerManagement.init_embedding_model_signal.send() post_handler() diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index 17f1d507d..ece3b31e4 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -19,6 +19,7 @@ from django.db.models import Q from drf_yasg import openapi from rest_framework import serializers +from common.constants.authentication_type import AuthenticationType from common.constants.exception_code_constants import ExceptionCodeConstants from common.constants.permission_constants import RoleConstants, get_permission_list_by_role from common.exception.app_exception import AppApiException @@ -66,7 +67,8 @@ class LoginSerializer(ApiMixin, serializers.Serializer): :return: 用户Token(认证信息) """ user = self.is_valid() - token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email}) + token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email, + 'type': AuthenticationType.USER.value}) return token class Meta: diff --git a/apps/users/views/user.py b/apps/users/views/user.py index 44cde7973..d79cf2340 100644 --- a/apps/users/views/user.py +++ b/apps/users/views/user.py @@ -7,7 +7,6 @@ @desc: """ from django.core import cache -from django.db.models import QuerySet from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action @@ -21,7 +20,6 @@ from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants, CompareConstants from common.response import result from smartdoc.settings import JWT_AUTH -from users.models.user import User as UserModel from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \ RePasswordSerializer, \ SendEmailSerializer, UserProfile @@ -36,7 +34,8 @@ class User(APIView): @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取当前用户信息", operation_id="获取当前用户信息", - responses=result.get_api_response(UserProfile.get_response_body_api())) + responses=result.get_api_response(UserProfile.get_response_body_api()), + tags=['用户']) @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) def get(self, request: Request): return result.success(UserProfile.get_user_profile(request.user)) @@ -58,7 +57,8 @@ class ResetCurrentUserPasswordView(APIView): description="密码") } ), - responses=RePasswordSerializer().get_response_body_api()) + responses=RePasswordSerializer().get_response_body_api(), + tags=['用户']) def post(self, request: Request): data = {'email': request.user.email} data.update(request.data) @@ -77,7 +77,8 @@ class SendEmailToCurrentUserView(APIView): @permission_classes((AllowAny,)) @swagger_auto_schema(operation_summary="发送邮件到当前用户", operation_id="发送邮件到当前用户", - responses=SendEmailSerializer().get_response_body_api()) + responses=SendEmailSerializer().get_response_body_api(), + tags=['用户']) def post(self, request: Request): serializer_obj = SendEmailSerializer(data={'email': request.user.email, 'type': "reset_password"}) if serializer_obj.is_valid(raise_exception=True): @@ -91,7 +92,8 @@ class Logout(APIView): @permission_classes((AllowAny,)) @swagger_auto_schema(operation_summary="登出", operation_id="登出", - responses=SendEmailSerializer().get_response_body_api()) + responses=SendEmailSerializer().get_response_body_api(), + tags=['用户']) def post(self, request: Request): token_cache.delete(request.META.get('HTTP_AUTHORIZATION', None )) @@ -104,7 +106,9 @@ class Login(APIView): @swagger_auto_schema(operation_summary="登录", operation_id="登录", request_body=LoginSerializer().get_request_body_api(), - responses=LoginSerializer().get_response_body_api()) + responses=LoginSerializer().get_response_body_api(), + security=[], + tags=['用户']) def post(self, request: Request): login_request = LoginSerializer(data=request.data) # 校验请求参数 @@ -121,7 +125,9 @@ class Register(APIView): @swagger_auto_schema(operation_summary="用户注册", operation_id="用户注册", request_body=RegisterSerializer().get_request_body_api(), - responses=RegisterSerializer().get_response_body_api()) + responses=RegisterSerializer().get_response_body_api(), + security=[], + tags=['用户']) def post(self, request: Request): serializer_obj = RegisterSerializer(data=request.data) if serializer_obj.is_valid(raise_exception=True): @@ -136,7 +142,9 @@ class RePasswordView(APIView): @swagger_auto_schema(operation_summary="修改密码", operation_id="修改密码", request_body=RePasswordSerializer().get_request_body_api(), - responses=RePasswordSerializer().get_response_body_api()) + responses=RePasswordSerializer().get_response_body_api(), + security=[], + tags=['用户']) def post(self, request: Request): serializer_obj = RePasswordSerializer(data=request.data) return result.success(serializer_obj.reset_password()) @@ -149,7 +157,9 @@ class CheckCode(APIView): @swagger_auto_schema(operation_summary="校验验证码是否正确", operation_id="校验验证码是否正确", request_body=CheckCodeSerializer().get_request_body_api(), - responses=CheckCodeSerializer().get_response_body_api()) + responses=CheckCodeSerializer().get_response_body_api(), + security=[], + tags=['用户']) def post(self, request: Request): return result.success(CheckCodeSerializer(data=request.data).is_valid(raise_exception=True)) @@ -160,7 +170,9 @@ class SendEmail(APIView): @swagger_auto_schema(operation_summary="发送邮件", operation_id="发送邮件", request_body=SendEmailSerializer().get_request_body_api(), - responses=SendEmailSerializer().get_response_body_api()) + responses=SendEmailSerializer().get_response_body_api(), + security=[], + tags=['用户']) def post(self, request: Request): serializer_obj = SendEmailSerializer(data=request.data) if serializer_obj.is_valid(raise_exception=True): diff --git a/pyproject.toml b/pyproject.toml index e844bfa77..d1ae05745 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ djangorestframework = "3.14.0" drf-yasg = "1.21.7" django-filter = "23.2" elasticsearch = "8.9.0" -langchain = "0.0.274" +langchain = "^0.0.321" psycopg2-binary = "2.9.7" jieba = "^0.42.1" diskcache = "^5.6.3" @@ -22,6 +22,10 @@ chardet = "^5.2.0" torch = "^2.1.0" sentence-transformers = "^2.2.2" blinker = "^1.6.3" +openai = "^0.28.1" +tiktoken = "^0.5.1" +qianfan = "^0.1.1" +pycryptodome = "^3.19.0" [build-system]