refactor: 优化模型参数

This commit is contained in:
wxg0103 2024-08-14 18:54:32 +08:00 committed by wxg0103
parent a556d27f74
commit 210f99bf70
48 changed files with 1281 additions and 510 deletions

View File

@ -73,6 +73,9 @@ class IChatStep(IBaseChatPipelineStep):
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)

View File

@ -111,7 +111,7 @@ class BaseChatStep(IChatStep):
client_id=None, client_type=None,
no_references_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
chat_model = get_model_instance_by_model_user_id(model_id, user_id, **kwargs) if model_id is not None else None
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,

View File

@ -43,6 +43,7 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
from django.conf import settings
@ -109,6 +110,10 @@ class DatasetSettingSerializer(serializers.Serializer):
class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
temperature = serializers.FloatField(required=False, allow_null=True,
error_messages=ErrMessage.char("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))
class ApplicationWorkflowSerializer(serializers.Serializer):
@ -541,6 +546,7 @@ class ApplicationSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("模型id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -722,6 +728,35 @@ class ApplicationSerializer(serializers.Serializer):
[self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None,
application.user_id, self.data.get('user_id')])
def get_other_file_list(self):
temperature = None
max_tokens = None
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
if application:
setting_dict = application.model_setting
temperature = setting_dict.get("temperature")
max_tokens = setting_dict.get("max_tokens")
model = Model.objects.filter(id=self.initial_data.get("model_id")).first()
if model:
res = ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).get_other_fields(
model.model_name)
if temperature and res.get('temperature'):
res['temperature']['value'] = temperature
if max_tokens and res.get('max_tokens'):
res['max_tokens']['value'] = max_tokens
return res
def save_other_config(self, data):
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
if application:
setting_dict = application.model_setting
for key in ['max_tokens', 'temperature']:
if key in data:
setting_dict[key] = data[key]
application.model_setting = setting_dict
application.save()
class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
model = ApplicationApiKey

View File

@ -80,6 +80,8 @@ class ChatInfo:
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'temperature': model_setting.get('temperature') if 'temperature' in model_setting else None,
'max_tokens': model_setting.get('max_tokens') if 'max_tokens' in model_setting else None,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.application.dataset_setting.get(

View File

@ -19,6 +19,8 @@ urlpatterns = [
path('application/<str:application_id>/statistics/chat_record_aggregate_trend',
views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()),
path('application/<str:application_id>/model', views.Application.Model.as_view()),
path('application/<str:application_id>/model/<str:model_id>', views.Application.Model.Operate.as_view()),
path('application/<str:application_id>/other-config', views.Application.Model.OtherConfig.as_view()),
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
path("application/<str:application_id>/api_key/<str:api_key_id>",

View File

@ -187,6 +187,28 @@ class Application(APIView):
data={'application_id': application_id,
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
class Operate(APIView):
authentication_classes = [TokenAuth]
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
operation_id="获取应用参数设置其他字段",
tags=["应用/会话"])
def get(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'model_id': model_id}).get_other_file_list())
class OtherConfig(APIView):
authentication_classes = [TokenAuth]
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
operation_id="获取应用参数设置其他字段",
tags=["应用/会话"])
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id}).save_other_config(request.data))
class Profile(APIView):
authentication_classes = [TokenAuth]

View File

@ -13,7 +13,7 @@ from common.util.rsa_util import rsa_long_decrypt
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def get_model_(provider, model_type, model_name, credential):
def get_model_(provider, model_type, model_name, credential, **kwargs):
"""
获取模型实例
@param provider: 供应商
@ -25,17 +25,17 @@ def get_model_(provider, model_type, model_name, credential):
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
streaming=True)
streaming=True, **kwargs)
return model
def get_model(model):
def get_model(model, **kwargs):
"""
获取模型实例
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
return get_model_(model.provider, model.model_type, model.model_name, model.credential, **kwargs)
def get_provider(provider):

View File

@ -108,6 +108,13 @@ class BaseModelCredential(ABC):
"""
pass
def get_other_fields(self, model_name):
"""
获取其他字段
:return:
"""
pass
@staticmethod
def encryption(message: str):
"""

View File

@ -60,3 +60,16 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
region_name = forms.TextInputField('Region Name', required=True)
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
def get_other_fields(self, model_name):
return {
'max_tokens': {
'value': 1024,
'min': 1,
'max': 8192,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -1,12 +1,18 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional, Iterator
from langchain_community.chat_models import BedrockChat
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from langchain_community.chat_models.bedrock import ChatPromptAdapter
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
class BedrockModel(MaxKBBaseModel, BedrockChat):
@staticmethod
def is_cache_model():
return False
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
streaming: bool = False, **kwargs):
super().__init__(model_id=model_id, region_name=region_name,
@ -15,21 +21,52 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
streaming=model_kwargs.pop('streaming', False),
**model_kwargs
**optional_params
)
def _get_num_tokens(self, content: str) -> int:
"""Helper method to count tokens in a string."""
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(content))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)
def get_num_tokens(self, text: str) -> int:
return self._get_num_tokens(text)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
if provider == "anthropic":
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt,
system=system,
messages=formatted_messages,
stop=stop,
run_manager=run_manager,
**kwargs,
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))

View File

@ -53,3 +53,25 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -6,36 +6,102 @@
@date2024/4/28 11:45
@desc:
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from typing import List, Dict, Optional, Any, Iterator, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure"
openai_api_type="azure",
**optional_params
)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.get_last_generation_info().get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
logprobs = None
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
yield generation_chunk

View File

@ -42,7 +42,7 @@ class BaseChatOpenAI(ChatOpenAI):
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if len(chunk["choices"]) == 0 or chunk["choices"][0]["finish_reason"] == "length" or chunk["choices"][0]["finish_reason"] == "stop":
if token_usage := chunk.get("usage"):
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
logprobs = None

View File

@ -46,3 +46,25 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -13,12 +13,24 @@ from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name,
openai_api_base='https://api.deepseek.com',
openai_api_key=model_credential.get('api_key')
openai_api_key=model_credential.get('api_key'),
**optional_params
)
return deepseek_chat_open_ai

View File

@ -32,7 +32,8 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
res = model.invoke([HumanMessage(content='你好')])
print(res)
except Exception as e:
if isinstance(e, AppApiException):
raise e
@ -46,3 +47,25 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -16,11 +16,23 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'temperature' in model_kwargs:
optional_params['temperature'] = model_kwargs['temperature']
if 'max_tokens' in model_kwargs:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
gemini_chat = GeminiChatModel(
model=model_name,
google_api_key=model_credential.get('api_key')
google_api_key=model_credential.get('api_key'),
**optional_params
)
return gemini_chat

View File

@ -47,3 +47,25 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.3,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -8,20 +8,28 @@
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
kimi_chat_open_ai = KimiChatModel(
openai_api_base=model_credential['api_base'],
openai_api_key=model_credential['api_key'],
model_name=model_name,
**optional_params
)
return kimi_chat_open_ai

View File

@ -42,3 +42,25 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.3,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -30,5 +30,12 @@ class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OllamaChatModel(model=model_name, openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'))
openai_api_key=model_credential.get('api_key'),
stream_usage=True, **optional_params)

View File

@ -47,3 +47,25 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -6,21 +6,32 @@
@date2024/4/18 15:28
@desc:
"""
from typing import List, Dict
from typing import List, Dict, Optional, Iterator, Any, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
streaming=model_kwargs.get('streaming', False),
max_tokens=model_kwargs.get('max_tokens', 5),
temperature=model_kwargs.get('temperature', 0.5),
**optional_params,
stream_usage=True
)
return azure_chat_open_ai

View File

@ -45,3 +45,16 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 1.0,
'min': 0.1,
'max': 1.9,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
}

View File

@ -6,10 +6,13 @@
@date2024/4/28 11:44
@desc:
"""
from typing import List, Dict
from typing import List, Dict, Optional, Iterator, Any
from langchain_community.chat_models import ChatTongyi
from langchain_community.llms.tongyi import generate_with_last_element_mark
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -18,16 +21,61 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
class QwenChatModel(MaxKBBaseModel, ChatTongyi):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
chat_tong_yi = QwenChatModel(
model_name=model_name,
dashscope_api_key=model_credential.get('api_key')
dashscope_api_key=model_credential.get('api_key'),
**optional_params,
)
return chat_tong_yi
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.get_last_generation_info().get('input_tokens', 0)
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params: Dict[str, Any] = self._invocation_params(
messages=messages, stop=stop, stream=True, **kwargs
)
for stream_resp, is_last_chunk in generate_with_last_element_mark(
self.stream_completion_with_retry(**params)
):
choice = stream_resp["output"]["choices"][0]
message = choice["message"]
if (
choice["finish_reason"] == "stop"
and message["content"] == ""
):
token_usage = stream_resp["usage"]
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
if (
choice["finish_reason"] == "null"
and message["content"] == ""
and "tool_calls" not in message
):
continue
chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
)
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

View File

@ -44,4 +44,17 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential):
hunyuan_app_id = forms.TextInputField('APP ID', required=True)
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.5,
'min': 0.1,
'max': 2.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
}

View File

@ -236,6 +236,11 @@ class ChatHunyuan(BaseChatModel):
choice["Delta"], default_chunk_class
)
default_chunk_class = chunk.__class__
# FinishReason === stop
if choice.get("FinishReason") == "stop":
self.__dict__.setdefault("_last_generation_info", {}).update(
response.get("Usage", {})
)
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

View File

@ -1,6 +1,6 @@
# coding=utf-8
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
@ -15,12 +15,18 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
optional_params = {}
if 'temperature' in kwargs:
optional_params['temperature'] = kwargs['temperature']
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
raise ValueError(
"All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.")
super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs)
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming,
temperature=optional_params.get('temperature', None)
)
@staticmethod
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
@ -28,10 +34,11 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
streaming = model_kwargs.pop('streaming', False)
return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages)
return self.get_last_generation_info().get('PromptTokens', 0)
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('CompletionTokens', 0)

View File

@ -48,3 +48,25 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
access_key_id = forms.PasswordInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
def get_other_fields(self):
return {
'temperature': {
'value': 0.3,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -1,27 +1,21 @@
from typing import List, Dict
from langchain_community.chat_models import VolcEngineMaasChat
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from langchain_openai import ChatOpenAI
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI):
class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
volcanic_engine_chat = VolcanicEngineChatModel(
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return VolcanicEngineChatModel(
model=model_name,
volc_engine_maas_ak=model_credential.get("access_key_id"),
volc_engine_maas_sk=model_credential.get("secret_access_key"),
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params
)
return volcanic_engine_chat
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -14,7 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
from smartdoc.conf import PROJECT_DIR
@ -24,7 +24,7 @@ model_info_list = [
ModelInfo('ep-xxxxxxxxxx-yyyy',
'用户前往火山方舟的模型推理页面创建推理接入点这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
ModelTypeConst.LLM,
volcanic_engine_llm_model_credential, OpenAIChatModel
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
)
]

View File

@ -53,3 +53,16 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
secret_key = forms.PasswordInputField("Secret Key", required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.95,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
}

View File

@ -6,28 +6,73 @@
@date2023/11/10 17:45
@desc:
"""
from typing import List, Dict
import uuid
from typing import List, Dict, Optional, Any, Iterator
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.messages import get_buffer_string
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import ChatGenerationChunk
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
)
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False))
streaming=model_kwargs.get('streaming', False),
**optional_params)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.get_last_generation_info().get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
params["stream"] = True
for res in self.client.do(**params):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
if msg.content == "":
token_usage = res.get("body").get("usage")
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk( # type: ignore[call-arg]
content=msg.content,
role="assistant",
additional_kwargs=additional_kwargs,
),
generation_info=msg.additional_kwargs,
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

View File

@ -49,3 +49,28 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
def get_other_fields(self, model_name):
max_value = 8192
if model_name == 'general' or model_name == 'pro-128k':
max_value = 4096
return {
'temperature': {
'value': 0.5,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 4096,
'min': 1,
'max': max_value,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -25,14 +25,19 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs:
optional_params['temperature'] = model_kwargs['temperature']
return XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
spark_llm_domain=model_name,
temperature=model_kwargs.get('temperature', 0.5),
max_tokens=model_kwargs.get('max_tokens', 5),
streaming=model_kwargs.get('streaming', False),
**optional_params
)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:

View File

@ -39,3 +39,25 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -26,14 +26,13 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {}
if 'max_tokens' in model_kwargs:
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs:
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XinferenceChatModel(
model=model_name,
openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'),
streaming=model_kwargs.get('streaming', False),
**optional_params
)

View File

@ -45,3 +45,26 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.95,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4095,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -33,11 +33,17 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs:
optional_params['temperature'] = model_kwargs['temperature']
zhipuai_chat = ZhipuChatModel(
temperature=0.5,
api_key=model_credential.get('api_key'),
model=model_name,
max_tokens=model_kwargs.get('max_tokens', 5)
streaming=model_kwargs.get('streaming', False),
**optional_params
)
return zhipuai_chat

View File

@ -22,7 +22,7 @@ def get_model_by_id(_id, user_id):
return model
def get_model_instance_by_model_user_id(model_id, user_id):
def get_model_instance_by_model_user_id(model_id, user_id, **kwargs):
"""
获取模型实例,根据模型相关数据
@param model_id: 模型id
@ -30,4 +30,4 @@ def get_model_instance_by_model_user_id(model_id, user_id):
@return: 模型实例
"""
model = get_model_by_id(model_id, user_id)
return ModelManage.get_model(model_id, lambda _id: get_model(model))
return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs))

View File

@ -1,8 +1,8 @@
import { Result } from '@/request/Result'
import { get, post, postStream, del, put } from '@/request/index'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
import {Result} from '@/request/Result'
import {get, post, postStream, del, put} from '@/request/index'
import type {pageRequest} from '@/api/type/common'
import type {ApplicationFormType} from '@/api/type/application'
import {type Ref} from 'vue'
const prefix = '/application'
@ -11,25 +11,25 @@ const prefix = '/application'
* @param
*/
const getAllAppilcation: () => Promise<Result<any[]>> = () => {
return get(`${prefix}`)
return get(`${prefix}`)
}
/**
*
* page {
"current_page": "string",
"page_size": "string",
}
"current_page": "string",
"page_size": "string",
}
* param {
"name": "string",
}
"name": "string",
}
*/
const getApplication: (
page: pageRequest,
param: any,
loading?: Ref<boolean>
page: pageRequest,
param: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (page, param, loading) => {
return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading)
return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading)
}
/**
@ -37,10 +37,10 @@ const getApplication: (
* @param
*/
const postApplication: (
data: ApplicationFormType,
loading?: Ref<boolean>
data: ApplicationFormType,
loading?: Ref<boolean>
) => Promise<Result<any>> = (data, loading) => {
return post(`${prefix}`, data, undefined, loading)
return post(`${prefix}`, data, undefined, loading)
}
/**
@ -48,11 +48,11 @@ const postApplication: (
* @param
*/
const putApplication: (
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}`, data, undefined, loading)
return put(`${prefix}/${application_id}`, data, undefined, loading)
}
/**
@ -60,10 +60,10 @@ const putApplication: (
* @param application_id
*/
const delApplication: (
application_id: String,
loading?: Ref<boolean>
application_id: String,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (application_id, loading) => {
return del(`${prefix}/${application_id}`, undefined, {}, loading)
return del(`${prefix}/${application_id}`, undefined, {}, loading)
}
/**
@ -71,10 +71,10 @@ const delApplication: (
* @param application_id
*/
const getApplicationDetail: (
application_id: string,
loading?: Ref<boolean>
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, loading) => {
return get(`${prefix}/${application_id}`, undefined, loading)
return get(`${prefix}/${application_id}`, undefined, loading)
}
/**
@ -82,10 +82,10 @@ const getApplicationDetail: (
* @param application_id
*/
const getApplicationDataset: (
application_id: string,
loading?: Ref<boolean>
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/list_dataset`, undefined, loading)
return get(`${prefix}/${application_id}/list_dataset`, undefined, loading)
}
/**
@ -93,10 +93,10 @@ const getApplicationDataset: (
* @param application_id
*/
const getAccessToken: (application_id: string, loading?: Ref<boolean>) => Promise<Result<any>> = (
application_id,
loading
application_id,
loading
) => {
return get(`${prefix}/${application_id}/access_token`, undefined, loading)
return get(`${prefix}/${application_id}/access_token`, undefined, loading)
}
/**
@ -107,71 +107,71 @@ const getAccessToken: (application_id: string, loading?: Ref<boolean>) => Promis
* }
*/
const putAccessToken: (
application_id: string,
data: any,
loading?: Ref<boolean>
application_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}/access_token`, data, undefined, loading)
return put(`${prefix}/${application_id}/access_token`, data, undefined, loading)
}
/**
*
* @param
* @param
{
"access_token": "string"
}
"access_token": "string"
}
*/
const postAppAuthentication: (access_token: string, loading?: Ref<boolean>) => Promise<any> = (
access_token,
loading
access_token,
loading
) => {
return post(`${prefix}/authentication`, { access_token }, undefined, loading)
return post(`${prefix}/authentication`, {access_token}, undefined, loading)
}
/**
*
* @param
* @param
{
"access_token": "string"
}
"access_token": "string"
}
*/
const getAppProfile: (loading?: Ref<boolean>) => Promise<any> = (loading) => {
return get(`${prefix}/profile`, undefined, loading)
return get(`${prefix}/profile`, undefined, loading)
}
/**
* Id
* @param
* @param
}
}
*/
const postChatOpen: (data: ApplicationFormType) => Promise<Result<any>> = (data) => {
return post(`${prefix}/chat/open`, data)
return post(`${prefix}/chat/open`, data)
}
/**
* Id
* @param
* @param
}
}
*/
const postWorkflowChatOpen: (data: ApplicationFormType) => Promise<Result<any>> = (data) => {
return post(`${prefix}/chat_workflow/open`, data)
return post(`${prefix}/chat_workflow/open`, data)
}
/**
* Id
* @param
* @param
* {
"model_id": "string",
"multiple_rounds_dialogue": true,
"dataset_id_list": [
"string"
]
}
"model_id": "string",
"multiple_rounds_dialogue": true,
"dataset_id_list": [
"string"
]
}
*/
const getChatOpen: (application_id: String) => Promise<Result<any>> = (application_id) => {
return get(`${prefix}/${application_id}/chat/open`)
return get(`${prefix}/${application_id}/chat/open`)
}
/**
*
@ -180,32 +180,32 @@ const getChatOpen: (application_id: String) => Promise<Result<any>> = (applicati
* data
*/
const postChatMessage: (chat_id: string, data: any) => Promise<any> = (chat_id, data) => {
return postStream(`/api${prefix}/chat_message/${chat_id}`, data)
return postStream(`/api${prefix}/chat_message/${chat_id}`, data)
}
/**
*
* @param
* @param
* application_id : string; chat_id : string; chat_record_id : string
* {
"vote_status": "string", // -1 0 1
}
"vote_status": "string", // -1 0 1
}
*/
const putChatVote: (
application_id: string,
chat_id: string,
chat_record_id: string,
vote_status: string,
loading?: Ref<boolean>
application_id: string,
chat_id: string,
chat_record_id: string,
vote_status: string,
loading?: Ref<boolean>
) => Promise<any> = (application_id, chat_id, chat_record_id, vote_status, loading) => {
return put(
`${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`,
{
vote_status
},
undefined,
loading
)
return put(
`${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`,
{
vote_status
},
undefined,
loading
)
}
/**
@ -216,11 +216,11 @@ const putChatVote: (
* @returns
*/
const getApplicationHitTest: (
application_id: string,
data: any,
loading?: Ref<boolean>
application_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, data, loading) => {
return get(`${prefix}/${application_id}/hit_test`, data, loading)
return get(`${prefix}/${application_id}/hit_test`, data, loading)
}
/**
@ -231,10 +231,10 @@ const getApplicationHitTest: (
* @returns
*/
const getApplicationModel: (
application_id: string,
loading?: Ref<boolean>
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, loading)
return get(`${prefix}/${application_id}/model`, loading)
}
/**
@ -242,31 +242,61 @@ const getApplicationModel: (
* @param
*/
const putPublishApplication: (
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
}
/**
*
* @param application_id
* @param model_id
* @param loading
* @returns
*/
const getModelOtherConfig: (
application_id: string,
model_id: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, model_id, loading) => {
return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading)
}
/**
*
* @param application_id
*
*/
const putModelOtherConfig: (
application_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}/other-config`, data, undefined, loading)
}
export default {
getAllAppilcation,
getApplication,
postApplication,
putApplication,
postChatOpen,
getChatOpen,
postChatMessage,
delApplication,
getApplicationDetail,
getApplicationDataset,
getAccessToken,
putAccessToken,
postAppAuthentication,
getAppProfile,
putChatVote,
getApplicationHitTest,
getApplicationModel,
putPublishApplication,
postWorkflowChatOpen
getAllAppilcation,
getApplication,
postApplication,
putApplication,
postChatOpen,
getChatOpen,
postChatMessage,
delApplication,
getApplicationDetail,
getApplicationDataset,
getAccessToken,
putAccessToken,
postAppAuthentication,
getAppProfile,
putChatVote,
getApplicationHitTest,
getApplicationModel,
putPublishApplication,
postWorkflowChatOpen,
getModelOtherConfig,
putModelOtherConfig
}

View File

@ -1,136 +1,163 @@
import { defineStore } from 'pinia'
import {defineStore} from 'pinia'
import applicationApi from '@/api/application'
import applicationXpackApi from '@/api/application-xpack'
import { type Ref } from 'vue'
import {type Ref, type UnwrapRef} from 'vue'
import useUserStore from './user'
import type {ApplicationFormType} from "@/api/type/application";
const useApplicationStore = defineStore({
id: 'application',
state: () => ({
location: `${window.location.origin}/ui/chat/`
}),
actions: {
async asyncGetAllApplication() {
return new Promise((resolve, reject) => {
applicationApi
.getAllAppilcation()
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
id: 'application',
state: () => ({
location: `${window.location.origin}/ui/chat/`
}),
actions: {
async asyncGetAllApplication() {
return new Promise((resolve, reject) => {
applicationApi
.getAllAppilcation()
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDetail(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDetail(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDetail(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDetail(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDataset(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDataset(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDataset(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDataset(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetAccessToken(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
async asyncGetAccessToken(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
})
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
})
},
},
async asyncGetAppProfile(loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAppXpackProfile(loading)
.then((data) => {
resolve(data)
async asyncGetAppProfile(loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAppXpackProfile(loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAppProfile(loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
})
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAppProfile(loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
})
},
},
async asyncAppAuthentication(token: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.postAppAuthentication(token, loading)
.then((res) => {
localStorage.setItem('accessToken', res.data)
sessionStorage.setItem('accessToken', res.data)
resolve(res)
})
.catch((error) => {
reject(error)
})
})
},
async refreshAccessToken(token: string) {
this.asyncAppAuthentication(token)
},
// 修改应用
async asyncPutApplication(id: string, data: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putApplication(id, data, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
async asyncAppAuthentication(token: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.postAppAuthentication(token, loading)
.then((res) => {
localStorage.setItem('accessToken', res.data)
sessionStorage.setItem('accessToken', res.data)
resolve(res)
})
.catch((error) => {
reject(error)
})
})
},
async refreshAccessToken(token: string) {
this.asyncAppAuthentication(token)
},
// 修改应用
async asyncPutApplication(id: string, data: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putApplication(id, data, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
// 获取模型的 温度/max_token字段设置
async asyncGetModelConfig(id: string, model_id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getModelOtherConfig(id, model_id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
// 保存应用的 温度/max_token字段设置
async asyncPostModelConfig(id: string, params: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putModelOtherConfig(id, params, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
}
}
})
export default useApplicationStore

View File

@ -67,9 +67,14 @@
<template #label>
<div class="flex-between">
<span>{{ $t('views.application.applicationForm.form.aiModel.label') }}</span>
<!-- <el-button type="primary" link @click="openAIParamSettingDialog">
<el-button
type="primary"
link
@click="openAIParamSettingDialog"
:disabled="!applicationForm.model_id"
>
{{ $t('views.application.applicationForm.form.paramSetting') }}
</el-button> -->
</el-button>
</div>
</template>
<el-select
@ -104,9 +109,9 @@
>公用
</el-tag>
</div>
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id"
><Check
/></el-icon>
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id">
<Check />
</el-icon>
</el-option>
<!-- 不可用 -->
<el-option
@ -127,15 +132,17 @@
$t('views.application.applicationForm.form.aiModel.unavailable')
}}</span>
</div>
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id"
><Check
/></el-icon>
<el-icon class="check-icon" v-if="item.id === applicationForm.model_id">
<Check />
</el-icon>
</el-option>
</el-option-group>
<template #footer>
<div class="w-full text-left cursor" @click="openCreateModel()">
<el-button type="primary" link>
<el-icon class="mr-4"><Plus /></el-icon>
<el-icon class="mr-4">
<Plus />
</el-icon>
{{ $t('views.application.applicationForm.form.addModel') }}
</el-button>
</div>
@ -155,12 +162,14 @@
>
</div>
<el-tooltip effect="dark" placement="right">
<template #content>{{
$t('views.application.applicationForm.form.prompt.tooltip', {
data: '{data}',
question: '{question}'
})
}}</template>
<template #content
>{{
$t('views.application.applicationForm.form.prompt.tooltip', {
data: '{data}',
question: '{question}'
})
}}
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
@ -192,20 +201,22 @@
}}</span>
<div>
<el-button type="primary" link @click="openParamSettingDialog">
<AppIcon iconName="app-operation" class="mr-4"></AppIcon
>{{ $t('views.application.applicationForm.form.paramSetting') }}
<AppIcon iconName="app-operation" class="mr-4"></AppIcon>
{{ $t('views.application.applicationForm.form.paramSetting') }}
</el-button>
<el-button type="primary" link @click="openDatasetDialog">
<el-icon class="mr-4"><Plus /></el-icon
>{{ $t('views.application.applicationForm.form.add') }}
<el-icon class="mr-4">
<Plus />
</el-icon>
{{ $t('views.application.applicationForm.form.add') }}
</el-button>
</div>
</div>
</template>
<div class="w-full">
<el-text type="info" v-if="applicationForm.dataset_id_list?.length === 0">{{
$t('views.application.applicationForm.form.relatedKnowledgeBaseWhere')
}}</el-text>
<el-text type="info" v-if="applicationForm.dataset_id_list?.length === 0"
>{{ $t('views.application.applicationForm.form.relatedKnowledgeBaseWhere') }}
</el-text>
<el-row :gutter="12" v-else>
<el-col
:xs="24"
@ -237,7 +248,9 @@
</div>
</div>
<el-button text @click="removeDataset(item)">
<el-icon><Close /></el-icon>
<el-icon>
<Close />
</el-icon>
</el-button>
</div>
</el-card>
@ -314,7 +327,7 @@
</el-col>
</el-row>
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" :id="id" />
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
<AddDatasetDialog
ref="AddDatasetDialogRef"
@ -351,6 +364,7 @@ import { relatedObject } from '@/utils/utils'
import { MsgSuccess } from '@/utils/message'
import useStore from '@/stores'
import { t } from '@/locales'
const { model, application } = useStore()
const route = useRoute()
@ -436,7 +450,14 @@ const submit = async (formEl: FormInstance | undefined) => {
}
const openAIParamSettingDialog = () => {
AIModeParamSettingDialogRef.value?.open(applicationForm.value)
const model_id = applicationForm.value.model_id
if (!model_id) {
MsgSuccess(t('请选择AI 模型'))
return
}
application.asyncGetModelConfig(id, model_id, loading).then((res: any) => {
AIModeParamSettingDialogRef.value?.open(res.data)
})
}
const openParamSettingDialog = () => {
@ -463,9 +484,11 @@ function removeDataset(id: any) {
)
}
}
function addDataset(val: Array<string>) {
applicationForm.value.dataset_id_list = val
}
function openDatasetDialog() {
AddDatasetDialogRef.value.open(applicationForm.value.dataset_id_list)
}
@ -525,15 +548,18 @@ onMounted(() => {
.relate-dataset-card {
color: var(--app-text-color);
}
.dialog-bg {
border-radius: 8px;
background: var(--dialog-bg-gradient-color);
overflow: hidden;
box-sizing: border-box;
}
.scrollbar-height-left {
height: calc(var(--app-main-height) - 64px);
}
.scrollbar-height {
height: calc(var(--app-main-height) - 160px);
}

View File

@ -8,122 +8,135 @@
append-to-body
>
<el-form label-position="top" ref="paramFormRef" :model="form">
<el-form-item>
<el-form-item v-for="(item, key) in form" :key="key">
<template #label>
<div class="flex align-center">
<div class="flex-between mr-4">
<span>温度</span>
<span>{{ item.label }}</span>
</div>
<el-tooltip effect="dark" placement="right">
<template #content
>较高的数值会使输出更加随机而较低的数值会使其更加集中和确定</template
>
<template #content>{{ item.tooltip }}</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-slider
v-model="form.similarity"
v-model="item.value"
show-input
:show-input-controls="false"
:min="0"
:max="1"
:precision="2"
:step="0.01"
class="custom-slider"
/>
</el-form-item>
<el-form-item>
<template #label>
<div class="flex align-center">
<div class="flex-between mr-4">
<span>输出最大Tokens</span>
</div>
<el-tooltip effect="dark" placement="right">
<template #content>指定模型可生成的最大token个数</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-slider
v-model="form.max_paragraph_char_number"
show-input
:show-input-controls="false"
:min="1"
:max="10000"
:min="item.min"
:max="item.max"
:precision="item.precision || 0"
:step="item.step || 1"
class="custom-slider"
/>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer p-16">
<el-button @click.prevent="dialogVisible = false">{{
$t('views.application.applicationForm.buttons.cancel')
}}</el-button>
<el-button type="primary" @click="submit(paramFormRef)" :loading="loading">
<el-button @click.prevent="dialogVisible = false">
{{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button>
<el-button type="primary" @click="submit" :loading="loading">
{{ $t('views.application.applicationForm.buttons.confirm') }}
</el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch, reactive } from 'vue'
import { cloneDeep } from 'lodash'
import type { FormInstance, FormRules } from 'element-plus'
import { ref, reactive, watch } from 'vue'
import { cloneDeep, set } from 'lodash'
import type { FormInstance } from 'element-plus'
import useStore from '@/stores'
const { application } = useStore()
const emit = defineEmits(['refresh'])
const paramFormRef = ref()
const form = ref<any>({
similarity: 0.6,
max_paragraph_char_number: 5000
})
const dialogVisible = ref<boolean>(false)
const paramFormRef = ref<FormInstance>()
const form = reactive<Form>({})
const dialogVisible = ref(false)
const loading = ref(false)
const props = defineProps<{
id: any
}>()
const resetForm = () => {
// form
Object.keys(form).forEach((key) => delete form[key])
}
watch(dialogVisible, (bool) => {
if (!bool) {
form.value = {
similarity: 0.6,
max_paragraph_char_number: 5000
}
}
})
interface Form {
[key: string]: FormField
}
interface FormField {
value: any
min?: number
max?: number
step?: number
label?: string
precision?: number
tooltip?: string
}
const open = (data: any) => {
form.value = cloneDeep(data)
const newData = cloneDeep(data)
Object.keys(form).forEach((key) => {
delete form[key]
})
Object.keys(newData).forEach((key) => {
set(form, key, newData[key])
})
dialogVisible.value = true
}
const submit = async (formEl: FormInstance | undefined) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
emit('refresh', form.value)
dialogVisible.value = false
}
})
const submit = async () => {
if (paramFormRef.value) {
await paramFormRef.value.validate((valid, fields) => {
if (valid) {
const data = Object.keys(form).reduce(
(acc, key) => {
acc[key] = form[key].value
return acc
},
{} as Record<string, any>
)
application.asyncPostModelConfig(props.id, data, loading).then(() => {
emit('refresh', form)
dialogVisible.value = false
})
}
})
}
}
watch(dialogVisible, (bool) => {
if (!bool) {
resetForm()
}
})
defineExpose({ open })
</script>
<style lang="scss" scope>
<style lang="scss" scoped>
.aiMode-param-dialog {
padding: 8px 8px 24px 8px;
.el-dialog__header {
padding: 16px 16px 0 16px;
}
.el-dialog__body {
padding: 16px !important;
}
.dialog-max-height {
height: 550px;
}
.custom-slider {
.el-input-number.is-without-controls .el-input__wrapper {
padding: 0 !important;

View File

@ -1,42 +1,51 @@
<template>
<div class="p-24" v-loading="loading">
<el-form
ref="authFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
ref="authFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
>
<el-form-item :label="$t('login.cas.ldpUri')" prop="config_data.ldpUri">
<el-input v-model="form.config_data.ldpUri" :placeholder="$t('login.cas.ldpUriPlaceholder')"/>
<el-input
v-model="form.config_data.ldpUri"
:placeholder="$t('login.cas.ldpUriPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.cas.redirectUrl')" prop="config_data.redirectUrl">
<el-input v-model="form.config_data.redirectUrl" :placeholder="$t('login.cas.redirectUrlPlaceholder')"/>
<el-input
v-model="form.config_data.redirectUrl"
:placeholder="$t('login.cas.redirectUrlPlaceholder')"
/>
</el-form-item>
<el-form-item>
<el-checkbox v-model="form.is_active">{{ $t('login.cas.enableAuthentication') }}</el-checkbox>
<el-checkbox v-model="form.is_active"
>{{ $t('login.cas.enableAuthentication') }}
</el-checkbox>
</el-form-item>
</el-form>
<div class="text-right">
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading"> {{ $t('login.cas.save') }}
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading">
{{ $t('login.cas.save') }}
</el-button>
</div>
</div>
</template>
<script setup lang="ts">
import {reactive, ref, watch, onMounted} from 'vue'
import { reactive, ref, watch, onMounted } from 'vue'
import authApi from '@/api/auth-setting'
import type {FormInstance, FormRules} from 'element-plus'
import {t} from '@/locales'
import {MsgSuccess} from '@/utils/message'
import type { FormInstance, FormRules } from 'element-plus'
import { t } from '@/locales'
import { MsgSuccess } from '@/utils/message'
const form = ref<any>({
id: '',
auth_type: 'CAS',
config_data: {
ldpUri: '',
redirectUrl: '',
redirectUrl: ''
},
is_active: true
})
@ -46,27 +55,25 @@ const authFormRef = ref()
const loading = ref(false)
const rules = reactive<FormRules<any>>({
'config_data.ldpUri': [{required: true, message: t('login.ldap.ldpUriPlaceholder'), trigger: 'blur'}],
'config_data.redirectUrl': [{
required: true,
message: t('login.ldap.redirectUrlPlaceholder'),
trigger: 'blur'
}],
'config_data.ldpUri': [
{ required: true, message: t('login.cas.ldpUriPlaceholder'), trigger: 'blur' }
],
'config_data.redirectUrl': [
{
required: true,
message: t('login.cas.redirectUrlPlaceholder'),
trigger: 'blur'
}
]
})
const submit = async (formEl: FormInstance | undefined, test?: string) => {
const submit = async (formEl: FormInstance | undefined) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
if (test) {
authApi.postAuthSetting(form.value, loading).then((res) => {
MsgSuccess(t('login.cas.testConnectionSuccess'))
})
} else {
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
MsgSuccess(t('login.cas.saveSuccess'))
})
}
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
MsgSuccess(t('login.cas.saveSuccess'))
})
}
})
}

View File

@ -1,49 +1,69 @@
<template>
<div class="p-24" v-loading="loading">
<el-form
ref="authFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
ref="authFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
>
<el-form-item :label="$t('login.ldap.address')" prop="config_data.ldap_server">
<el-input v-model="form.config_data.ldap_server" :placeholder="$t('login.ldap.serverPlaceholder')"/>
<el-input
v-model="form.config_data.ldap_server"
:placeholder="$t('login.ldap.serverPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.ldap.bindDN')" prop="config_data.base_dn">
<el-input v-model="form.config_data.base_dn" :placeholder="$t('login.ldap.bindDNPlaceholder')"/>
<el-input
v-model="form.config_data.base_dn"
:placeholder="$t('login.ldap.bindDNPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.ldap.password')" prop="config_data.password">
<el-input v-model="form.config_data.password" :placeholder="$t('login.ldap.passwordPlaceholder')"
show-password/>
<el-input
v-model="form.config_data.password"
:placeholder="$t('login.ldap.passwordPlaceholder')"
show-password
/>
</el-form-item>
<el-form-item :label="$t('login.ldap.ou')" prop="config_data.ou">
<el-input v-model="form.config_data.ou" :placeholder="$t('login.ldap.ouPlaceholder')"/>
<el-input v-model="form.config_data.ou" :placeholder="$t('login.ldap.ouPlaceholder')" />
</el-form-item>
<el-form-item :label="$t('login.ldap.ldap_filter')" prop="config_data.ldap_filter">
<el-input v-model="form.config_data.ldap_filter" :placeholder="$t('login.ldap.ldap_filterPlaceholder')"/>
<el-input
v-model="form.config_data.ldap_filter"
:placeholder="$t('login.ldap.ldap_filterPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.ldap.ldap_mapping')" prop="config_data.ldap_mapping">
<el-input v-model="form.config_data.ldap_mapping" placeholder='{"name":"name","email":"mail","username":"cn"}'/>
<el-input
v-model="form.config_data.ldap_mapping"
placeholder='{"name":"name","email":"mail","username":"cn"}'
/>
</el-form-item>
<el-form-item>
<el-checkbox v-model="form.is_active">{{ $t('login.ldap.enableAuthentication') }}</el-checkbox>
<el-checkbox v-model="form.is_active">{{
$t('login.ldap.enableAuthentication')
}}</el-checkbox>
</el-form-item>
<el-button @click="submit(authFormRef, 'test')" :disabled="loading"> {{ $t('login.ldap.test') }}</el-button>
<el-button @click="submit(authFormRef, 'test')" :disabled="loading">
{{ $t('login.ldap.test') }}</el-button
>
</el-form>
<div class="text-right">
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading"> {{ $t('login.ldap.save') }}
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading">
{{ $t('login.ldap.save') }}
</el-button>
</div>
</div>
</template>
<script setup lang="ts">
import {reactive, ref, watch, onMounted} from 'vue'
import { reactive, ref, watch, onMounted } from 'vue'
import authApi from '@/api/auth-setting'
import type {FormInstance, FormRules} from 'element-plus'
import {t} from '@/locales'
import {MsgSuccess} from '@/utils/message'
import type { FormInstance, FormRules } from 'element-plus'
import { t } from '@/locales'
import { MsgSuccess } from '@/utils/message'
const form = ref<any>({
id: '',
@ -54,7 +74,7 @@ const form = ref<any>({
password: '',
ou: '',
ldap_filter: '',
ldap_mapping: '',
ldap_mapping: ''
},
is_active: true
})
@ -64,12 +84,22 @@ const authFormRef = ref()
const loading = ref(false)
const rules = reactive<FormRules<any>>({
'config_data.ldap_server': [{required: true, message: t('login.ldap.serverPlaceholder'), trigger: 'blur'}],
'config_data.base_dn': [{required: true, message: t('login.ldap.bindDNPlaceholder'), trigger: 'blur'}],
'config_data.password': [{required: true, message: t('login.ldap.passwordPlaceholder'), trigger: 'blur'}],
'config_data.ou': [{required: true, message: t('login.ldap.ouPlaceholder'), trigger: 'blur'}],
'config_data.ldap_filter': [{required: true, message: t('login.ldap.ldap_filterPlaceholder'), trigger: 'blur'}],
'config_data.ldap_mapping': [{required: true, message: t('login.ldap.ldap_mappingPlaceholder'), trigger: 'blur'}]
'config_data.ldap_server': [
{ required: true, message: t('login.ldap.serverPlaceholder'), trigger: 'blur' }
],
'config_data.base_dn': [
{ required: true, message: t('login.ldap.bindDNPlaceholder'), trigger: 'blur' }
],
'config_data.password': [
{ required: true, message: t('login.ldap.passwordPlaceholder'), trigger: 'blur' }
],
'config_data.ou': [{ required: true, message: t('login.ldap.ouPlaceholder'), trigger: 'blur' }],
'config_data.ldap_filter': [
{ required: true, message: t('login.ldap.ldap_filterPlaceholder'), trigger: 'blur' }
],
'config_data.ldap_mapping': [
{ required: true, message: t('login.ldap.ldap_mappingPlaceholder'), trigger: 'blur' }
]
})
const submit = async (formEl: FormInstance | undefined, test?: string) => {
@ -94,7 +124,9 @@ function getDetail() {
if (res.data && JSON.stringify(res.data) !== '{}') {
form.value = res.data
if (res.data.config_data.ldap_mapping) {
form.value.config_data.ldap_mapping = JSON.stringify(JSON.parse(res.data.config_data.ldap_mapping))
form.value.config_data.ldap_mapping = JSON.stringify(
JSON.parse(res.data.config_data.ldap_mapping)
)
}
}
})

View File

@ -1,49 +1,69 @@
<template>
<div class="p-24" v-loading="loading">
<el-form
ref="authFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
ref="authFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
>
<el-form-item :label="$t('login.oidc.authEndpoint')" prop="config_data.authEndpoint">
<el-input v-model="form.config_data.authEndpoint" :placeholder="$t('login.oidc.authEndpointPlaceholder')"/>
<el-input
v-model="form.config_data.authEndpoint"
:placeholder="$t('login.oidc.authEndpointPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oidc.tokenEndpoint')" prop="config_data.tokenEndpoint">
<el-input v-model="form.config_data.tokenEndpoint" :placeholder="$t('login.oidc.tokenEndpointPlaceholder')"/>
<el-input
v-model="form.config_data.tokenEndpoint"
:placeholder="$t('login.oidc.tokenEndpointPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oidc.userInfoEndpoint')" prop="config_data.userInfoEndpoint">
<el-input v-model="form.config_data.userInfoEndpoint"
:placeholder="$t('login.oidc.userInfoEndpointPlaceholder')"/>
<el-input
v-model="form.config_data.userInfoEndpoint"
:placeholder="$t('login.oidc.userInfoEndpointPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oidc.clientId')" prop="config_data.clientId">
<el-input v-model="form.config_data.clientId" :placeholder="$t('login.oidc.clientIdPlaceholder')"/>
<el-input
v-model="form.config_data.clientId"
:placeholder="$t('login.oidc.clientIdPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oidc.clientSecret')" prop="config_data.clientSecret">
<el-input v-model="form.config_data.clientSecret" :placeholder="$t('login.oidc.clientSecretPlaceholder')"
show-password/>
<el-input
v-model="form.config_data.clientSecret"
:placeholder="$t('login.oidc.clientSecretPlaceholder')"
show-password
/>
</el-form-item>
<el-form-item :label="$t('login.oidc.redirectUrl')" prop="config_data.redirectUrl">
<el-input v-model="form.config_data.redirectUrl" :placeholder="$t('login.oidc.redirectUrlPlaceholder')"/>
<el-input
v-model="form.config_data.redirectUrl"
:placeholder="$t('login.oidc.redirectUrlPlaceholder')"
/>
</el-form-item>
<el-form-item>
<el-checkbox v-model="form.is_active">{{ $t('login.oidc.enableAuthentication') }}</el-checkbox>
<el-checkbox v-model="form.is_active"
>{{ $t('login.oidc.enableAuthentication') }}
</el-checkbox>
</el-form-item>
</el-form>
<div class="text-right">
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading"> {{ $t('login.ldap.save') }}
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading">
{{ $t('login.ldap.save') }}
</el-button>
</div>
</div>
</template>
<script setup lang="ts">
import {reactive, ref, watch, onMounted} from 'vue'
import { reactive, ref, watch, onMounted } from 'vue'
import authApi from '@/api/auth-setting'
import type {FormInstance, FormRules} from 'element-plus'
import {t} from '@/locales'
import {MsgSuccess} from '@/utils/message'
import type { FormInstance, FormRules } from 'element-plus'
import { t } from '@/locales'
import { MsgSuccess } from '@/utils/message'
const form = ref<any>({
id: '',
@ -64,32 +84,40 @@ const authFormRef = ref()
const loading = ref(false)
const rules = reactive<FormRules<any>>({
'config_data.authEndpoint': [{required: true, message: t('login.oidc.authEndpointPlaceholder'), trigger: 'blur'}],
'config_data.tokenEndpoint': [{required: true, message: t('login.oidc.tokenEndpointPlaceholder'), trigger: 'blur'}],
'config_data.userInfoEndpoint': [{
required: true,
message: t('login.oidc.userInfoEndpointPlaceholder'),
trigger: 'blur'
}],
'config_data.clientId': [{required: true, message: t('login.oidc.clientIdPlaceholder'), trigger: 'blur'}],
'config_data.clientSecret': [{required: true, message: t('login.oidc.clientSecretPlaceholder'), trigger: 'blur'}],
'config_data.redirectUrl': [{required: true, message: t('login.oidc.redirectUrlPlaceholder'), trigger: 'blur'}],
'config_data.logoutEndpoint': [{required: true, message: t('login.oidc.logoutEndpointPlaceholder'), trigger: 'blur'}]
'config_data.authEndpoint': [
{ required: true, message: t('login.oidc.authEndpointPlaceholder'), trigger: 'blur' }
],
'config_data.tokenEndpoint': [
{ required: true, message: t('login.oidc.tokenEndpointPlaceholder'), trigger: 'blur' }
],
'config_data.userInfoEndpoint': [
{
required: true,
message: t('login.oidc.userInfoEndpointPlaceholder'),
trigger: 'blur'
}
],
'config_data.clientId': [
{ required: true, message: t('login.oidc.clientIdPlaceholder'), trigger: 'blur' }
],
'config_data.clientSecret': [
{ required: true, message: t('login.oidc.clientSecretPlaceholder'), trigger: 'blur' }
],
'config_data.redirectUrl': [
{ required: true, message: t('login.oidc.redirectUrlPlaceholder'), trigger: 'blur' }
],
'config_data.logoutEndpoint': [
{ required: true, message: t('login.oidc.logoutEndpointPlaceholder'), trigger: 'blur' }
]
})
const submit = async (formEl: FormInstance | undefined, test?: string) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
if (test) {
authApi.postAuthSetting(form.value, loading).then((res) => {
MsgSuccess(t('login.ldap.testConnectionSuccess'))
})
} else {
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
MsgSuccess(t('login.ldap.saveSuccess'))
})
}
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
MsgSuccess(t('login.ldap.saveSuccess'))
})
}
})
}
@ -98,9 +126,6 @@ function getDetail() {
authApi.getAuthSetting(form.value.auth_type, loading).then((res: any) => {
if (res.data && JSON.stringify(res.data) !== '{}') {
form.value = res.data
if (res.data.config_data.ldap_mapping) {
form.value.config_data.ldap_mapping = JSON.stringify(JSON.parse(res.data.config_data.ldap_mapping))
}
}
})
}

View File

@ -9,6 +9,7 @@
@click="clickListHandle"
value-key="provider"
default-active=""
style="overflow-y: auto"
>
<template #default="{ row, index }">
<div class="flex" v-if="index === 0">