mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: 优化模型参数
This commit is contained in:
parent
a556d27f74
commit
210f99bf70
|
|
@ -73,6 +73,9 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
|
||||
max_tokens = serializers.IntegerField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.integer("最大token数"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class BaseChatStep(IChatStep):
|
|||
client_id=None, client_type=None,
|
||||
no_references_setting=None,
|
||||
**kwargs):
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id, **kwargs) if model_id is not None else None
|
||||
if stream:
|
||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding
|
|||
from embedding.models import SearchMode
|
||||
from setting.models import AuthOperate
|
||||
from setting.models.model_management import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.serializers.provider_serializers import ModelSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.conf import settings
|
||||
|
|
@ -109,6 +110,10 @@ class DatasetSettingSerializer(serializers.Serializer):
|
|||
|
||||
class ModelSettingSerializer(serializers.Serializer):
|
||||
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
|
||||
temperature = serializers.FloatField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.char("温度"))
|
||||
max_tokens = serializers.IntegerField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.integer("最大token数"))
|
||||
|
||||
|
||||
class ApplicationWorkflowSerializer(serializers.Serializer):
|
||||
|
|
@ -541,6 +546,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
class Operate(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
model_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("模型id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -722,6 +728,35 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
[self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None,
|
||||
application.user_id, self.data.get('user_id')])
|
||||
|
||||
def get_other_file_list(self):
|
||||
temperature = None
|
||||
max_tokens = None
|
||||
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
|
||||
if application:
|
||||
setting_dict = application.model_setting
|
||||
temperature = setting_dict.get("temperature")
|
||||
max_tokens = setting_dict.get("max_tokens")
|
||||
model = Model.objects.filter(id=self.initial_data.get("model_id")).first()
|
||||
if model:
|
||||
res = ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
||||
model.model_name).get_other_fields(
|
||||
model.model_name)
|
||||
if temperature and res.get('temperature'):
|
||||
res['temperature']['value'] = temperature
|
||||
if max_tokens and res.get('max_tokens'):
|
||||
res['max_tokens']['value'] = max_tokens
|
||||
return res
|
||||
|
||||
def save_other_config(self, data):
|
||||
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
|
||||
if application:
|
||||
setting_dict = application.model_setting
|
||||
for key in ['max_tokens', 'temperature']:
|
||||
if key in data:
|
||||
setting_dict[key] = data[key]
|
||||
application.model_setting = setting_dict
|
||||
application.save()
|
||||
|
||||
class ApplicationKeySerializerModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ApplicationApiKey
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ class ChatInfo:
|
|||
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||
'problem_optimization': self.application.problem_optimization,
|
||||
'stream': True,
|
||||
'temperature': model_setting.get('temperature') if 'temperature' in model_setting else None,
|
||||
'max_tokens': model_setting.get('max_tokens') if 'max_tokens' in model_setting else None,
|
||||
'search_mode': self.application.dataset_setting.get(
|
||||
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
|
||||
'no_references_setting': self.application.dataset_setting.get(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ urlpatterns = [
|
|||
path('application/<str:application_id>/statistics/chat_record_aggregate_trend',
|
||||
views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()),
|
||||
path('application/<str:application_id>/model', views.Application.Model.as_view()),
|
||||
path('application/<str:application_id>/model/<str:model_id>', views.Application.Model.Operate.as_view()),
|
||||
path('application/<str:application_id>/other-config', views.Application.Model.OtherConfig.as_view()),
|
||||
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
|
||||
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
|
||||
path("application/<str:application_id>/api_key/<str:api_key_id>",
|
||||
|
|
|
|||
|
|
@ -187,6 +187,28 @@ class Application(APIView):
|
|||
data={'application_id': application_id,
|
||||
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
|
||||
operation_id="获取应用参数设置其他字段",
|
||||
tags=["应用/会话"])
|
||||
def get(self, request: Request, application_id: str, model_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id, 'model_id': model_id}).get_other_file_list())
|
||||
|
||||
class OtherConfig(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
|
||||
operation_id="获取应用参数设置其他字段",
|
||||
tags=["应用/会话"])
|
||||
def put(self, request: Request, application_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id}).save_other_config(request.data))
|
||||
|
||||
class Profile(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from common.util.rsa_util import rsa_long_decrypt
|
|||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
def get_model_(provider, model_type, model_name, credential):
|
||||
def get_model_(provider, model_type, model_name, credential, **kwargs):
|
||||
"""
|
||||
获取模型实例
|
||||
@param provider: 供应商
|
||||
|
|
@ -25,17 +25,17 @@ def get_model_(provider, model_type, model_name, credential):
|
|||
model = get_provider(provider).get_model(model_type, model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(credential)),
|
||||
streaming=True)
|
||||
streaming=True, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def get_model(model):
|
||||
def get_model(model, **kwargs):
|
||||
"""
|
||||
获取模型实例
|
||||
@param model: model 数据库Model实例对象
|
||||
@return: 模型实例
|
||||
"""
|
||||
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
|
||||
return get_model_(model.provider, model.model_type, model.model_name, model.credential, **kwargs)
|
||||
|
||||
|
||||
def get_provider(provider):
|
||||
|
|
|
|||
|
|
@ -108,6 +108,13 @@ class BaseModelCredential(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
"""
|
||||
获取其他字段
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def encryption(message: str):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -60,3 +60,16 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
region_name = forms.TextInputField('Region Name', required=True)
|
||||
access_key_id = forms.TextInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 8192,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional, Iterator
|
||||
from langchain_community.chat_models import BedrockChat
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from langchain_community.chat_models.bedrock import ChatPromptAdapter
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class BedrockModel(MaxKBBaseModel, BedrockChat):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
|
||||
streaming: bool = False, **kwargs):
|
||||
super().__init__(model_id=model_id, region_name=region_name,
|
||||
|
|
@ -15,21 +21,52 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
|
|||
@classmethod
|
||||
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||
**model_kwargs) -> 'BedrockModel':
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
|
||||
return cls(
|
||||
model_id=model_name,
|
||||
region_name=model_credential['region_name'],
|
||||
credentials_profile_name=model_credential['credentials_profile_name'],
|
||||
streaming=model_kwargs.pop('streaming', False),
|
||||
**model_kwargs
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def _get_num_tokens(self, content: str) -> int:
|
||||
"""Helper method to count tokens in a string."""
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(content))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self._get_num_tokens(text)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
provider = self._get_provider()
|
||||
prompt, system, formatted_messages = None, None, None
|
||||
|
||||
if provider == "anthropic":
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
else:
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
for chunk in self._prepare_input_and_invoke_stream(
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
delta = chunk.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
|
|
|
|||
|
|
@ -53,3 +53,25 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||
|
||||
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,36 +6,102 @@
|
|||
@date:2024/4/28 11:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from typing import List, Dict, Optional, Any, Iterator, Type
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
|
||||
return AzureChatModel(
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
|
||||
deployment_name=model_credential.get('deployment_name'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_type="azure"
|
||||
openai_api_type="azure",
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
if self.include_response_headers:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
base_generation_info = {}
|
||||
with response:
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
if token_usage := chunk.get("usage"):
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
||||
logprobs = None
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {**base_generation_info} if is_first_chunk else {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
)
|
||||
is_first_chunk = False
|
||||
yield generation_chunk
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class BaseChatOpenAI(ChatOpenAI):
|
|||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
if len(chunk["choices"]) == 0 or chunk["choices"][0]["finish_reason"] == "length" or chunk["choices"][0]["finish_reason"] == "stop":
|
||||
if token_usage := chunk.get("usage"):
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
||||
logprobs = None
|
||||
|
|
|
|||
|
|
@ -46,3 +46,25 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,12 +13,24 @@ from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
|||
|
||||
|
||||
class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
|
||||
deepseek_chat_open_ai = DeepSeekChatModel(
|
||||
model=model_name,
|
||||
openai_api_base='https://api.deepseek.com',
|
||||
openai_api_key=model_credential.get('api_key')
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
)
|
||||
return deepseek_chat_open_ai
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
res = model.invoke([HumanMessage(content='你好')])
|
||||
print(res)
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
|
|
@ -46,3 +47,25 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,11 +16,23 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
|
|||
|
||||
|
||||
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'temperature' in model_kwargs:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
if 'max_tokens' in model_kwargs:
|
||||
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
|
||||
|
||||
gemini_chat = GeminiChatModel(
|
||||
model=model_name,
|
||||
google_api_key=model_credential.get('api_key')
|
||||
google_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
)
|
||||
return gemini_chat
|
||||
|
||||
|
|
|
|||
|
|
@ -47,3 +47,25 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.3,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,20 +8,28 @@
|
|||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
|
||||
kimi_chat_open_ai = KimiChatModel(
|
||||
openai_api_base=model_credential['api_base'],
|
||||
openai_api_key=model_credential['api_key'],
|
||||
model_name=model_name,
|
||||
**optional_params
|
||||
)
|
||||
return kimi_chat_open_ai
|
||||
|
||||
|
|
|
|||
|
|
@ -42,3 +42,25 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.3,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,5 +30,12 @@ class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
api_base = model_credential.get('api_base', '')
|
||||
base_url = get_base_url(api_base)
|
||||
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
|
||||
return OllamaChatModel(model=model_name, openai_api_base=base_url,
|
||||
openai_api_key=model_credential.get('api_key'))
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
stream_usage=True, **optional_params)
|
||||
|
|
|
|||
|
|
@ -47,3 +47,25 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,21 +6,32 @@
|
|||
@date:2024/4/18 15:28
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Iterator, Any, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
azure_chat_open_ai = OpenAIChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
max_tokens=model_kwargs.get('max_tokens', 5),
|
||||
temperature=model_kwargs.get('temperature', 0.5),
|
||||
**optional_params,
|
||||
stream_usage=True
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
|
|
|||
|
|
@ -45,3 +45,16 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 1.0,
|
||||
'min': 0.1,
|
||||
'max': 1.9,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,10 +6,13 @@
|
|||
@date:2024/4/28 11:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Iterator, Any
|
||||
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_community.llms.tongyi import generate_with_last_element_mark
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
|
@ -18,16 +21,61 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
|
|||
class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
chat_tong_yi = QwenChatModel(
|
||||
model_name=model_name,
|
||||
dashscope_api_key=model_credential.get('api_key')
|
||||
dashscope_api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
return chat_tong_yi
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.get_last_generation_info().get('input_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params: Dict[str, Any] = self._invocation_params(
|
||||
messages=messages, stop=stop, stream=True, **kwargs
|
||||
)
|
||||
|
||||
for stream_resp, is_last_chunk in generate_with_last_element_mark(
|
||||
self.stream_completion_with_retry(**params)
|
||||
):
|
||||
choice = stream_resp["output"]["choices"][0]
|
||||
message = choice["message"]
|
||||
if (
|
||||
choice["finish_reason"] == "stop"
|
||||
and message["content"] == ""
|
||||
):
|
||||
token_usage = stream_resp["usage"]
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
||||
if (
|
||||
choice["finish_reason"] == "null"
|
||||
and message["content"] == ""
|
||||
and "tool_calls" not in message
|
||||
):
|
||||
continue
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
**self._chat_generation_from_qwen_resp(
|
||||
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
|
||||
)
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -44,4 +44,17 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
hunyuan_app_id = forms.TextInputField('APP ID', required=True)
|
||||
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
|
||||
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
|
||||
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.5,
|
||||
'min': 0.1,
|
||||
'max': 2.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -236,6 +236,11 @@ class ChatHunyuan(BaseChatModel):
|
|||
choice["Delta"], default_chunk_class
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
# FinishReason === stop
|
||||
if choice.get("FinishReason") == "stop":
|
||||
self.__dict__.setdefault("_last_generation_info", {}).update(
|
||||
response.get("Usage", {})
|
||||
)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
|
@ -15,12 +15,18 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
|
|||
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
|
||||
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
|
||||
|
||||
optional_params = {}
|
||||
if 'temperature' in kwargs:
|
||||
optional_params['temperature'] = kwargs['temperature']
|
||||
|
||||
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
|
||||
raise ValueError(
|
||||
"All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.")
|
||||
|
||||
super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
|
||||
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs)
|
||||
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming,
|
||||
temperature=optional_params.get('temperature', None)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
|
||||
|
|
@ -28,10 +34,11 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
|
|||
streaming = model_kwargs.pop('streaming', False)
|
||||
return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages)
|
||||
return self.get_last_generation_info().get('PromptTokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('CompletionTokens', 0)
|
||||
|
|
|
|||
|
|
@ -48,3 +48,25 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
access_key_id = forms.PasswordInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
|
||||
def get_other_fields(self):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.3,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,27 +1,21 @@
|
|||
from typing import List, Dict
|
||||
|
||||
from langchain_community.chat_models import VolcEngineMaasChat
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
volcanic_engine_chat = VolcanicEngineChatModel(
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return VolcanicEngineChatModel(
|
||||
model=model_name,
|
||||
volc_engine_maas_ak=model_credential.get("access_key_id"),
|
||||
volc_engine_maas_sk=model_credential.get("secret_access_key"),
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
)
|
||||
return volcanic_engine_chat
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
|||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
||||
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ model_info_list = [
|
|||
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||
ModelTypeConst.LLM,
|
||||
volcanic_engine_llm_model_credential, OpenAIChatModel
|
||||
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -53,3 +53,16 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.95,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,28 +6,73 @@
|
|||
@date:2023/11/10 17:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
import uuid
|
||||
from typing import List, Dict, Optional, Any, Iterator
|
||||
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
from langchain.schema.messages import get_buffer_string
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
)
|
||||
|
||||
|
||||
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return QianfanChatModel(model=model_name,
|
||||
qianfan_ak=model_credential.get('api_key'),
|
||||
qianfan_sk=model_credential.get('secret_key'),
|
||||
streaming=model_kwargs.get('streaming', False))
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
**optional_params)
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
params["stop"] = stop
|
||||
params["stream"] = True
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
msg = _convert_dict_to_message(res)
|
||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||
if msg.content == "":
|
||||
token_usage = res.get("body").get("usage")
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=AIMessageChunk( # type: ignore[call-arg]
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=additional_kwargs,
|
||||
),
|
||||
generation_info=msg.additional_kwargs,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -49,3 +49,28 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
max_value = 8192
|
||||
if model_name == 'general' or model_name == 'pro-128k':
|
||||
max_value = 4096
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.5,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 4096,
|
||||
'min': 1,
|
||||
'max': max_value,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,14 +25,19 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return XFChatSparkLLM(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=model_credential.get('spark_api_url'),
|
||||
spark_llm_domain=model_name,
|
||||
temperature=model_kwargs.get('temperature', 0.5),
|
||||
max_tokens=model_kwargs.get('max_tokens', 5),
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
|
|
|
|||
|
|
@ -39,3 +39,25 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,14 +26,13 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
base_url = get_base_url(api_base)
|
||||
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs:
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs:
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return XinferenceChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=base_url,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
**optional_params
|
||||
)
|
||||
|
|
|
|||
|
|
@ -45,3 +45,26 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.95,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4095,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,11 +33,17 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
|
||||
zhipuai_chat = ZhipuChatModel(
|
||||
temperature=0.5,
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name,
|
||||
max_tokens=model_kwargs.get('max_tokens', 5)
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
**optional_params
|
||||
)
|
||||
return zhipuai_chat
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def get_model_by_id(_id, user_id):
|
|||
return model
|
||||
|
||||
|
||||
def get_model_instance_by_model_user_id(model_id, user_id):
|
||||
def get_model_instance_by_model_user_id(model_id, user_id, **kwargs):
|
||||
"""
|
||||
获取模型实例,根据模型相关数据
|
||||
@param model_id: 模型id
|
||||
|
|
@ -30,4 +30,4 @@ def get_model_instance_by_model_user_id(model_id, user_id):
|
|||
@return: 模型实例
|
||||
"""
|
||||
model = get_model_by_id(model_id, user_id)
|
||||
return ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs))
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import { Result } from '@/request/Result'
|
||||
import { get, post, postStream, del, put } from '@/request/index'
|
||||
import type { pageRequest } from '@/api/type/common'
|
||||
import type { ApplicationFormType } from '@/api/type/application'
|
||||
import { type Ref } from 'vue'
|
||||
import {Result} from '@/request/Result'
|
||||
import {get, post, postStream, del, put} from '@/request/index'
|
||||
import type {pageRequest} from '@/api/type/common'
|
||||
import type {ApplicationFormType} from '@/api/type/application'
|
||||
import {type Ref} from 'vue'
|
||||
|
||||
const prefix = '/application'
|
||||
|
||||
|
|
@ -11,25 +11,25 @@ const prefix = '/application'
|
|||
* @param 参数
|
||||
*/
|
||||
const getAllAppilcation: () => Promise<Result<any[]>> = () => {
|
||||
return get(`${prefix}`)
|
||||
return get(`${prefix}`)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取分页应用
|
||||
* page {
|
||||
"current_page": "string",
|
||||
"page_size": "string",
|
||||
}
|
||||
"current_page": "string",
|
||||
"page_size": "string",
|
||||
}
|
||||
* param {
|
||||
"name": "string",
|
||||
}
|
||||
"name": "string",
|
||||
}
|
||||
*/
|
||||
const getApplication: (
|
||||
page: pageRequest,
|
||||
param: any,
|
||||
loading?: Ref<boolean>
|
||||
page: pageRequest,
|
||||
param: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (page, param, loading) => {
|
||||
return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading)
|
||||
return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -37,10 +37,10 @@ const getApplication: (
|
|||
* @param 参数
|
||||
*/
|
||||
const postApplication: (
|
||||
data: ApplicationFormType,
|
||||
loading?: Ref<boolean>
|
||||
data: ApplicationFormType,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (data, loading) => {
|
||||
return post(`${prefix}`, data, undefined, loading)
|
||||
return post(`${prefix}`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -48,11 +48,11 @@ const postApplication: (
|
|||
* @param 参数
|
||||
*/
|
||||
const putApplication: (
|
||||
application_id: String,
|
||||
data: ApplicationFormType,
|
||||
loading?: Ref<boolean>
|
||||
application_id: String,
|
||||
data: ApplicationFormType,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, data, loading) => {
|
||||
return put(`${prefix}/${application_id}`, data, undefined, loading)
|
||||
return put(`${prefix}/${application_id}`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -60,10 +60,10 @@ const putApplication: (
|
|||
* @param 参数 application_id
|
||||
*/
|
||||
const delApplication: (
|
||||
application_id: String,
|
||||
loading?: Ref<boolean>
|
||||
application_id: String,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<boolean>> = (application_id, loading) => {
|
||||
return del(`${prefix}/${application_id}`, undefined, {}, loading)
|
||||
return del(`${prefix}/${application_id}`, undefined, {}, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -71,10 +71,10 @@ const delApplication: (
|
|||
* @param 参数 application_id
|
||||
*/
|
||||
const getApplicationDetail: (
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, loading) => {
|
||||
return get(`${prefix}/${application_id}`, undefined, loading)
|
||||
return get(`${prefix}/${application_id}`, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -82,10 +82,10 @@ const getApplicationDetail: (
|
|||
* @param 参数 application_id
|
||||
*/
|
||||
const getApplicationDataset: (
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, loading) => {
|
||||
return get(`${prefix}/${application_id}/list_dataset`, undefined, loading)
|
||||
return get(`${prefix}/${application_id}/list_dataset`, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -93,10 +93,10 @@ const getApplicationDataset: (
|
|||
* @param 参数 application_id
|
||||
*/
|
||||
const getAccessToken: (application_id: string, loading?: Ref<boolean>) => Promise<Result<any>> = (
|
||||
application_id,
|
||||
loading
|
||||
application_id,
|
||||
loading
|
||||
) => {
|
||||
return get(`${prefix}/${application_id}/access_token`, undefined, loading)
|
||||
return get(`${prefix}/${application_id}/access_token`, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -107,71 +107,71 @@ const getAccessToken: (application_id: string, loading?: Ref<boolean>) => Promis
|
|||
* }
|
||||
*/
|
||||
const putAccessToken: (
|
||||
application_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
application_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, data, loading) => {
|
||||
return put(`${prefix}/${application_id}/access_token`, data, undefined, loading)
|
||||
return put(`${prefix}/${application_id}/access_token`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 应用认证
|
||||
* @param 参数
|
||||
* @param 参数
|
||||
{
|
||||
"access_token": "string"
|
||||
}
|
||||
"access_token": "string"
|
||||
}
|
||||
*/
|
||||
const postAppAuthentication: (access_token: string, loading?: Ref<boolean>) => Promise<any> = (
|
||||
access_token,
|
||||
loading
|
||||
access_token,
|
||||
loading
|
||||
) => {
|
||||
return post(`${prefix}/authentication`, { access_token }, undefined, loading)
|
||||
return post(`${prefix}/authentication`, {access_token}, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 对话获取应用相关信息
|
||||
* @param 参数
|
||||
* @param 参数
|
||||
{
|
||||
"access_token": "string"
|
||||
}
|
||||
"access_token": "string"
|
||||
}
|
||||
*/
|
||||
const getAppProfile: (loading?: Ref<boolean>) => Promise<any> = (loading) => {
|
||||
return get(`${prefix}/profile`, undefined, loading)
|
||||
return get(`${prefix}/profile`, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获得临时回话Id
|
||||
* @param 参数
|
||||
* @param 参数
|
||||
|
||||
}
|
||||
}
|
||||
*/
|
||||
const postChatOpen: (data: ApplicationFormType) => Promise<Result<any>> = (data) => {
|
||||
return post(`${prefix}/chat/open`, data)
|
||||
return post(`${prefix}/chat/open`, data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获得工作流临时回话Id
|
||||
* @param 参数
|
||||
* @param 参数
|
||||
|
||||
}
|
||||
}
|
||||
*/
|
||||
const postWorkflowChatOpen: (data: ApplicationFormType) => Promise<Result<any>> = (data) => {
|
||||
return post(`${prefix}/chat_workflow/open`, data)
|
||||
return post(`${prefix}/chat_workflow/open`, data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 正式回话Id
|
||||
* @param 参数
|
||||
* @param 参数
|
||||
* {
|
||||
"model_id": "string",
|
||||
"multiple_rounds_dialogue": true,
|
||||
"dataset_id_list": [
|
||||
"string"
|
||||
]
|
||||
}
|
||||
"model_id": "string",
|
||||
"multiple_rounds_dialogue": true,
|
||||
"dataset_id_list": [
|
||||
"string"
|
||||
]
|
||||
}
|
||||
*/
|
||||
const getChatOpen: (application_id: String) => Promise<Result<any>> = (application_id) => {
|
||||
return get(`${prefix}/${application_id}/chat/open`)
|
||||
return get(`${prefix}/${application_id}/chat/open`)
|
||||
}
|
||||
/**
|
||||
* 对话
|
||||
|
|
@ -180,32 +180,32 @@ const getChatOpen: (application_id: String) => Promise<Result<any>> = (applicati
|
|||
* data
|
||||
*/
|
||||
const postChatMessage: (chat_id: string, data: any) => Promise<any> = (chat_id, data) => {
|
||||
return postStream(`/api${prefix}/chat_message/${chat_id}`, data)
|
||||
return postStream(`/api${prefix}/chat_message/${chat_id}`, data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 点赞、点踩
|
||||
* @param 参数
|
||||
* @param 参数
|
||||
* application_id : string; chat_id : string; chat_record_id : string
|
||||
* {
|
||||
"vote_status": "string", // -1 0 1
|
||||
}
|
||||
"vote_status": "string", // -1 0 1
|
||||
}
|
||||
*/
|
||||
const putChatVote: (
|
||||
application_id: string,
|
||||
chat_id: string,
|
||||
chat_record_id: string,
|
||||
vote_status: string,
|
||||
loading?: Ref<boolean>
|
||||
application_id: string,
|
||||
chat_id: string,
|
||||
chat_record_id: string,
|
||||
vote_status: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<any> = (application_id, chat_id, chat_record_id, vote_status, loading) => {
|
||||
return put(
|
||||
`${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`,
|
||||
{
|
||||
vote_status
|
||||
},
|
||||
undefined,
|
||||
loading
|
||||
)
|
||||
return put(
|
||||
`${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`,
|
||||
{
|
||||
vote_status
|
||||
},
|
||||
undefined,
|
||||
loading
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -216,11 +216,11 @@ const putChatVote: (
|
|||
* @returns
|
||||
*/
|
||||
const getApplicationHitTest: (
|
||||
application_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
application_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<Array<any>>> = (application_id, data, loading) => {
|
||||
return get(`${prefix}/${application_id}/hit_test`, data, loading)
|
||||
return get(`${prefix}/${application_id}/hit_test`, data, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -231,10 +231,10 @@ const getApplicationHitTest: (
|
|||
* @returns
|
||||
*/
|
||||
const getApplicationModel: (
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<Array<any>>> = (application_id, loading) => {
|
||||
return get(`${prefix}/${application_id}/model`, loading)
|
||||
return get(`${prefix}/${application_id}/model`, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -242,31 +242,61 @@ const getApplicationModel: (
|
|||
* @param 参数
|
||||
*/
|
||||
const putPublishApplication: (
|
||||
application_id: String,
|
||||
data: ApplicationFormType,
|
||||
loading?: Ref<boolean>
|
||||
application_id: String,
|
||||
data: ApplicationFormType,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, data, loading) => {
|
||||
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
|
||||
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型其他配置
|
||||
* @param application_id
|
||||
* @param model_id
|
||||
* @param loading
|
||||
* @returns
|
||||
*/
|
||||
const getModelOtherConfig: (
|
||||
application_id: string,
|
||||
model_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, model_id, loading) => {
|
||||
return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存其他配置信息
|
||||
* @param application_id
|
||||
*
|
||||
*/
|
||||
const putModelOtherConfig: (
|
||||
application_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, data, loading) => {
|
||||
return put(`${prefix}/${application_id}/other-config`, data, undefined, loading)
|
||||
}
|
||||
|
||||
export default {
|
||||
getAllAppilcation,
|
||||
getApplication,
|
||||
postApplication,
|
||||
putApplication,
|
||||
postChatOpen,
|
||||
getChatOpen,
|
||||
postChatMessage,
|
||||
delApplication,
|
||||
getApplicationDetail,
|
||||
getApplicationDataset,
|
||||
getAccessToken,
|
||||
putAccessToken,
|
||||
postAppAuthentication,
|
||||
getAppProfile,
|
||||
putChatVote,
|
||||
getApplicationHitTest,
|
||||
getApplicationModel,
|
||||
putPublishApplication,
|
||||
postWorkflowChatOpen
|
||||
getAllAppilcation,
|
||||
getApplication,
|
||||
postApplication,
|
||||
putApplication,
|
||||
postChatOpen,
|
||||
getChatOpen,
|
||||
postChatMessage,
|
||||
delApplication,
|
||||
getApplicationDetail,
|
||||
getApplicationDataset,
|
||||
getAccessToken,
|
||||
putAccessToken,
|
||||
postAppAuthentication,
|
||||
getAppProfile,
|
||||
putChatVote,
|
||||
getApplicationHitTest,
|
||||
getApplicationModel,
|
||||
putPublishApplication,
|
||||
postWorkflowChatOpen,
|
||||
getModelOtherConfig,
|
||||
putModelOtherConfig
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,136 +1,163 @@
|
|||
import { defineStore } from 'pinia'
|
||||
import {defineStore} from 'pinia'
|
||||
import applicationApi from '@/api/application'
|
||||
import applicationXpackApi from '@/api/application-xpack'
|
||||
import { type Ref } from 'vue'
|
||||
import {type Ref, type UnwrapRef} from 'vue'
|
||||
|
||||
import useUserStore from './user'
|
||||
import type {ApplicationFormType} from "@/api/type/application";
|
||||
|
||||
const useApplicationStore = defineStore({
|
||||
id: 'application',
|
||||
state: () => ({
|
||||
location: `${window.location.origin}/ui/chat/`
|
||||
}),
|
||||
actions: {
|
||||
async asyncGetAllApplication() {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getAllAppilcation()
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
id: 'application',
|
||||
state: () => ({
|
||||
location: `${window.location.origin}/ui/chat/`
|
||||
}),
|
||||
actions: {
|
||||
async asyncGetAllApplication() {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getAllAppilcation()
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
|
||||
async asyncGetApplicationDetail(id: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getApplicationDetail(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
async asyncGetApplicationDetail(id: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getApplicationDetail(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
|
||||
async asyncGetApplicationDataset(id: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getApplicationDataset(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
async asyncGetApplicationDataset(id: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getApplicationDataset(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
|
||||
async asyncGetAccessToken(id: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const user = useUserStore()
|
||||
if (user.isEnterprise()) {
|
||||
applicationXpackApi
|
||||
.getAccessToken(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
async asyncGetAccessToken(id: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const user = useUserStore()
|
||||
if (user.isEnterprise()) {
|
||||
applicationXpackApi
|
||||
.getAccessToken(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
} else {
|
||||
applicationApi
|
||||
.getAccessToken(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
} else {
|
||||
applicationApi
|
||||
.getAccessToken(id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
}
|
||||
})
|
||||
},
|
||||
},
|
||||
|
||||
async asyncGetAppProfile(loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const user = useUserStore()
|
||||
if (user.isEnterprise()) {
|
||||
applicationXpackApi
|
||||
.getAppXpackProfile(loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
async asyncGetAppProfile(loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const user = useUserStore()
|
||||
if (user.isEnterprise()) {
|
||||
applicationXpackApi
|
||||
.getAppXpackProfile(loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
} else {
|
||||
applicationApi
|
||||
.getAppProfile(loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
} else {
|
||||
applicationApi
|
||||
.getAppProfile(loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
}
|
||||
})
|
||||
},
|
||||
},
|
||||
|
||||
async asyncAppAuthentication(token: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.postAppAuthentication(token, loading)
|
||||
.then((res) => {
|
||||
localStorage.setItem('accessToken', res.data)
|
||||
sessionStorage.setItem('accessToken', res.data)
|
||||
resolve(res)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
async refreshAccessToken(token: string) {
|
||||
this.asyncAppAuthentication(token)
|
||||
},
|
||||
// 修改应用
|
||||
async asyncPutApplication(id: string, data: any, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.putApplication(id, data, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
async asyncAppAuthentication(token: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.postAppAuthentication(token, loading)
|
||||
.then((res) => {
|
||||
localStorage.setItem('accessToken', res.data)
|
||||
sessionStorage.setItem('accessToken', res.data)
|
||||
resolve(res)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
async refreshAccessToken(token: string) {
|
||||
this.asyncAppAuthentication(token)
|
||||
},
|
||||
// 修改应用
|
||||
async asyncPutApplication(id: string, data: any, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.putApplication(id, data, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
// 获取模型的 温度/max_token字段设置
|
||||
async asyncGetModelConfig(id: string, model_id: string, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getModelOtherConfig(id, model_id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
// 保存应用的 温度/max_token字段设置
|
||||
async asyncPostModelConfig(id: string, params: any, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.putModelOtherConfig(id, params, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
export default useApplicationStore
|
||||
|
|
|
|||
|
|
@ -67,9 +67,14 @@
|
|||
<template #label>
|
||||
<div class="flex-between">
|
||||
<span>{{ $t('views.application.applicationForm.form.aiModel.label') }}</span>
|
||||
<!-- <el-button type="primary" link @click="openAIParamSettingDialog">
|
||||
<el-button
|
||||
type="primary"
|
||||
link
|
||||
@click="openAIParamSettingDialog"
|
||||
:disabled="!applicationForm.model_id"
|
||||
>
|
||||
{{ $t('views.application.applicationForm.form.paramSetting') }}
|
||||
</el-button> -->
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
<el-select
|
||||
|
|
@ -104,9 +109,9 @@
|
|||
>公用
|
||||
</el-tag>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id"
|
||||
><Check
|
||||
/></el-icon>
|
||||
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id">
|
||||
<Check />
|
||||
</el-icon>
|
||||
</el-option>
|
||||
<!-- 不可用 -->
|
||||
<el-option
|
||||
|
|
@ -127,15 +132,17 @@
|
|||
$t('views.application.applicationForm.form.aiModel.unavailable')
|
||||
}}</span>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id"
|
||||
><Check
|
||||
/></el-icon>
|
||||
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id">
|
||||
<Check />
|
||||
</el-icon>
|
||||
</el-option>
|
||||
</el-option-group>
|
||||
<template #footer>
|
||||
<div class="w-full text-left cursor" @click="openCreateModel()">
|
||||
<el-button type="primary" link>
|
||||
<el-icon class="mr-4"><Plus /></el-icon>
|
||||
<el-icon class="mr-4">
|
||||
<Plus />
|
||||
</el-icon>
|
||||
{{ $t('views.application.applicationForm.form.addModel') }}
|
||||
</el-button>
|
||||
</div>
|
||||
|
|
@ -155,12 +162,14 @@
|
|||
>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right">
|
||||
<template #content>{{
|
||||
$t('views.application.applicationForm.form.prompt.tooltip', {
|
||||
data: '{data}',
|
||||
question: '{question}'
|
||||
})
|
||||
}}</template>
|
||||
<template #content
|
||||
>{{
|
||||
$t('views.application.applicationForm.form.prompt.tooltip', {
|
||||
data: '{data}',
|
||||
question: '{question}'
|
||||
})
|
||||
}}
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
|
|
@ -192,20 +201,22 @@
|
|||
}}</span>
|
||||
<div>
|
||||
<el-button type="primary" link @click="openParamSettingDialog">
|
||||
<AppIcon iconName="app-operation" class="mr-4"></AppIcon
|
||||
>{{ $t('views.application.applicationForm.form.paramSetting') }}
|
||||
<AppIcon iconName="app-operation" class="mr-4"></AppIcon>
|
||||
{{ $t('views.application.applicationForm.form.paramSetting') }}
|
||||
</el-button>
|
||||
<el-button type="primary" link @click="openDatasetDialog">
|
||||
<el-icon class="mr-4"><Plus /></el-icon
|
||||
>{{ $t('views.application.applicationForm.form.add') }}
|
||||
<el-icon class="mr-4">
|
||||
<Plus />
|
||||
</el-icon>
|
||||
{{ $t('views.application.applicationForm.form.add') }}
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<div class="w-full">
|
||||
<el-text type="info" v-if="applicationForm.dataset_id_list?.length === 0">{{
|
||||
$t('views.application.applicationForm.form.relatedKnowledgeBaseWhere')
|
||||
}}</el-text>
|
||||
<el-text type="info" v-if="applicationForm.dataset_id_list?.length === 0"
|
||||
>{{ $t('views.application.applicationForm.form.relatedKnowledgeBaseWhere') }}
|
||||
</el-text>
|
||||
<el-row :gutter="12" v-else>
|
||||
<el-col
|
||||
:xs="24"
|
||||
|
|
@ -237,7 +248,9 @@
|
|||
</div>
|
||||
</div>
|
||||
<el-button text @click="removeDataset(item)">
|
||||
<el-icon><Close /></el-icon>
|
||||
<el-icon>
|
||||
<Close />
|
||||
</el-icon>
|
||||
</el-button>
|
||||
</div>
|
||||
</el-card>
|
||||
|
|
@ -314,7 +327,7 @@
|
|||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
|
||||
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" :id="id" />
|
||||
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
|
||||
<AddDatasetDialog
|
||||
ref="AddDatasetDialogRef"
|
||||
|
|
@ -351,6 +364,7 @@ import { relatedObject } from '@/utils/utils'
|
|||
import { MsgSuccess } from '@/utils/message'
|
||||
import useStore from '@/stores'
|
||||
import { t } from '@/locales'
|
||||
|
||||
const { model, application } = useStore()
|
||||
|
||||
const route = useRoute()
|
||||
|
|
@ -436,7 +450,14 @@ const submit = async (formEl: FormInstance | undefined) => {
|
|||
}
|
||||
|
||||
const openAIParamSettingDialog = () => {
|
||||
AIModeParamSettingDialogRef.value?.open(applicationForm.value)
|
||||
const model_id = applicationForm.value.model_id
|
||||
if (!model_id) {
|
||||
MsgSuccess(t('请选择AI 模型'))
|
||||
return
|
||||
}
|
||||
application.asyncGetModelConfig(id, model_id, loading).then((res: any) => {
|
||||
AIModeParamSettingDialogRef.value?.open(res.data)
|
||||
})
|
||||
}
|
||||
|
||||
const openParamSettingDialog = () => {
|
||||
|
|
@ -463,9 +484,11 @@ function removeDataset(id: any) {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
function addDataset(val: Array<string>) {
|
||||
applicationForm.value.dataset_id_list = val
|
||||
}
|
||||
|
||||
function openDatasetDialog() {
|
||||
AddDatasetDialogRef.value.open(applicationForm.value.dataset_id_list)
|
||||
}
|
||||
|
|
@ -525,15 +548,18 @@ onMounted(() => {
|
|||
.relate-dataset-card {
|
||||
color: var(--app-text-color);
|
||||
}
|
||||
|
||||
.dialog-bg {
|
||||
border-radius: 8px;
|
||||
background: var(--dialog-bg-gradient-color);
|
||||
overflow: hidden;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.scrollbar-height-left {
|
||||
height: calc(var(--app-main-height) - 64px);
|
||||
}
|
||||
|
||||
.scrollbar-height {
|
||||
height: calc(var(--app-main-height) - 160px);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,122 +8,135 @@
|
|||
append-to-body
|
||||
>
|
||||
<el-form label-position="top" ref="paramFormRef" :model="form">
|
||||
<el-form-item>
|
||||
<el-form-item v-for="(item, key) in form" :key="key">
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="flex-between mr-4">
|
||||
<span>温度</span>
|
||||
<span>{{ item.label }}</span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right">
|
||||
<template #content
|
||||
>较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定</template
|
||||
>
|
||||
<template #content>{{ item.tooltip }}</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-slider
|
||||
v-model="form.similarity"
|
||||
v-model="item.value"
|
||||
show-input
|
||||
:show-input-controls="false"
|
||||
:min="0"
|
||||
:max="1"
|
||||
:precision="2"
|
||||
:step="0.01"
|
||||
class="custom-slider"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="flex-between mr-4">
|
||||
<span>输出最大Tokens</span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right">
|
||||
<template #content>指定模型可生成的最大token个数</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-slider
|
||||
v-model="form.max_paragraph_char_number"
|
||||
show-input
|
||||
:show-input-controls="false"
|
||||
:min="1"
|
||||
:max="10000"
|
||||
:min="item.min"
|
||||
:max="item.max"
|
||||
:precision="item.precision || 0"
|
||||
:step="item.step || 1"
|
||||
class="custom-slider"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
|
||||
<template #footer>
|
||||
<span class="dialog-footer p-16">
|
||||
<el-button @click.prevent="dialogVisible = false">{{
|
||||
$t('views.application.applicationForm.buttons.cancel')
|
||||
}}</el-button>
|
||||
<el-button type="primary" @click="submit(paramFormRef)" :loading="loading">
|
||||
<el-button @click.prevent="dialogVisible = false">
|
||||
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||
</el-button>
|
||||
<el-button type="primary" @click="submit" :loading="loading">
|
||||
{{ $t('views.application.applicationForm.buttons.confirm') }}
|
||||
</el-button>
|
||||
</span>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, reactive } from 'vue'
|
||||
import { cloneDeep } from 'lodash'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import { ref, reactive, watch } from 'vue'
|
||||
import { cloneDeep, set } from 'lodash'
|
||||
import type { FormInstance } from 'element-plus'
|
||||
import useStore from '@/stores'
|
||||
|
||||
const { application } = useStore()
|
||||
|
||||
const emit = defineEmits(['refresh'])
|
||||
|
||||
const paramFormRef = ref()
|
||||
|
||||
const form = ref<any>({
|
||||
similarity: 0.6,
|
||||
max_paragraph_char_number: 5000
|
||||
})
|
||||
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
const paramFormRef = ref<FormInstance>()
|
||||
const form = reactive<Form>({})
|
||||
const dialogVisible = ref(false)
|
||||
const loading = ref(false)
|
||||
const props = defineProps<{
|
||||
id: any
|
||||
}>()
|
||||
const resetForm = () => {
|
||||
// 清空 form 对象,等待新的数据
|
||||
Object.keys(form).forEach((key) => delete form[key])
|
||||
}
|
||||
|
||||
watch(dialogVisible, (bool) => {
|
||||
if (!bool) {
|
||||
form.value = {
|
||||
similarity: 0.6,
|
||||
max_paragraph_char_number: 5000
|
||||
}
|
||||
}
|
||||
})
|
||||
interface Form {
|
||||
[key: string]: FormField
|
||||
}
|
||||
|
||||
interface FormField {
|
||||
value: any
|
||||
min?: number
|
||||
max?: number
|
||||
step?: number
|
||||
label?: string
|
||||
precision?: number
|
||||
tooltip?: string
|
||||
}
|
||||
|
||||
const open = (data: any) => {
|
||||
form.value = cloneDeep(data)
|
||||
|
||||
const newData = cloneDeep(data)
|
||||
Object.keys(form).forEach((key) => {
|
||||
delete form[key]
|
||||
})
|
||||
Object.keys(newData).forEach((key) => {
|
||||
set(form, key, newData[key])
|
||||
})
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const submit = async (formEl: FormInstance | undefined) => {
|
||||
if (!formEl) return
|
||||
await formEl.validate((valid, fields) => {
|
||||
if (valid) {
|
||||
emit('refresh', form.value)
|
||||
dialogVisible.value = false
|
||||
}
|
||||
})
|
||||
const submit = async () => {
|
||||
if (paramFormRef.value) {
|
||||
await paramFormRef.value.validate((valid, fields) => {
|
||||
if (valid) {
|
||||
const data = Object.keys(form).reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = form[key].value
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, any>
|
||||
)
|
||||
application.asyncPostModelConfig(props.id, data, loading).then(() => {
|
||||
emit('refresh', form)
|
||||
dialogVisible.value = false
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
watch(dialogVisible, (bool) => {
|
||||
if (!bool) {
|
||||
resetForm()
|
||||
}
|
||||
})
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
<style lang="scss" scope>
|
||||
|
||||
<style lang="scss" scoped>
|
||||
.aiMode-param-dialog {
|
||||
padding: 8px 8px 24px 8px;
|
||||
|
||||
.el-dialog__header {
|
||||
padding: 16px 16px 0 16px;
|
||||
}
|
||||
|
||||
.el-dialog__body {
|
||||
padding: 16px !important;
|
||||
}
|
||||
|
||||
.dialog-max-height {
|
||||
height: 550px;
|
||||
}
|
||||
|
||||
.custom-slider {
|
||||
.el-input-number.is-without-controls .el-input__wrapper {
|
||||
padding: 0 !important;
|
||||
|
|
|
|||
|
|
@ -1,42 +1,51 @@
|
|||
<template>
|
||||
<div class="p-24" v-loading="loading">
|
||||
<el-form
|
||||
ref="authFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
ref="authFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
>
|
||||
<el-form-item :label="$t('login.cas.ldpUri')" prop="config_data.ldpUri">
|
||||
<el-input v-model="form.config_data.ldpUri" :placeholder="$t('login.cas.ldpUriPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.ldpUri"
|
||||
:placeholder="$t('login.cas.ldpUriPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.cas.redirectUrl')" prop="config_data.redirectUrl">
|
||||
<el-input v-model="form.config_data.redirectUrl" :placeholder="$t('login.cas.redirectUrlPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.redirectUrl"
|
||||
:placeholder="$t('login.cas.redirectUrlPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<el-checkbox v-model="form.is_active">{{ $t('login.cas.enableAuthentication') }}</el-checkbox>
|
||||
<el-checkbox v-model="form.is_active"
|
||||
>{{ $t('login.cas.enableAuthentication') }}
|
||||
</el-checkbox>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
|
||||
<div class="text-right">
|
||||
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading"> {{ $t('login.cas.save') }}
|
||||
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading">
|
||||
{{ $t('login.cas.save') }}
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import {reactive, ref, watch, onMounted} from 'vue'
|
||||
import { reactive, ref, watch, onMounted } from 'vue'
|
||||
import authApi from '@/api/auth-setting'
|
||||
import type {FormInstance, FormRules} from 'element-plus'
|
||||
import {t} from '@/locales'
|
||||
import {MsgSuccess} from '@/utils/message'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import { t } from '@/locales'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
|
||||
const form = ref<any>({
|
||||
id: '',
|
||||
auth_type: 'CAS',
|
||||
config_data: {
|
||||
ldpUri: '',
|
||||
redirectUrl: '',
|
||||
redirectUrl: ''
|
||||
},
|
||||
is_active: true
|
||||
})
|
||||
|
|
@ -46,27 +55,25 @@ const authFormRef = ref()
|
|||
const loading = ref(false)
|
||||
|
||||
const rules = reactive<FormRules<any>>({
|
||||
'config_data.ldpUri': [{required: true, message: t('login.ldap.ldpUriPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.redirectUrl': [{
|
||||
required: true,
|
||||
message: t('login.ldap.redirectUrlPlaceholder'),
|
||||
trigger: 'blur'
|
||||
}],
|
||||
'config_data.ldpUri': [
|
||||
{ required: true, message: t('login.cas.ldpUriPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.redirectUrl': [
|
||||
{
|
||||
required: true,
|
||||
message: t('login.cas.redirectUrlPlaceholder'),
|
||||
trigger: 'blur'
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const submit = async (formEl: FormInstance | undefined, test?: string) => {
|
||||
const submit = async (formEl: FormInstance | undefined) => {
|
||||
if (!formEl) return
|
||||
await formEl.validate((valid, fields) => {
|
||||
if (valid) {
|
||||
if (test) {
|
||||
authApi.postAuthSetting(form.value, loading).then((res) => {
|
||||
MsgSuccess(t('login.cas.testConnectionSuccess'))
|
||||
})
|
||||
} else {
|
||||
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
|
||||
MsgSuccess(t('login.cas.saveSuccess'))
|
||||
})
|
||||
}
|
||||
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
|
||||
MsgSuccess(t('login.cas.saveSuccess'))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,49 +1,69 @@
|
|||
<template>
|
||||
<div class="p-24" v-loading="loading">
|
||||
<el-form
|
||||
ref="authFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
ref="authFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
>
|
||||
<el-form-item :label="$t('login.ldap.address')" prop="config_data.ldap_server">
|
||||
<el-input v-model="form.config_data.ldap_server" :placeholder="$t('login.ldap.serverPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.ldap_server"
|
||||
:placeholder="$t('login.ldap.serverPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.ldap.bindDN')" prop="config_data.base_dn">
|
||||
<el-input v-model="form.config_data.base_dn" :placeholder="$t('login.ldap.bindDNPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.base_dn"
|
||||
:placeholder="$t('login.ldap.bindDNPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.ldap.password')" prop="config_data.password">
|
||||
<el-input v-model="form.config_data.password" :placeholder="$t('login.ldap.passwordPlaceholder')"
|
||||
show-password/>
|
||||
<el-input
|
||||
v-model="form.config_data.password"
|
||||
:placeholder="$t('login.ldap.passwordPlaceholder')"
|
||||
show-password
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.ldap.ou')" prop="config_data.ou">
|
||||
<el-input v-model="form.config_data.ou" :placeholder="$t('login.ldap.ouPlaceholder')"/>
|
||||
<el-input v-model="form.config_data.ou" :placeholder="$t('login.ldap.ouPlaceholder')" />
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.ldap.ldap_filter')" prop="config_data.ldap_filter">
|
||||
<el-input v-model="form.config_data.ldap_filter" :placeholder="$t('login.ldap.ldap_filterPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.ldap_filter"
|
||||
:placeholder="$t('login.ldap.ldap_filterPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.ldap.ldap_mapping')" prop="config_data.ldap_mapping">
|
||||
<el-input v-model="form.config_data.ldap_mapping" placeholder='{"name":"name","email":"mail","username":"cn"}'/>
|
||||
<el-input
|
||||
v-model="form.config_data.ldap_mapping"
|
||||
placeholder='{"name":"name","email":"mail","username":"cn"}'
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<el-checkbox v-model="form.is_active">{{ $t('login.ldap.enableAuthentication') }}</el-checkbox>
|
||||
<el-checkbox v-model="form.is_active">{{
|
||||
$t('login.ldap.enableAuthentication')
|
||||
}}</el-checkbox>
|
||||
</el-form-item>
|
||||
<el-button @click="submit(authFormRef, 'test')" :disabled="loading"> {{ $t('login.ldap.test') }}</el-button>
|
||||
<el-button @click="submit(authFormRef, 'test')" :disabled="loading">
|
||||
{{ $t('login.ldap.test') }}</el-button
|
||||
>
|
||||
</el-form>
|
||||
|
||||
<div class="text-right">
|
||||
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading"> {{ $t('login.ldap.save') }}
|
||||
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading">
|
||||
{{ $t('login.ldap.save') }}
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import {reactive, ref, watch, onMounted} from 'vue'
|
||||
import { reactive, ref, watch, onMounted } from 'vue'
|
||||
import authApi from '@/api/auth-setting'
|
||||
import type {FormInstance, FormRules} from 'element-plus'
|
||||
import {t} from '@/locales'
|
||||
import {MsgSuccess} from '@/utils/message'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import { t } from '@/locales'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
|
||||
const form = ref<any>({
|
||||
id: '',
|
||||
|
|
@ -54,7 +74,7 @@ const form = ref<any>({
|
|||
password: '',
|
||||
ou: '',
|
||||
ldap_filter: '',
|
||||
ldap_mapping: '',
|
||||
ldap_mapping: ''
|
||||
},
|
||||
is_active: true
|
||||
})
|
||||
|
|
@ -64,12 +84,22 @@ const authFormRef = ref()
|
|||
const loading = ref(false)
|
||||
|
||||
const rules = reactive<FormRules<any>>({
|
||||
'config_data.ldap_server': [{required: true, message: t('login.ldap.serverPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.base_dn': [{required: true, message: t('login.ldap.bindDNPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.password': [{required: true, message: t('login.ldap.passwordPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.ou': [{required: true, message: t('login.ldap.ouPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.ldap_filter': [{required: true, message: t('login.ldap.ldap_filterPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.ldap_mapping': [{required: true, message: t('login.ldap.ldap_mappingPlaceholder'), trigger: 'blur'}]
|
||||
'config_data.ldap_server': [
|
||||
{ required: true, message: t('login.ldap.serverPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.base_dn': [
|
||||
{ required: true, message: t('login.ldap.bindDNPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.password': [
|
||||
{ required: true, message: t('login.ldap.passwordPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.ou': [{ required: true, message: t('login.ldap.ouPlaceholder'), trigger: 'blur' }],
|
||||
'config_data.ldap_filter': [
|
||||
{ required: true, message: t('login.ldap.ldap_filterPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.ldap_mapping': [
|
||||
{ required: true, message: t('login.ldap.ldap_mappingPlaceholder'), trigger: 'blur' }
|
||||
]
|
||||
})
|
||||
|
||||
const submit = async (formEl: FormInstance | undefined, test?: string) => {
|
||||
|
|
@ -94,7 +124,9 @@ function getDetail() {
|
|||
if (res.data && JSON.stringify(res.data) !== '{}') {
|
||||
form.value = res.data
|
||||
if (res.data.config_data.ldap_mapping) {
|
||||
form.value.config_data.ldap_mapping = JSON.stringify(JSON.parse(res.data.config_data.ldap_mapping))
|
||||
form.value.config_data.ldap_mapping = JSON.stringify(
|
||||
JSON.parse(res.data.config_data.ldap_mapping)
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,49 +1,69 @@
|
|||
<template>
|
||||
<div class="p-24" v-loading="loading">
|
||||
<el-form
|
||||
ref="authFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
ref="authFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
>
|
||||
<el-form-item :label="$t('login.oidc.authEndpoint')" prop="config_data.authEndpoint">
|
||||
<el-input v-model="form.config_data.authEndpoint" :placeholder="$t('login.oidc.authEndpointPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.authEndpoint"
|
||||
:placeholder="$t('login.oidc.authEndpointPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.oidc.tokenEndpoint')" prop="config_data.tokenEndpoint">
|
||||
<el-input v-model="form.config_data.tokenEndpoint" :placeholder="$t('login.oidc.tokenEndpointPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.tokenEndpoint"
|
||||
:placeholder="$t('login.oidc.tokenEndpointPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.oidc.userInfoEndpoint')" prop="config_data.userInfoEndpoint">
|
||||
<el-input v-model="form.config_data.userInfoEndpoint"
|
||||
:placeholder="$t('login.oidc.userInfoEndpointPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.userInfoEndpoint"
|
||||
:placeholder="$t('login.oidc.userInfoEndpointPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.oidc.clientId')" prop="config_data.clientId">
|
||||
<el-input v-model="form.config_data.clientId" :placeholder="$t('login.oidc.clientIdPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.clientId"
|
||||
:placeholder="$t('login.oidc.clientIdPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.oidc.clientSecret')" prop="config_data.clientSecret">
|
||||
<el-input v-model="form.config_data.clientSecret" :placeholder="$t('login.oidc.clientSecretPlaceholder')"
|
||||
show-password/>
|
||||
<el-input
|
||||
v-model="form.config_data.clientSecret"
|
||||
:placeholder="$t('login.oidc.clientSecretPlaceholder')"
|
||||
show-password
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('login.oidc.redirectUrl')" prop="config_data.redirectUrl">
|
||||
<el-input v-model="form.config_data.redirectUrl" :placeholder="$t('login.oidc.redirectUrlPlaceholder')"/>
|
||||
<el-input
|
||||
v-model="form.config_data.redirectUrl"
|
||||
:placeholder="$t('login.oidc.redirectUrlPlaceholder')"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<el-checkbox v-model="form.is_active">{{ $t('login.oidc.enableAuthentication') }}</el-checkbox>
|
||||
<el-checkbox v-model="form.is_active"
|
||||
>{{ $t('login.oidc.enableAuthentication') }}
|
||||
</el-checkbox>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
|
||||
<div class="text-right">
|
||||
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading"> {{ $t('login.ldap.save') }}
|
||||
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading">
|
||||
{{ $t('login.ldap.save') }}
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import {reactive, ref, watch, onMounted} from 'vue'
|
||||
import { reactive, ref, watch, onMounted } from 'vue'
|
||||
import authApi from '@/api/auth-setting'
|
||||
import type {FormInstance, FormRules} from 'element-plus'
|
||||
import {t} from '@/locales'
|
||||
import {MsgSuccess} from '@/utils/message'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import { t } from '@/locales'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
|
||||
const form = ref<any>({
|
||||
id: '',
|
||||
|
|
@ -64,32 +84,40 @@ const authFormRef = ref()
|
|||
const loading = ref(false)
|
||||
|
||||
const rules = reactive<FormRules<any>>({
|
||||
'config_data.authEndpoint': [{required: true, message: t('login.oidc.authEndpointPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.tokenEndpoint': [{required: true, message: t('login.oidc.tokenEndpointPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.userInfoEndpoint': [{
|
||||
required: true,
|
||||
message: t('login.oidc.userInfoEndpointPlaceholder'),
|
||||
trigger: 'blur'
|
||||
}],
|
||||
'config_data.clientId': [{required: true, message: t('login.oidc.clientIdPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.clientSecret': [{required: true, message: t('login.oidc.clientSecretPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.redirectUrl': [{required: true, message: t('login.oidc.redirectUrlPlaceholder'), trigger: 'blur'}],
|
||||
'config_data.logoutEndpoint': [{required: true, message: t('login.oidc.logoutEndpointPlaceholder'), trigger: 'blur'}]
|
||||
'config_data.authEndpoint': [
|
||||
{ required: true, message: t('login.oidc.authEndpointPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.tokenEndpoint': [
|
||||
{ required: true, message: t('login.oidc.tokenEndpointPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.userInfoEndpoint': [
|
||||
{
|
||||
required: true,
|
||||
message: t('login.oidc.userInfoEndpointPlaceholder'),
|
||||
trigger: 'blur'
|
||||
}
|
||||
],
|
||||
'config_data.clientId': [
|
||||
{ required: true, message: t('login.oidc.clientIdPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.clientSecret': [
|
||||
{ required: true, message: t('login.oidc.clientSecretPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.redirectUrl': [
|
||||
{ required: true, message: t('login.oidc.redirectUrlPlaceholder'), trigger: 'blur' }
|
||||
],
|
||||
'config_data.logoutEndpoint': [
|
||||
{ required: true, message: t('login.oidc.logoutEndpointPlaceholder'), trigger: 'blur' }
|
||||
]
|
||||
})
|
||||
|
||||
const submit = async (formEl: FormInstance | undefined, test?: string) => {
|
||||
if (!formEl) return
|
||||
await formEl.validate((valid, fields) => {
|
||||
if (valid) {
|
||||
if (test) {
|
||||
authApi.postAuthSetting(form.value, loading).then((res) => {
|
||||
MsgSuccess(t('login.ldap.testConnectionSuccess'))
|
||||
})
|
||||
} else {
|
||||
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
|
||||
MsgSuccess(t('login.ldap.saveSuccess'))
|
||||
})
|
||||
}
|
||||
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
|
||||
MsgSuccess(t('login.ldap.saveSuccess'))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -98,9 +126,6 @@ function getDetail() {
|
|||
authApi.getAuthSetting(form.value.auth_type, loading).then((res: any) => {
|
||||
if (res.data && JSON.stringify(res.data) !== '{}') {
|
||||
form.value = res.data
|
||||
if (res.data.config_data.ldap_mapping) {
|
||||
form.value.config_data.ldap_mapping = JSON.stringify(JSON.parse(res.data.config_data.ldap_mapping))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
@click="clickListHandle"
|
||||
value-key="provider"
|
||||
default-active=""
|
||||
style="overflow-y: auto"
|
||||
>
|
||||
<template #default="{ row, index }">
|
||||
<div class="flex" v-if="index === 0">
|
||||
|
|
|
|||
Loading…
Reference in New Issue