diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..67b063a9c --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +support@fit2cloud.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..830ab828b --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,25 @@ +# Contributing + +## Create pull request +PR are always welcome, even if they only contain small fixes like typos or a few lines of code. If there will be a significant effort, please document it as an issue and get a discussion going before starting to work on it. + +Please submit a PR broken down into small changes bit by bit. A PR consisting of a lot of features and code changes may be hard to review. It is recommended to submit PRs in an incremental fashion. + +This [development guideline](https://github.com/1Panel-dev/MaxKB/wiki/3-%E5%BC%80%E5%8F%91%E7%8E%AF%E5%A2%83%E6%90%AD%E5%BB%BA) contains information about repository structure, how to set up development environment, how to run it, and more. + +Note: If you split your pull request to small changes, please make sure any of the changes goes to master will not break anything. Otherwise, it can not be merged until this feature complete. + +## Report issues +It is a great way to contribute by reporting an issue. Well-written and complete bug reports are always welcome! Please open an issue and follow the template to fill in required information. + +Before opening any issue, please look up the existing issues to avoid submitting a duplication. +If you find a match, you can "subscribe" to it to get notified on updates. If you have additional helpful information about the issue, please leave a comment. + +When reporting issues, always include: + +* Which version you are using. +* Steps to reproduce the issue. +* Snapshots or log files if needed + +Because the issues are open to the public, when submitting files, be sure to remove any sensitive information, e.g. user name, password, IP address, and company name. You can +replace those parts with "REDACTED" or other strings like "****". \ No newline at end of file diff --git a/README.md b/README.md index 4c5021ad7..bbc1452bf 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max - **开箱即用**:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好; - **无缝嵌入**:支持零编码快速嵌入到第三方业务系统; -- **多模型支持**:支持对接主流的大模型,包括本地私有大模型(如 Llama 2)、Azure OpenAI 和百度千帆大模型等。 +- **多模型支持**:支持对接主流的大模型,包括本地私有大模型(如 Llama 2、Llama 3)、通义千问、OpenAI、Azure OpenAI、Kimi 和百度千帆大模型等。 ## 快速开始 @@ -33,6 +33,9 @@ docker run -d --name=maxkb -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data 1pa - [使用手册](https://github.com/1Panel-dev/MaxKB/wiki/1-%E5%AE%89%E8%A3%85%E9%83%A8%E7%BD%B2) - [论坛求助](https://bbs.fit2cloud.com/c/mk/11) - [演示视频](https://www.bilibili.com/video/BV1BE421M7YM/) +- 技术交流群 + + ## UI 展示 @@ -53,7 +56,7 @@ docker run -d --name=maxkb -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data 1pa - 后端:[Python / Django](https://www.djangoproject.com/) - LangChain:[LangChain](https://www.langchain.com/) - 向量数据库:[PostgreSQL / pgvector](https://www.postgresql.org/) -- 大模型:Azure OpenAI、百度千帆大模型、[Ollama](https://github.com/ollama/ollama) +- 大模型:Azure OpenAI、OpenAI、百度千帆大模型、[Ollama](https://github.com/ollama/ollama)、通义千问、Kimi ## Star History diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..22be037d0 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,17 @@ +# 安全说明 + +如果您发现安全问题,请直接联系我们: + +- support@fit2cloud.com +- 400-052-0755 + +感谢您的支持! + +# Security Policy + +All security bugs should be reported to the contact as below: + +- support@fit2cloud.com +- 400-052-0755 + +Thanks for your support! \ No newline at end of file diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index f4e9296af..549abfaf2 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -6,15 +6,16 @@ @date:2024/1/9 18:10 @desc: 检索知识库 """ +import re from abc import abstractmethod from typing import List, Type +from django.core import validators from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage from common.util.field_message import ErrMessage -from dataset.models import Paragraph class ISearchDatasetStep(IBaseChatPipelineStep): @@ -38,6 +39,10 @@ class ISearchDatasetStep(IBaseChatPipelineStep): # 相似度 0-1之间 similarity = serializers.FloatField(required=True, max_value=1, min_value=0, error_messages=ErrMessage.float("引用分段数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: return self.InstanceSerializer @@ -50,6 +55,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): @abstractmethod def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, **kwargs) -> List[ParagraphPipelineModel]: """ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 @@ -60,6 +66,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): :param exclude_document_id_list: 需要排除的文档id :param exclude_paragraph_id_list: 需要排除段落id :param padding_problem_text 补全问题 + :param search_mode 检索模式 :return: 段落列表 """ pass diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 1781d4f3d..dcd375ce4 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -17,6 +17,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel from common.db.search import native_search from common.util.file_util import get_file_content from dataset.models import Paragraph +from embedding.models import SearchMode from smartdoc.conf import PROJECT_DIR @@ -24,13 +25,14 @@ class BaseSearchDatasetStep(ISearchDatasetStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, **kwargs) -> List[ParagraphPipelineModel]: exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text embedding_model = EmbeddingModel.get_embedding_model() embedding_value = embedding_model.embed_query(exec_problem_text) vector = VectorStore.get_embedding_vector() - embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list, - exclude_paragraph_id_list, True, top_n, similarity) + embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) if embedding_list is None: return [] paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector) diff --git a/apps/application/migrations/0003_application_icon.py b/apps/application/migrations/0003_application_icon.py new file mode 100644 index 000000000..6e040becc --- /dev/null +++ b/apps/application/migrations/0003_application_icon.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-23 11:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0002_chat_client_id'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='icon', + field=models.CharField(default='/ui/favicon.ico', max_length=256, verbose_name='应用icon'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index e20e92385..101816d38 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -19,7 +19,7 @@ from users.models import User def get_dataset_setting_dict(): - return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000} + return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'} def get_model_setting_dict(): @@ -37,6 +37,7 @@ class Application(AppModelMixin): dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict) model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) + icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico") @staticmethod def get_default_model_prompt(): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 3b20a1b39..30af7f56e 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -8,12 +8,13 @@ """ import hashlib import os +import re import uuid from functools import reduce from typing import Dict from django.contrib.postgres.fields import ArrayField -from django.core import cache +from django.core import cache, validators from django.core import signing from django.db import transaction, models from django.db.models import QuerySet @@ -28,10 +29,12 @@ from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404 +from common.field.common import UploadedImageField from common.util.field_message import ErrMessage from common.util.file_util import get_file_content -from dataset.models import DataSet, Document +from dataset.models import DataSet, Document, Image from dataset.serializers.common_serializers import list_paragraph +from embedding.models import SearchMode from setting.models import AuthOperate from setting.models.model_management import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -77,6 +80,10 @@ class DatasetSettingSerializer(serializers.Serializer): error_messages=ErrMessage.float("相识度")) max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000, error_messages=ErrMessage.integer("最多引用字符数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) class ModelSettingSerializer(serializers.Serializer): @@ -245,6 +252,7 @@ class ApplicationSerializer(serializers.Serializer): # 问题补全 problem_optimization = serializers.BooleanField(required=False, allow_null=True, error_messages=ErrMessage.boolean("问题补全")) + icon = serializers.CharField(required=False, allow_null=True, error_messages=ErrMessage.char("icon图标")) class Create(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -291,6 +299,10 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.integer("topN")) similarity = serializers.FloatField(required=True, max_value=1, min_value=0, error_messages=ErrMessage.float("相关度")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -312,6 +324,7 @@ class ApplicationSerializer(serializers.Serializer): hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), + SearchMode(self.data.get('search_mode')), EmbeddingModel.get_embedding_model()) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) @@ -369,6 +382,8 @@ class ApplicationSerializer(serializers.Serializer): def reset_application(application: Dict): application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False del application['dialogue_number'] + if 'dataset_setting' in application: + application['dataset_setting'] = {**application['dataset_setting'], 'search_mode': 'embedding'} return application def page(self, current_page: int, page_size: int, with_valid=True): @@ -381,7 +396,25 @@ class ApplicationSerializer(serializers.Serializer): class ApplicationModel(serializers.ModelSerializer): class Meta: model = Application - fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number'] + fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number', 'icon'] + + class IconOperate(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + image = UploadedImageField(required=True, error_messages=ErrMessage.image("图片")) + + def edit(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).filter(id=self.data.get('application_id')).first() + if application is None: + raise AppApiException(500, '不存在的应用id') + image_id = uuid.uuid1() + image = Image(id=image_id, image=self.data.get('image').read(), image_name=self.data.get('image').name) + image.save() + application.icon = f'/api/image/{image_id}' + application.save() + return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)} class Operate(serializers.Serializer): application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) @@ -445,7 +478,7 @@ class ApplicationSerializer(serializers.Serializer): update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', - 'api_key_is_active'] + 'api_key_is_active', 'icon'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: if update_key == 'multiple_rounds_dialogue': diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index dbbffb2a2..daf850e76 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -77,6 +77,8 @@ class ChatInfo: 'model_id': self.application.model.id if self.application.model is not None else None, 'problem_optimization': self.application.problem_optimization, 'stream': True, + 'search_mode': self.application.dataset_setting.get( + 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding' } @@ -184,9 +186,9 @@ class ChatMessageSerializer(serializers.Serializer): pipeline_manage_builder.append_step(BaseResetProblemStep) # 构建流水线管理器 pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep) - .append_step(BaseGenerateHumanMessageStep) - .append_step(BaseChatStep) - .build()) + .append_step(BaseGenerateHumanMessageStep) + .append_step(BaseChatStep) + .build()) exclude_paragraph_id_list = [] # 相同问题是否需要排除已经查询到的段落 if re_chat: diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 38d747bb1..542fd9f75 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -12,6 +12,17 @@ from common.mixins.api_mixin import ApiMixin class ApplicationApi(ApiMixin): + class EditApplicationIcon(ApiMixin): + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传文件') + ] + class Authentication(ApiMixin): @staticmethod def get_request_body_api(): @@ -143,7 +154,9 @@ class ApplicationApi(ApiMixin): 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", - description="是否开启问题优化", default=True) + description="是否开启问题优化", default=True), + 'icon': openapi.Schema(type=openapi.TYPE_STRING, title="icon", + description="icon", default="/ui/favicon.ico") } ) @@ -161,6 +174,8 @@ class ApplicationApi(ApiMixin): default=0.6), 'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数', description="最多引用字符数", default=3000), + 'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式', + description="embedding|keywords|blend", default='embedding'), } ) diff --git a/apps/application/urls.py b/apps/application/urls.py index 507759286..30866c81a 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -8,6 +8,7 @@ urlpatterns = [ path('application/profile', views.Application.Profile.as_view()), path('application/embed', views.Application.Embed.as_view()), path('application/authentication', views.Application.Authentication.as_view()), + path('application//edit_icon', views.Application.EditIcon.as_view()), path('application//statistics/customer_count', views.ApplicationStatistics.CustomerCount.as_view()), path('application//statistics/customer_count_trend', diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 12e47f854..3ebed0899 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -10,6 +10,7 @@ from django.http import HttpResponse from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser from rest_framework.request import Request from rest_framework.views import APIView @@ -131,6 +132,28 @@ class ApplicationStatistics(APIView): class Application(APIView): authentication_classes = [TokenAuth] + class EditIcon(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改应用icon", + operation_id="修改应用icon", + tags=['应用'], + manual_parameters=ApplicationApi.EditApplicationIcon.get_request_params_api(), + request_body=ApplicationApi.Operate.get_request_body_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), PermissionConstants.APPLICATION_EDIT, + compare=CompareConstants.AND) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.IconOperate( + data={'application_id': application_id, 'user_id': request.user.id, + 'image': request.FILES.get('file')}).edit(request.data)) + class Embed(APIView): @action(methods=["GET"], detail=False) @swagger_auto_schema(operation_summary="获取嵌入js", @@ -343,7 +366,8 @@ class Application(APIView): ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id, "query_text": request.query_params.get("query_text"), "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity')}).hit_test( + 'similarity': request.query_params.get('similarity'), + 'search_mode': request.query_params.get('search_mode')}).hit_test( )) class Operate(APIView): diff --git a/apps/common/auth/handle/impl/application_key.py b/apps/common/auth/handle/impl/application_key.py index 5ebd9db28..cee128ef8 100644 --- a/apps/common/auth/handle/impl/application_key.py +++ b/apps/common/auth/handle/impl/application_key.py @@ -34,7 +34,7 @@ class ApplicationKey(AuthBaseHandle): return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY], permission_list=permission_list, application_id=application_api_key.application_id, - client_id=token, + client_id=str(application_api_key.id), client_type=AuthenticationType.API_KEY.value) def support(self, request, token: str, get_token_details): diff --git a/apps/common/field/common.py b/apps/common/field/common.py index 3d6c95f11..c615e587a 100644 --- a/apps/common/field/common.py +++ b/apps/common/field/common.py @@ -32,3 +32,11 @@ class FunctionField(serializers.Field): def to_representation(self, value): return value + + +class UploadedImageField(serializers.ImageField): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def to_representation(self, value): + return value diff --git a/apps/common/middleware/static_headers_middleware.py b/apps/common/middleware/static_headers_middleware.py index 8f5606ab8..79b799a70 100644 --- a/apps/common/middleware/static_headers_middleware.py +++ b/apps/common/middleware/static_headers_middleware.py @@ -17,7 +17,14 @@ class StaticHeadersMiddleware(MiddlewareMixin): if request.path.startswith('/ui/chat/'): access_token = request.path.replace('/ui/chat/', '') application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() - if application_access_token is not None and application_access_token.white_active: - # 添加自定义的响应头 - response['Content-Security-Policy'] = f'frame-ancestors {" ".join(application_access_token.white_list)}' + if application_access_token is not None: + if application_access_token.white_active: + # 添加自定义的响应头 + response[ + 'Content-Security-Policy'] = f'frame-ancestors {" ".join(application_access_token.white_list)}' + response.content = (response.content.decode('utf-8').replace( + '', + f'') + .replace('MaxKB', f'{application_access_token.application.name}').encode( + "utf-8")) return response diff --git a/apps/common/swagger_api/common_api.py b/apps/common/swagger_api/common_api.py index 71a876ae0..c3d8be6ca 100644 --- a/apps/common/swagger_api/common_api.py +++ b/apps/common/swagger_api/common_api.py @@ -33,6 +33,13 @@ class CommonApi: default=0.6, required=True, description='相关性'), + openapi.Parameter(name='search_mode', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + default="embedding", + required=True, + description='检索模式embedding|keywords|blend' + ) ] @staticmethod diff --git a/apps/common/util/field_message.py b/apps/common/util/field_message.py index 3bcf067e1..93b51b920 100644 --- a/apps/common/util/field_message.py +++ b/apps/common/util/field_message.py @@ -95,3 +95,12 @@ class ErrMessage: 'invalid': gettext_lazy('【%s】日期格式错误。请改用以下格式之一: {format}。'), 'datetime': gettext_lazy('【%s】应为日期,但得到的是日期时间。') } + + @staticmethod + def image(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + 'invalid_image': gettext_lazy('【%s】上载有效的图像。您上载的文件不是图像或图像已损坏。' % field), + 'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。') + } diff --git a/apps/common/util/ts_vecto_util.py b/apps/common/util/ts_vecto_util.py new file mode 100644 index 000000000..b5d4de3fd --- /dev/null +++ b/apps/common/util/ts_vecto_util.py @@ -0,0 +1,107 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: ts_vecto_util.py + @date:2024/4/16 15:26 + @desc: +""" +import re +import uuid +from typing import List + +import jieba +from jieba import analyse + +from common.util.split_model import group_by + +jieba_word_list_cache = [chr(item) for item in range(38, 84)] + +for jieba_word in jieba_word_list_cache: + jieba.add_word('#' + jieba_word + '#') +# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b", +# 某些不分词数据 +# r'"([^"]*)"' +word_pattern_list = [r"v\d+.\d+.\d+", + r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"] + +remove_chars = '\n , :\'<>!@#¥%……&*()!@#$%^&*(): ;,/"./-' + + +def get_word_list(text: str): + result = [] + for pattern in word_pattern_list: + word_list = re.findall(pattern, text) + for child_list in word_list: + for word in child_list if isinstance(child_list, tuple) else [child_list]: + # 不能有: 所以再使用: 进行分割 + if word.__contains__(':'): + item_list = word.split(":") + for w in item_list: + result.append(w) + else: + result.append(word) + return result + + +def replace_word(word_dict, text: str): + for key in word_dict: + text = re.sub('(?= 0]) + + +def to_query(text: str): + # 获取不分词的数据 + word_list = get_word_list(text) + # 获取关键词关系 + word_dict = to_word_dict(word_list, text) + # 替换字符串 + text = replace_word(word_dict, text) + extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng')) + result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if + not remove_chars.__contains__(word)]) + # 删除词库 + for word in word_list: + jieba.del_word(word) + return result diff --git a/apps/dataset/migrations/0002_image.py b/apps/dataset/migrations/0002_image.py new file mode 100644 index 000000000..a5fb59eb1 --- /dev/null +++ b/apps/dataset/migrations/0002_image.py @@ -0,0 +1,27 @@ +# Generated by Django 4.1.13 on 2024-04-22 19:31 + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Image', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('image', models.BinaryField(verbose_name='图片数据')), + ('image_name', models.CharField(default='', max_length=256, verbose_name='图片名称')), + ], + options={ + 'db_table': 'image', + }, + ), + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index 4cea2ef90..9ee76ffe2 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -105,3 +105,12 @@ class ProblemParagraphMapping(AppModelMixin): class Meta: db_table = "problem_paragraph_mapping" + + +class Image(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + image = models.BinaryField(verbose_name="图片数据") + image_name = models.CharField(max_length=256, verbose_name="图片名称", default="") + + class Meta: + db_table = "image" diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 6f03eca2a..5d03def4a 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -37,6 +37,7 @@ from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping from dataset.serializers.common_serializers import list_paragraph, MetaSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer +from embedding.models import SearchMode from setting.models import AuthOperate from smartdoc.conf import PROJECT_DIR @@ -457,6 +458,10 @@ class DataSetSerializers(serializers.ModelSerializer): error_messages=ErrMessage.char("响应Top")) similarity = serializers.FloatField(required=True, max_value=1, min_value=0, error_messages=ErrMessage.char("相似度")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) @@ -474,6 +479,7 @@ class DataSetSerializers(serializers.ModelSerializer): hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), + SearchMode(self.data.get('search_mode')), EmbeddingModel.get_embedding_model()) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) diff --git a/apps/dataset/serializers/image_serializers.py b/apps/dataset/serializers/image_serializers.py new file mode 100644 index 000000000..46a1d72bc --- /dev/null +++ b/apps/dataset/serializers/image_serializers.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image_serializers.py + @date:2024/4/22 16:36 + @desc: +""" +import uuid + +from django.db.models import QuerySet +from django.http import HttpResponse +from rest_framework import serializers + +from common.exception.app_exception import NotFound404 +from common.field.common import UploadedImageField +from common.util.field_message import ErrMessage +from dataset.models import Image + + +class ImageSerializer(serializers.Serializer): + image = UploadedImageField(required=True, error_messages=ErrMessage.image("图片")) + + def upload(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + image_id = uuid.uuid1() + image = Image(id=image_id, image=self.data.get('image').read(), image_name=self.data.get('image').name) + image.save() + return f'/api/image/{image_id}' + + class Operate(serializers.Serializer): + id = serializers.UUIDField(required=True) + + def get(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + image_id = self.data.get('id') + image = QuerySet(Image).filter(id=image_id).first() + if image is None: + raise NotFound404(404, "不存在的图片") + return HttpResponse(image.image, status=200, headers={'Content-Type': 'image/png'}) diff --git a/apps/dataset/swagger_api/image_api.py b/apps/dataset/swagger_api/image_api.py new file mode 100644 index 000000000..f69b94719 --- /dev/null +++ b/apps/dataset/swagger_api/image_api.py @@ -0,0 +1,22 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image_api.py + @date:2024/4/23 11:23 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ImageApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传图片文件') + ] diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 38a92c845..2868bcbbd 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -40,5 +40,6 @@ urlpatterns = [ path('dataset//problem//', views.Problem.Page.as_view()), path('dataset//problem/', views.Problem.Operate.as_view()), path('dataset//problem//paragraph', views.Problem.Paragraph.as_view()), - + path('image/', views.Image.Operate.as_view()), + path('image', views.Image.as_view()) ] diff --git a/apps/dataset/views/__init__.py b/apps/dataset/views/__init__.py index b82d9ef72..6b2abcfb1 100644 --- a/apps/dataset/views/__init__.py +++ b/apps/dataset/views/__init__.py @@ -10,3 +10,4 @@ from .dataset import * from .document import * from .paragraph import * from .problem import * +from .image import * diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 96106ff86..d3720977b 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -111,7 +111,8 @@ class Dataset(APIView): DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id, "query_text": request.query_params.get("query_text"), "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity')}).hit_test( + 'similarity': request.query_params.get('similarity'), + 'search_mode': request.query_params.get('search_mode')}).hit_test( )) class Operate(APIView): diff --git a/apps/dataset/views/image.py b/apps/dataset/views/image.py new file mode 100644 index 000000000..124336f87 --- /dev/null +++ b/apps/dataset/views/image.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image.py + @date:2024/4/22 16:23 + @desc: +""" +from drf_yasg import openapi +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth +from common.response import result +from dataset.serializers.image_serializers import ImageSerializer + + +class Image(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="上传图片", + operation_id="上传图片", + manual_parameters=[openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传文件')], + tags=["图片"]) + def post(self, request: Request): + return result.success(ImageSerializer(data={'image': request.FILES.get('file')}).upload()) + + class Operate(APIView): + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取图片", + operation_id="获取图片", + tags=["图片"]) + def get(self, request: Request, image_id: str): + return ImageSerializer.Operate(data={'id': image_id}).get() diff --git a/apps/embedding/migrations/0002_embedding_search_vector.py b/apps/embedding/migrations/0002_embedding_search_vector.py new file mode 100644 index 000000000..3ed58d582 --- /dev/null +++ b/apps/embedding/migrations/0002_embedding_search_vector.py @@ -0,0 +1,56 @@ +# Generated by Django 4.1.13 on 2024-04-16 11:43 + +import django.contrib.postgres.search +from django.db import migrations + +from common.util.common import sub_array +from common.util.ts_vecto_util import to_ts_vector +from dataset.models import Status +from embedding.models import Embedding + + +def update_embedding_search_vector(embedding, paragraph_list): + paragraphs = [paragraph for paragraph in paragraph_list if paragraph.id == embedding.get('paragraph')] + if len(paragraphs) > 0: + content = paragraphs[0].title + paragraphs[0].content + return Embedding(id=embedding.get('id'), search_vector=to_ts_vector(content)) + return Embedding(id=embedding.get('id'), search_vector="") + + +def save_keywords(apps, schema_editor): + document = apps.get_model("dataset", "Document") + embedding = apps.get_model("embedding", "Embedding") + paragraph = apps.get_model('dataset', 'Paragraph') + db_alias = schema_editor.connection.alias + document_list = document.objects.using(db_alias).all() + for document in document_list: + document.status = Status.embedding + document.save() + paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() + embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', + 'paragraph') + embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding + in embedding_list] + child_array = sub_array(embedding_update_list, 50) + for c in child_array: + try: + embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) + except Exception as e: + print(e) + document.status = Status.success + document.save() + + +class Migration(migrations.Migration): + dependencies = [ + ('embedding', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='embedding', + name='search_vector', + field=django.contrib.postgres.search.SearchVectorField(default='', verbose_name='分词'), + ), + migrations.RunPython(save_keywords) + ] diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index f7cc6bf31..24c78f41f 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -10,6 +10,7 @@ from django.db import models from common.field.vector_field import VectorField from dataset.models.data_set import Document, Paragraph, DataSet +from django.contrib.postgres.search import SearchVectorField class SourceType(models.TextChoices): @@ -19,6 +20,12 @@ class SourceType(models.TextChoices): TITLE = 2, '标题' +class SearchMode(models.TextChoices): + embedding = 'embedding' + keywords = 'keywords' + blend = 'blend' + + class Embedding(models.Model): id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id") @@ -37,6 +44,8 @@ class Embedding(models.Model): embedding = VectorField(verbose_name="向量") + search_vector = SearchVectorField(verbose_name="分词", default="") + meta = models.JSONField(verbose_name="元数据", default=dict) class Meta: diff --git a/apps/embedding/sql/blend_search.sql b/apps/embedding/sql/blend_search.sql new file mode 100644 index 000000000..afb1f0040 --- /dev/null +++ b/apps/embedding/sql/blend_search.sql @@ -0,0 +1,26 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score AS similarity +FROM + ( + SELECT DISTINCT ON + ( "paragraph_id" ) ( similarity ),* , + similarity AS comprehensive_score + FROM + ( + SELECT + *, + (( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity + FROM + embedding ${embedding_query} + ) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE + comprehensive_score >%s +ORDER BY + comprehensive_score DESC + LIMIT %s \ No newline at end of file diff --git a/apps/embedding/sql/keywords_search.sql b/apps/embedding/sql/keywords_search.sql new file mode 100644 index 000000000..a27d0a694 --- /dev/null +++ b/apps/embedding/sql/keywords_search.sql @@ -0,0 +1,17 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score as similarity +FROM + ( + SELECT DISTINCT ON + ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + FROM + ( SELECT *,ts_rank_cd(embedding.search_vector,websearch_to_tsquery('simple',%s),32) AS similarity FROM embedding ${keywords_query}) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE comprehensive_score>%s +ORDER BY comprehensive_score DESC +LIMIT %s \ No newline at end of file diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index c031e1143..496150dda 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -14,7 +14,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings from common.config.embedding_config import EmbeddingModel from common.util.common import sub_array -from embedding.models import SourceType +from embedding.models import SourceType, SearchMode lock = threading.Lock() @@ -113,13 +113,16 @@ class BaseVectorStore(ABC): return result[0] @abstractmethod - def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str], - exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float): + def query(self, query_text:str,query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): pass @abstractmethod def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, + search_mode: SearchMode, embedding: HuggingFaceEmbeddings): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 6f9d25b2a..d5e5d125e 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -9,6 +9,7 @@ import json import os import uuid +from abc import ABC, abstractmethod from typing import Dict, List from django.db.models import QuerySet @@ -18,7 +19,8 @@ from common.config.embedding_config import EmbeddingModel from common.db.search import generate_sql_by_query_dict from common.db.sql_execute import select_list from common.util.file_util import get_file_content -from embedding.models import Embedding, SourceType +from common.util.ts_vecto_util import to_ts_vector, to_query +from embedding.models import Embedding, SourceType, SearchMode from embedding.vector.base_vector import BaseVectorStore from smartdoc.conf import PROJECT_DIR @@ -57,7 +59,8 @@ class PGVector(BaseVectorStore): paragraph_id=paragraph_id, source_id=source_id, embedding=text_embedding, - source_type=source_type) + source_type=source_type, + search_vector=to_ts_vector(text)) embedding.save() return True @@ -71,13 +74,15 @@ class PGVector(BaseVectorStore): is_active=text_list[index].get('is_active', True), source_id=text_list[index].get('source_id'), source_type=text_list[index].get('source_type'), - embedding=embeddings[index]) for index in + embedding=embeddings[index], + search_vector=to_ts_vector(text_list[index]['text'])) for index in range(0, len(text_list))] QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, + search_mode: SearchMode, embedding: HuggingFaceEmbeddings): if dataset_id_list is None or len(dataset_id_list) == 0: return [] @@ -87,17 +92,14 @@ class PGVector(BaseVectorStore): if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: exclude_dict.__setitem__('document_id__in', exclude_document_id_list) query_set = query_set.exclude(**exclude_dict) - exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', - 'hit_test.sql')), - with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(embedding_query), *exec_params, similarity, top_number]) - return embedding_model + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode) - def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str], - exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float): + def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): exclude_dict = {} if dataset_id_list is None or len(dataset_id_list) == 0: return [] @@ -107,14 +109,9 @@ class PGVector(BaseVectorStore): if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0: exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list) query_set = query_set.exclude(**exclude_dict) - exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', - 'embedding_search.sql')), - with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), *exec_params, similarity, top_n]) - return embedding_model + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode) def update_by_source_id(self, source_id: str, instance: Dict): QuerySet(Embedding).filter(source_id=source_id).update(**instance) @@ -141,3 +138,81 @@ class PGVector(BaseVectorStore): def delete_by_paragraph_id(self, paragraph_id: str): QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete() + + +class ISearch(ABC): + @abstractmethod + def support(self, search_mode: SearchMode): + pass + + @abstractmethod + def handle(self, query_set, query_text, query_embedding, top_number: int, + similarity: float, search_mode: SearchMode): + pass + + +class EmbeddingSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'embedding_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.embedding.value + + +class KeywordsSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'keywords_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [to_query(query_text), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.keywords.value + + +class BlendSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'blend_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, + top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.blend.value + + +search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()] diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index 3816795e5..0a46cbfa4 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -14,6 +14,8 @@ from setting.models_provider.impl.openai_model_provider.openai_model_provider im from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider +from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider +from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider class ModelProvideConstants(Enum): @@ -23,3 +25,5 @@ class ModelProvideConstants(Enum): model_openai_provider = OpenAIModelProvider() model_kimi_provider = KimiModelProvider() model_qwen_provider = QwenModelProvider() + model_zhipu_provider = ZhiPuModelProvider() + model_xf_provider = XunFeiModelProvider() diff --git a/apps/setting/models_provider/impl/xf_model_provider/__init__.py b/apps/setting/models_provider/impl/xf_model_provider/__init__.py new file mode 100644 index 000000000..c743b4e18 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/04/19 15:55 + @desc: +""" \ No newline at end of file diff --git a/apps/setting/models_provider/impl/xf_model_provider/icon/xf_icon_svg b/apps/setting/models_provider/impl/xf_model_provider/icon/xf_icon_svg new file mode 100644 index 000000000..b74e351e2 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/icon/xf_icon_svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py b/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py new file mode 100644 index 000000000..a09d48092 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/04/19 15:55 + @desc: +""" + +from typing import List, Optional, Any, Iterator + +from langchain_community.chat_models import ChatSparkLLM +from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk + + +class XFChatSparkLLM(ChatSparkLLM): + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + default_chunk_class = AIMessageChunk + + self.client.arun( + [_convert_message_to_dict(m) for m in messages], + self.spark_user_id, + self.model_kwargs, + True, + ) + for content in self.client.subscribe(timeout=self.request_timeout): + if "data" not in content: + continue + delta = content["data"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + cg_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) + yield cg_chunk diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py new file mode 100644 index 000000000..28059c5c6 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -0,0 +1,103 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xf_model_provider.py + @date:2024/04/19 14:47 + @desc: +""" +import os +from typing import Dict + +from langchain.schema import HumanMessage +from langchain_community.chat_models import ChatSparkLLM + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ + ModelInfo, IModelProvider, ValidCode +from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM +from smartdoc.conf import PROJECT_DIR +import ssl + +ssl._create_default_https_context = ssl.create_default_context() + + +class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): + model_type_list = XunFeiModelProvider().get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = XunFeiModelProvider().get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + spark_api_url = forms.TextInputField('API 域名', required=True) + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + +qwen_model_credential = XunFeiLLMModelCredential() + +model_dict = { + 'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential), + 'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential), + 'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential) +} + + +class XunFeiModelProvider(IModelProvider): + + def get_dialogue_number(self): + return 3 + + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM: + zhipuai_chat = XFChatSparkLLM( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + spark_llm_domain=model_name + ) + return zhipuai_chat + + def get_model_credential(self, model_type, model_name): + if model_name in model_dict: + return model_dict.get(model_name).model_credential + return qwen_model_credential + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon', + 'xf_icon_svg'))) + + def get_model_list(self, model_type: str): + if model_type is None: + raise AppApiException(500, '模型类型不能为空') + return [model_dict.get(key).to_dict() for key in + list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] + + def get_model_type_list(self): + return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/__init__.py b/apps/setting/models_provider/impl/zhipu_model_provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/icon/zhipuai_icon_svg b/apps/setting/models_provider/impl/zhipu_model_provider/icon/zhipuai_icon_svg new file mode 100644 index 000000000..f39fedcbb --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/icon/zhipuai_icon_svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py new file mode 100644 index 000000000..b84bb3d15 --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py @@ -0,0 +1,93 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: zhipu_model_provider.py + @date:2024/04/19 13:5 + @desc: +""" +import os +from typing import Dict + +from langchain.schema import HumanMessage +from langchain_community.chat_models import ChatZhipuAI + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ + ModelInfo, IModelProvider, ValidCode +from smartdoc.conf import PROJECT_DIR + + +class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): + model_type_list = ZhiPuModelProvider().get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = ZhiPuModelProvider().get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + +qwen_model_credential = ZhiPuLLMModelCredential() + +model_dict = { + 'glm-4': ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential), + 'glm-4v': ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential), + 'glm-3-turbo': ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential) +} + + +class ZhiPuModelProvider(IModelProvider): + + def get_dialogue_number(self): + return 3 + + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI: + zhipuai_chat = ChatZhipuAI( + temperature=0.5, + api_key=model_credential.get('api_key'), + model=model_name + ) + return zhipuai_chat + + def get_model_credential(self, model_type, model_name): + if model_name in model_dict: + return model_dict.get(model_name).model_credential + return qwen_model_credential + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon', + 'zhipuai_icon_svg'))) + + def get_model_list(self, model_type: str): + if model_type is None: + raise AppApiException(500, '模型类型不能为空') + return [model_dict.get(key).to_dict() for key in + list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] + + def get_model_type_list(self): + return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index 6672a4cb5..352bf6f96 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -418,7 +418,8 @@ class UserProfile(ApiMixin): permission_list = get_user_dynamics_permission(str(user.id)) permission_list += [p.value for p in get_permission_list_by_role(RoleConstants[user.role])] return {'id': user.id, 'username': user.username, 'email': user.email, 'role': user.role, - 'permissions': [str(p) for p in permission_list]} + 'permissions': [str(p) for p in permission_list], + 'is_edit_password': user.password == 'd880e722c47a34d8e9fce789fc62389d' if user.role == 'ADMIN' else False} @staticmethod def get_response_body_api(): diff --git a/pyproject.toml b/pyproject.toml index 35d4a42f0..f70b250bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,10 @@ pymupdf = "1.24.1" python-docx = "^1.1.0" xlwt = "^1.3.0" dashscope = "^1.17.0" +zhipuai = "^2.0.1" +httpx = "^0.27.0" +httpx-sse = "^0.4.0" +websocket-client = "^1.7.0" [build-system] requires = ["poetry-core"] diff --git a/ui/src/api/application-overview.ts b/ui/src/api/application-overview.ts index e6a4d2c4d..0513a0d34 100644 --- a/ui/src/api/application-overview.ts +++ b/ui/src/api/application-overview.ts @@ -67,10 +67,24 @@ const getStatistics: ( return get(`${prefix}/${application_id}/statistics/chat_record_aggregate_trend`, data, loading) } +/** + * 修改应用icon + * @param 参数 application_id + * data: file + */ +const putAppIcon: ( + application_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}/edit_icon`, data, undefined, loading) +} + export default { getAPIKey, postAPIKey, delAPIKey, putAPIKey, - getStatistics + getStatistics, + putAppIcon } diff --git a/ui/src/api/image.ts b/ui/src/api/image.ts new file mode 100644 index 000000000..425e8c6c3 --- /dev/null +++ b/ui/src/api/image.ts @@ -0,0 +1,15 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' + +const prefix = '/image' +/** + * 上传图片 + * @param 参数 file:file + */ +const postImage: (data: any) => Promise> = (data) => { + return post(`${prefix}`, data) +} + +export default { + postImage +} diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index eefd8f31f..b39ed6707 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -10,6 +10,7 @@ interface ApplicationFormType { dataset_setting?: any model_setting?: any problem_optimization?: boolean + icon?: string | undefined } interface chatType { id: string diff --git a/ui/src/api/type/user.ts b/ui/src/api/type/user.ts index 04cbd4140..6724252c9 100644 --- a/ui/src/api/type/user.ts +++ b/ui/src/api/type/user.ts @@ -19,6 +19,10 @@ interface User { * 用户权限 */ permissions: Array + /** + * 是否需要修改密码 + */ + is_edit_password?: boolean } interface LoginRequest { diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index a76a3ab12..7390a84fa 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -364,15 +364,15 @@ const getWrite = (chat: any, reader: any, stream: boolean) => { let str = decoder.decode(value, { stream: true }) // 这里解释一下 start 因为数据流返回流并不是按照后端chunk返回 我们希望得到的chunk是data:{xxx}\n\n 但是它获取到的可能是 data:{ -> xxx}\n\n 总而言之就是 fetch不能保证每个chunk都说以data:开始 \n\n结束 tempResult += str - if (tempResult.endsWith('\n\n')) { - str = tempResult - tempResult = '' + const split = tempResult.match(/data:.*}\n\n/g) + if (split) { + str = split.join('') + tempResult = tempResult.replace(str, '') } else { return reader.read().then(write_stream) } // 这里解释一下 end if (str && str.startsWith('data:')) { - const split = str.match(/data:.*}\n\n/g) if (split) { for (const index in split) { const chunk = JSON?.parse(split[index].replace('data:', '')) diff --git a/ui/src/components/read-write/index.vue b/ui/src/components/read-write/index.vue index 5cc99a404..013bf7fc0 100644 --- a/ui/src/components/read-write/index.vue +++ b/ui/src/components/read-write/index.vue @@ -21,6 +21,8 @@ autofocus :maxlength="maxlength || '-'" :show-word-limit="maxlength ? true : false" + @blur="isEdit = false" + @keyup.enter="submit" > @@ -64,6 +66,10 @@ const loading = ref(false) watch(isEdit, (bool) => { if (!bool) { writeValue.value = '' + } else { + nextTick(() => { + inputRef.value?.focus() + }) } }) @@ -80,10 +86,6 @@ function editNameHandle() { isEdit.value = true } -onMounted(() => { - nextTick(() => { - inputRef.value?.focus() - }) -}) +onMounted(() => {}) diff --git a/ui/src/layout/components/breadcrumb/index.vue b/ui/src/layout/components/breadcrumb/index.vue index e2089c864..2e98be0e8 100644 --- a/ui/src/layout/components/breadcrumb/index.vue +++ b/ui/src/layout/components/breadcrumb/index.vue @@ -18,6 +18,7 @@ class="mr-8" :size="24" /> + + diff --git a/ui/src/views/application-overview/index.vue b/ui/src/views/application-overview/index.vue index d981e5063..bbda53a29 100644 --- a/ui/src/views/application-overview/index.vue +++ b/ui/src/views/application-overview/index.vue @@ -5,14 +5,37 @@

应用信息

- +
+ + + + + + + +
+

{{ detail?.name }}

@@ -102,6 +125,7 @@ + + diff --git a/ui/src/views/application/index.vue b/ui/src/views/application/index.vue index 58969cbc8..e7d92a0ec 100644 --- a/ui/src/views/application/index.vue +++ b/ui/src/views/application/index.vue @@ -41,12 +41,21 @@ > @@ -85,6 +94,7 @@ import { ref, onMounted, reactive, computed } from 'vue' import applicationApi from '@/api/application' import { MsgSuccess, MsgConfirm } from '@/utils/message' +import { isAppIcon } from '@/utils/application' import { useRouter } from 'vue-router' import useStore from '@/stores' const { application } = useStore() diff --git a/ui/src/views/dataset/DatasetSetting.vue b/ui/src/views/dataset/DatasetSetting.vue index f4f13dedc..539689c06 100644 --- a/ui/src/views/dataset/DatasetSetting.vue +++ b/ui/src/views/dataset/DatasetSetting.vue @@ -59,12 +59,21 @@ {{ item.name }} @@ -87,6 +96,7 @@ import BaseForm from '@/views/dataset/component/BaseForm.vue' import datasetApi from '@/api/dataset' import type { ApplicationFormType } from '@/api/type/application' import { MsgSuccess } from '@/utils/message' +import { isAppIcon } from '@/utils/application' import useStore from '@/stores' const route = useRoute() const { diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 3c9b5e216..8705ed818 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -54,7 +54,7 @@