From 081f034420aec68e60145ee47a85e82f5fa0907a Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Mon, 18 Dec 2023 15:20:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8B=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/chat_message_serializers.py | 4 +- apps/smartdoc/settings/base.py | 2 +- install_model.py | 71 +++++++++++++++++++ main.py | 1 + 4 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 install_model.py diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 981a4e380..db4fddf61 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -140,7 +140,9 @@ class ChatMessageSerializer(serializers.Serializer): ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id, paragraph_id, source_type, - source_id, c, 0, + source_id, + c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~', + 0, 0)) # 重新设置缓存 chat_cache.set(chat_id, diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index b8f84f8e4..9277f2984 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -157,7 +157,7 @@ TIME_ZONE = CONFIG.get_time_zone() USE_I18N = True -USE_TZ = True +USE_TZ = False # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.2/howto/static-files/ diff --git a/install_model.py b/install_model.py new file mode 100644 index 000000000..1400cf93f --- /dev/null +++ b/install_model.py @@ -0,0 +1,71 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: install_model.py + @date:2023/12/18 14:02 + @desc: +""" +import json +import os.path + +from transformers import GPT2TokenizerFast +import sentence_transformers + +prefix_dir = os.path.join(os.path.dirname(os.path.abspath(os.getcwd())), 'model') + +model_config = [ + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base'), + 'pretrained_model_name_or_path': 'gpt2' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base'), + 'pretrained_model_name_or_path': 'gpt2-medium' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base'), + 'pretrained_model_name_or_path': 'gpt2-large' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base'), + 'pretrained_model_name_or_path': 'gpt2-xl' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base'), + 'pretrained_model_name_or_path': 'distilgpt2' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'model_name_or_path': 'shibing624/text2vec-base-chinese', + 'cache_folder': os.path.join(prefix_dir, 'embedding') + }, + 'download_function': sentence_transformers.SentenceTransformer + } + +] + + +def install(): + for model in model_config: + print(json.dumps(model.get('download_params'))) + model.get('download_function')(**model.get('download_params')) + + +if __name__ == '__main__': + install() diff --git a/main.py b/main.py index 237efdc3a..8f8802302 100644 --- a/main.py +++ b/main.py @@ -48,6 +48,7 @@ def start_services(): if __name__ == '__main__': + os.environ['TRANSFORMERS_CACHE'] = '/opt/maxkb/model' parser = argparse.ArgumentParser( description=""" qabot service control tools;