This commit is contained in:
liqiang-fit2cloud 2024-11-15 10:14:01 +08:00
commit 3ac5477556
28 changed files with 623 additions and 179 deletions

View File

@ -1,23 +1,36 @@
# coding=utf-8
import io
from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
from dataset.models import File
from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle
class BaseDocumentExtractNode(IDocumentExtractNode):
def execute(self, document, **kwargs):
get_buffer = FileBufferHandle().get_buffer
self.context['document_list'] = document
content = ''
spliter = '\n-----------------------------------\n'
if len(document) > 0:
for doc in document:
file = QuerySet(File).filter(id=doc['file_id']).first()
file_type = doc['name'].split('.')[-1]
if file_type.lower() in ['txt', 'md', 'csv', 'html']:
content += spliter + doc['name'] + '\n' + file.get_byte().tobytes().decode('utf-8')
splitter = '\n-----------------------------------\n'
if document is None:
return NodeResult({'content': content}, {})
for doc in document:
file = QuerySet(File).filter(id=doc['file_id']).first()
buffer = io.BytesIO(file.get_byte().tobytes())
buffer.name = doc['name'] # this is the important line
for split_handle in (parse_table_handle_list + split_handles):
if split_handle.support(buffer, get_buffer):
# 回到文件头
buffer.seek(0)
file_content = split_handle.get_content(buffer)
content += splitter + '## ' + doc['name'] + '\n' + file_content
break
return NodeResult({'content': content}, {})

View File

@ -16,6 +16,8 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
@ -32,7 +34,7 @@ class IImageUnderstandNode(INode):
self.node_params_serializer.data.get('image_list')[1:])
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
chat_record_id,
image,
**kwargs) -> NodeResult:

View File

@ -63,7 +63,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id,
image,
**kwargs) -> NodeResult:
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
@ -71,10 +71,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
# todo 处理上传图片
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
self.context['message_list'] = message_list
self.context['image_list'] = image
self.context['dialogue_type'] = dialogue_type
if stream:
r = image_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
@ -86,15 +86,31 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)
@staticmethod
def get_history_message(history_chat_record, dialogue_number):
def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
[self.generate_history_human_message(history_chat_record[index]), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_history_human_message(self, chat_record):
for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id = image_list[0]['file_id']
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
return HumanMessage(
content=[
{'type': 'text', 'text': data['question']},
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
])
return HumanMessage(content=chat_record.problem_text)
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
@ -148,5 +164,6 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list')
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}

View File

@ -143,7 +143,6 @@ class Flow:
if model_params_setting is None:
model_params_setting = model_params_setting_form.get_default_form_data()
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
model_params_setting_form.valid_form(model_params_setting)
if node.properties.get('status', 200) != 200:
raise ValidationError(ErrorDetail(f'节点{node.properties.get("stepName")} 不可用'))
node_list = [node for node in self.nodes if (node.type == 'function-lib-node')]

View File

@ -836,8 +836,6 @@ class ApplicationSerializer(serializers.Serializer):
ApplicationSerializer.Edit(data=instance).is_valid(
raise_exception=True)
application_id = self.data.get("application_id")
valid_model_params_setting(instance.get('model_id'),
instance.get('model_params_setting'))
application = QuerySet(Application).get(id=application_id)
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:

View File

@ -294,7 +294,6 @@ class ChatSerializers(serializers.Serializer):
return chat_id
def open_simple(self, application):
valid_model_params_setting(application.model_id, application.model_params_setting)
application_id = self.data.get('application_id')
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
@ -376,7 +375,6 @@ class ChatSerializers(serializers.Serializer):
model_id = self.data.get('model_id')
dataset_id_list = self.data.get('dataset_id_list')
dialogue_number = 3 if self.data.get('multiple_rounds_dialogue', False) else 0
valid_model_params_setting(model_id, self.data.get('model_params_setting'))
application = Application(id=None, dialogue_number=dialogue_number, model_id=model_id,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),

View File

@ -442,7 +442,8 @@ class ChatView(APIView):
def post(self, request: Request, application_id: str, chat_id: str):
files = request.FILES.getlist('file')
file_ids = []
meta = {'application_id': application_id, 'chat_id': chat_id}
debug = request.data.get("debug", "false").lower() == "true"
meta = {'application_id': application_id, 'chat_id': chat_id, 'debug': debug}
for file in files:
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})

View File

@ -17,3 +17,7 @@ class BaseParseTableHandle(ABC):
@abstractmethod
def handle(self, file, get_buffer,save_image):
pass
@abstractmethod
def get_content(self, file):
pass

View File

@ -18,3 +18,7 @@ class BaseSplitHandle(ABC):
@abstractmethod
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
pass
@abstractmethod
def get_content(self, file):
pass

View File

@ -189,3 +189,13 @@ class DocSplitHandle(BaseSplitHandle):
".DOC") or file_name.endswith(".DOCX"):
return True
return False
def get_content(self, file):
try:
image_list = []
buffer = file.read()
doc = Document(io.BytesIO(buffer))
return self.to_md(doc, image_list, get_image_id_func())
except BaseException as e:
traceback.print_exception(e)
return ''

View File

@ -7,6 +7,7 @@
@desc:
"""
import re
import traceback
from typing import List
from bs4 import BeautifulSoup
@ -59,3 +60,14 @@ class HTMLSplitHandle(BaseSplitHandle):
return {'name': file.name,
'content': split_model.parse(content)
}
def get_content(self, file):
buffer = file.read()
try:
encoding = get_encoding(buffer)
content = buffer.decode(encoding)
return html2text(content)
except BaseException as e:
traceback.print_exception(e)
return ''

View File

@ -11,6 +11,7 @@ import os
import re
import tempfile
import time
import traceback
from typing import List
import fitz
@ -297,3 +298,17 @@ class PdfSplitHandle(BaseSplitHandle):
if file_name.endswith(".pdf") or file_name.endswith(".PDF"):
return True
return False
def get_content(self, file):
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
# 将上传的文件保存到临时文件中
temp_file.write(file.read())
# 获取临时文件的路径
temp_file_path = temp_file.name
pdf_document = fitz.open(temp_file_path)
try:
return self.handle_pdf_content(file, pdf_document)
except BaseException as e:
traceback.print_exception(e)
return ''

View File

@ -34,3 +34,11 @@ class CsvSplitHandle(BaseParseTableHandle):
paragraphs.append({'title': '', 'content': line})
return [{'name': file.name, 'paragraphs': paragraphs}]
def get_content(self, file):
buffer = file.read()
try:
return buffer.decode(detect(buffer)['encoding'])
except BaseException as e:
max_kb.error(f'csv split handle error: {e}')
return [{'name': file.name, 'paragraphs': []}]

View File

@ -60,3 +60,24 @@ class XlsSplitHandle(BaseParseTableHandle):
max_kb.error(f'excel split handle error: {e}')
return [{'name': file.name, 'paragraphs': []}]
return result
def get_content(self, file):
# 打开 .xls 文件
workbook = xlrd.open_workbook(file_contents=file.read(), formatting_info=True)
sheets = workbook.sheets()
md_tables = ''
for sheet in sheets:
# 获取表头和内容
headers = sheet.row_values(0)
data = [sheet.row_values(row_idx) for row_idx in range(1, sheet.nrows)]
# 构建 Markdown 表格
md_table = '| ' + ' | '.join(headers) + ' |\n'
md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n'
for row in data:
# 将每个单元格中的内容替换换行符为 <br> 以保留原始格式
md_table += '| ' + ' | '.join([str(cell).replace('\n', '<br>') if cell else '' for cell in row]) + ' |\n'
md_tables += md_table + '\n\n'
return md_tables

View File

@ -72,3 +72,31 @@ class XlsxSplitHandle(BaseParseTableHandle):
max_kb.error(f'excel split handle error: {e}')
return [{'name': file.name, 'paragraphs': []}]
return result
def get_content(self, file):
# 加载 Excel 文件
workbook = load_workbook(file)
md_tables = ''
# 如果未指定 sheet_name则使用第一个工作表
for sheetname in workbook.sheetnames:
sheet = workbook[sheetname] if sheetname else workbook.active
# 获取工作表的所有行
rows = list(sheet.iter_rows(values_only=True))
if not rows:
continue
# 提取表头和内容
headers = rows[0]
data = rows[1:]
# 构建 Markdown 表格
md_table = '| ' + ' | '.join(headers) + ' |\n'
md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n'
for row in data:
md_table += '| ' + ' | '.join(
[str(cell).replace('\n', '<br>') if cell is not None else '' for cell in row]) + ' |\n'
md_tables += md_table + '\n\n'
return md_tables

View File

@ -7,6 +7,7 @@
@desc:
"""
import re
import traceback
from typing import List
from charset_normalizer import detect
@ -49,3 +50,11 @@ class TextSplitHandle(BaseSplitHandle):
return {'name': file.name,
'content': split_model.parse(content)
}
def get_content(self, file):
buffer = file.read()
try:
return buffer.decode(detect(buffer)['encoding'])
except BaseException as e:
traceback.print_exception(e)
return ''

View File

@ -8,8 +8,10 @@
"""
from .client_access_num_job import *
from .clean_chat_job import *
from .clean_debug_file_job import *
def run():
client_access_num_job.run()
clean_chat_job.run()
clean_debug_file_job.run()

View File

@ -0,0 +1,36 @@
# coding=utf-8
import logging
from datetime import timedelta
from apscheduler.schedulers.background import BackgroundScheduler
from django.db.models import Q
from django.utils import timezone
from django_apscheduler.jobstores import DjangoJobStore
from common.lock.impl.file_lock import FileLock
from dataset.models import File
scheduler = BackgroundScheduler()
scheduler.add_jobstore(DjangoJobStore(), "default")
lock = FileLock()
def clean_debug_file():
logging.getLogger("max_kb").info('开始清理debug文件')
two_hours_ago = timezone.now() - timedelta(hours=2)
# 删除对应的文件
File.objects.filter(Q(create_time__lt=two_hours_ago) & Q(meta__debug=True)).delete()
logging.getLogger("max_kb").info('结束清理debug文件')
def run():
if lock.try_lock('clean_debug_file', 30 * 30):
try:
scheduler.start()
clean_debug_file_job = scheduler.get_job(job_id='clean_debug_file')
if clean_debug_file_job is not None:
clean_debug_file_job.remove()
scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0', id='clean_debug_file')
finally:
lock.un_lock('clean_debug_file')

View File

@ -182,6 +182,25 @@
</div>
</template>
<!-- 文档内容提取 -->
<template v-if="item.type === WorkflowType.DocumentExtractNode">
<div class="card-never border-r-4">
<h5 class="p-8-12">参数输出</h5>
<div class="p-8-12 border-t-dashed lighter">
<el-scrollbar height="150">
<MdPreview
v-if="item.content"
ref="editorRef"
editorId="preview-only"
:modelValue="item.content"
style="background: none"
/>
<template v-else> - </template>
</el-scrollbar>
</div>
</div>
</template>
<!-- 函数库 -->
<template
v-if="

View File

@ -100,6 +100,7 @@ const props = withDefaults(
appId?: string
chatId: string
sendMessage: (question: string, other_params_data?: any, chat?: chatType) => void
openChatId: () => Promise<string>
}>(),
{
applicationDetails: () => ({}),
@ -165,8 +166,14 @@ const uploadFile = async (file: any, fileList: any) => {
}
if (!chatId_context.value) {
const res = await applicationApi.getChatOpen(props.applicationDetails.id as string)
chatId_context.value = res.data
const res = await props.openChatId()
chatId_context.value = res
}
if (props.type === 'debug-ai-chat') {
formData.append('debug', 'true')
} else {
formData.append('debug', 'false')
}
applicationApi

View File

@ -38,6 +38,7 @@
:is-mobile="isMobile"
:type="type"
:send-message="sendMessage"
:open-chat-id="openChatId"
v-model:chat-id="chartOpenId"
v-model:loading="loading"
v-if="type !== 'log'"
@ -139,57 +140,50 @@ const handleDebounceClick = debounce((val, other_params_data?: any, chat?: chatT
}, 200)
/**
* 对话
* 打开对话id
*/
function getChartOpenId(chat?: any) {
loading.value = true
const openChatId: () => Promise<string> = () => {
const obj = props.applicationDetails
if (props.appId) {
return applicationApi
.getChatOpen(props.appId)
.then((res) => {
chartOpenId.value = res.data
chatMessage(chat)
return res.data
})
.catch((res) => {
if (res.response.status === 403) {
application.asyncAppAuthentication(accessToken).then(() => {
getChartOpenId(chat)
return application.asyncAppAuthentication(accessToken).then(() => {
return openChatId()
})
} else {
loading.value = false
return Promise.reject(res)
}
return Promise.reject(res)
})
} else {
if (isWorkFlow(obj.type)) {
const submitObj = {
work_flow: obj.work_flow
}
return applicationApi
.postWorkflowChatOpen(submitObj)
.then((res) => {
chartOpenId.value = res.data
chatMessage(chat)
})
.catch((res) => {
loading.value = false
return Promise.reject(res)
})
return applicationApi.postWorkflowChatOpen(submitObj).then((res) => {
chartOpenId.value = res.data
return res.data
})
} else {
return applicationApi
.postChatOpen(obj)
.then((res) => {
chartOpenId.value = res.data
chatMessage(chat)
})
.catch((res) => {
loading.value = false
return Promise.reject(res)
})
return applicationApi.postChatOpen(obj).then((res) => {
chartOpenId.value = res.data
return res.data
})
}
}
}
/**
* 对话
*/
function getChartOpenId(chat?: any) {
return openChatId().then(() => {
chatMessage(chat)
})
}
/**
* 获取一个递归函数,处理流式数据

View File

@ -1,66 +1,90 @@
import en from 'element-plus/es/locale/lang/en';
import components from './components';
import layout from './layout';
import views from './views';
import en from 'element-plus/es/locale/lang/en'
import components from './components'
import layout from './layout'
import views from './views'
export default {
lang: 'English',
layout,
views,
components,
en,
login: {
authentication: 'Login Authentication',
ldap: {
title: 'LDAP Settings',
address: 'LDAP Address',
serverPlaceholder: 'Please enter LDAP address',
bindDN: 'Bind DN',
bindDNPlaceholder: 'Please enter Bind DN',
password: 'Password',
passwordPlaceholder: 'Please enter password',
ou: 'User OU',
ouPlaceholder: 'Please enter User OU',
ldap_filter: 'User Filter',
ldap_filterPlaceholder: 'Please enter User Filter',
ldap_mapping: 'LDAP Attribute Mapping',
ldap_mappingPlaceholder: 'Please enter LDAP Attribute Mapping',
test: 'Test Connection',
enableAuthentication: 'Enable LDAP Authentication',
save: 'Save',
testConnectionSuccess: 'Test Connection Success',
testConnectionFailed: 'Test Connection Failed',
saveSuccess: 'Save Success',
},
cas: {
title: 'CAS Settings',
ldpUri: 'ldpUri',
ldpUriPlaceholder: 'Please enter ldpUri',
redirectUrl: 'Callback Address',
redirectUrlPlaceholder: 'Please enter Callback Address',
enableAuthentication: 'Enable CAS Authentication',
saveSuccess: 'Save Success',
save: 'Save',
},
oidc: {
title: 'OIDC Settings',
authEndpoint: 'Auth Endpoint',
authEndpointPlaceholder: 'Please enter Auth Endpoint',
tokenEndpoint: 'Token Endpoint',
tokenEndpointPlaceholder: 'Please enter Token Endpoint',
userInfoEndpoint: 'User Info Endpoint',
userInfoEndpointPlaceholder: 'Please enter User Info Endpoint',
clientId: 'Client ID',
clientIdPlaceholder: 'Please enter Client ID',
clientSecret: 'Client Secret',
clientSecretPlaceholder: 'Please enter Client Secret',
logoutEndpoint: 'Logout Endpoint',
logoutEndpointPlaceholder: 'Please enter Logout Endpoint',
redirectUrl: 'Redirect URL',
redirectUrlPlaceholder: 'Please enter Redirect URL',
enableAuthentication: 'Enable OIDC Authentication',
},
jump_tip: 'Jumping to the authentication source page for authentication',
jump: 'Jump',
lang: 'English',
layout,
views,
components,
en,
login: {
authentication: 'Login Authentication',
ldap: {
title: 'LDAP Settings',
address: 'LDAP Address',
serverPlaceholder: 'Please enter LDAP address',
bindDN: 'Bind DN',
bindDNPlaceholder: 'Please enter Bind DN',
password: 'Password',
passwordPlaceholder: 'Please enter password',
ou: 'User OU',
ouPlaceholder: 'Please enter User OU',
ldap_filter: 'User Filter',
ldap_filterPlaceholder: 'Please enter User Filter',
ldap_mapping: 'LDAP Attribute Mapping',
ldap_mappingPlaceholder: 'Please enter LDAP Attribute Mapping',
test: 'Test Connection',
enableAuthentication: 'Enable LDAP Authentication',
save: 'Save',
testConnectionSuccess: 'Test Connection Success',
testConnectionFailed: 'Test Connection Failed',
saveSuccess: 'Save Success'
},
};
cas: {
title: 'CAS Settings',
ldpUri: 'ldpUri',
ldpUriPlaceholder: 'Please enter ldpUri',
validateUrl: 'Validation Address',
validateUrlPlaceholder: 'Please enter Validation Address',
redirectUrl: 'Callback Address',
redirectUrlPlaceholder: 'Please enter Callback Address',
enableAuthentication: 'Enable CAS Authentication',
saveSuccess: 'Save Success',
save: 'Save'
},
oidc: {
title: 'OIDC Settings',
authEndpoint: 'Auth Endpoint',
authEndpointPlaceholder: 'Please enter Auth Endpoint',
tokenEndpoint: 'Token Endpoint',
tokenEndpointPlaceholder: 'Please enter Token Endpoint',
userInfoEndpoint: 'User Info Endpoint',
userInfoEndpointPlaceholder: 'Please enter User Info Endpoint',
clientId: 'Client ID',
clientIdPlaceholder: 'Please enter Client ID',
clientSecret: 'Client Secret',
clientSecretPlaceholder: 'Please enter Client Secret',
logoutEndpoint: 'Logout Endpoint',
logoutEndpointPlaceholder: 'Please enter Logout Endpoint',
redirectUrl: 'Redirect URL',
redirectUrlPlaceholder: 'Please enter Redirect URL',
enableAuthentication: 'Enable OIDC Authentication'
},
jump_tip: 'Jumping to the authentication source page for authentication',
jump: 'Jump',
oauth2: {
title: 'OAUTH2 Settings',
authEndpoint: 'Auth Endpoint',
authEndpointPlaceholder: 'Please enter Auth Endpoint',
tokenEndpoint: 'Token Endpoint',
tokenEndpointPlaceholder: 'Please enter Token Endpoint',
userInfoEndpoint: 'User Info Endpoint',
userInfoEndpointPlaceholder: 'Please enter User Info Endpoint',
scope: 'Scope',
scopePlaceholder: 'Please enter Scope',
clientId: 'Client ID',
clientIdPlaceholder: 'Please enter Client ID',
clientSecret: 'Client Secret',
clientSecretPlaceholder: 'Please enter Client Secret',
redirectUrl: 'Redirect URL',
redirectUrlPlaceholder: 'Please enter Redirect URL',
filedMapping: 'Field Mapping',
filedMappingPlaceholder: 'Please enter Field Mapping',
enableAuthentication: 'Enable OAUTH2 Authentication',
save: 'Save',
saveSuccess: 'Save Success'
}
}
}

View File

@ -1,66 +1,90 @@
import zhCn from 'element-plus/es/locale/lang/zh-cn';
import components from './components';
import layout from './layout';
import views from './views';
import zhCn from 'element-plus/es/locale/lang/zh-cn'
import components from './components'
import layout from './layout'
import views from './views'
export default {
lang: '简体中文',
layout,
views,
components,
zhCn,
login: {
authentication: '登录认证',
ldap: {
title: 'LDAP 设置',
address: 'LDAP 地址',
serverPlaceholder: '请输入LDAP 地址',
bindDN: '绑定DN',
bindDNPlaceholder: '请输入绑定 DN',
password: '密码',
passwordPlaceholder: '请输入密码',
ou: '用户OU',
ouPlaceholder: '请输入用户 OU',
ldap_filter: '用户过滤器',
ldap_filterPlaceholder: '请输入用户过滤器',
ldap_mapping: 'LDAP 属性映射',
ldap_mappingPlaceholder: '请输入 LDAP 属性映射',
test: '测试连接',
enableAuthentication: '启用 LDAP 认证',
save: '保存',
testConnectionSuccess: '测试连接成功',
testConnectionFailed: '测试连接失败',
saveSuccess: '保存成功',
},
cas: {
title: 'CAS 设置',
ldpUri: 'ldpUri',
ldpUriPlaceholder: '请输入ldpUri',
redirectUrl: '回调地址',
redirectUrlPlaceholder: '请输入回调地址',
enableAuthentication: '启用CAS认证',
saveSuccess: '保存成功',
save: '保存',
},
oidc: {
title: 'OIDC 设置',
authEndpoint: '授权端地址',
authEndpointPlaceholder: '请输入授权端地址',
tokenEndpoint: 'Token端地址',
tokenEndpointPlaceholder: '请输入Token端地址',
userInfoEndpoint: '用户信息端地址',
userInfoEndpointPlaceholder: '请输入用户信息端地址',
clientId: '客户端ID',
clientIdPlaceholder: '请输入客户端ID',
clientSecret: '客户端密钥',
clientSecretPlaceholder: '请输入客户端密钥',
logoutEndpoint: '注销端地址',
logoutEndpointPlaceholder: '请输入注销端地址',
redirectUrl: '回调地址',
redirectUrlPlaceholder: '请输入回调地址',
enableAuthentication: '启用OIDC认证',
},
jump_tip: '即将跳转至认证源页面进行认证',
jump: '跳转',
lang: '简体中文',
layout,
views,
components,
zhCn,
login: {
authentication: '登录认证',
ldap: {
title: 'LDAP 设置',
address: 'LDAP 地址',
serverPlaceholder: '请输入LDAP 地址',
bindDN: '绑定DN',
bindDNPlaceholder: '请输入绑定 DN',
password: '密码',
passwordPlaceholder: '请输入密码',
ou: '用户OU',
ouPlaceholder: '请输入用户 OU',
ldap_filter: '用户过滤器',
ldap_filterPlaceholder: '请输入用户过滤器',
ldap_mapping: 'LDAP 属性映射',
ldap_mappingPlaceholder: '请输入 LDAP 属性映射',
test: '测试连接',
enableAuthentication: '启用 LDAP 认证',
save: '保存',
testConnectionSuccess: '测试连接成功',
testConnectionFailed: '测试连接失败',
saveSuccess: '保存成功'
},
};
cas: {
title: 'CAS 设置',
ldpUri: 'ldpUri',
ldpUriPlaceholder: '请输入ldpUri',
validateUrl: '验证地址',
validateUrlPlaceholder: '请输入验证地址',
redirectUrl: '回调地址',
redirectUrlPlaceholder: '请输入回调地址',
enableAuthentication: '启用CAS认证',
saveSuccess: '保存成功',
save: '保存'
},
oidc: {
title: 'OIDC 设置',
authEndpoint: '授权端地址',
authEndpointPlaceholder: '请输入授权端地址',
tokenEndpoint: 'Token端地址',
tokenEndpointPlaceholder: '请输入Token端地址',
userInfoEndpoint: '用户信息端地址',
userInfoEndpointPlaceholder: '请输入用户信息端地址',
clientId: '客户端ID',
clientIdPlaceholder: '请输入客户端ID',
clientSecret: '客户端密钥',
clientSecretPlaceholder: '请输入客户端密钥',
logoutEndpoint: '注销端地址',
logoutEndpointPlaceholder: '请输入注销端地址',
redirectUrl: '回调地址',
redirectUrlPlaceholder: '请输入回调地址',
enableAuthentication: '启用OIDC认证'
},
jump_tip: '即将跳转至认证源页面进行认证',
jump: '跳转',
oauth2: {
title: 'OAUTH2 设置',
authEndpoint: '授权端地址',
authEndpointPlaceholder: '请输入授权端地址',
tokenEndpoint: 'Token端地址',
tokenEndpointPlaceholder: '请输入Token端地址',
userInfoEndpoint: '用户信息端地址',
userInfoEndpointPlaceholder: '请输入用户信息端地址',
scope: '连接范围',
scopePlaceholder: '请输入连接范围',
clientId: '客户端ID',
clientIdPlaceholder: '请输入客户端ID',
clientSecret: '客户端密钥',
clientSecretPlaceholder: '请输入客户端密钥',
redirectUrl: '回调地址',
redirectUrlPlaceholder: '请输入回调地址',
filedMapping: '字段映射',
filedMappingPlaceholder: '请输入字段映射',
enableAuthentication: '启用OAUTH2认证',
save: '保存',
saveSuccess: '保存成功'
}
}
}

View File

@ -15,6 +15,12 @@
:placeholder="$t('login.cas.ldpUriPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.cas.validateUrl')" prop="config_data.validateUrl">
<el-input
v-model="form.config_data.validateUrl"
:placeholder="$t('login.cas.validateUrlPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.cas.redirectUrl')" prop="config_data.redirectUrl">
<el-input
v-model="form.config_data.redirectUrl"
@ -49,6 +55,7 @@ const form = ref<any>({
auth_type: 'CAS',
config_data: {
ldpUri: '',
validateUrl: '',
redirectUrl: ''
},
is_active: true
@ -62,6 +69,9 @@ const rules = reactive<FormRules<any>>({
'config_data.ldpUri': [
{ required: true, message: t('login.cas.ldpUriPlaceholder'), trigger: 'blur' }
],
'config_data.validateUrl': [
{ required: true, message: t('login.cas.validateUrlPlaceholder'), trigger: 'blur' }
],
'config_data.redirectUrl': [
{
required: true,
@ -85,6 +95,9 @@ const submit = async (formEl: FormInstance | undefined) => {
function getDetail() {
authApi.getAuthSetting(form.value.auth_type, loading).then((res: any) => {
if (res.data && JSON.stringify(res.data) !== '{}') {
if (!res.data.config_data.validateUrl) {
res.data.config_data.validateUrl = res.data.config_data.ldpUri
}
form.value = res.data
}
})

View File

@ -0,0 +1,161 @@
<template>
<div class="authentication-setting__main main-calc-height">
<el-scrollbar>
<div class="form-container p-24" v-loading="loading">
<el-form
ref="authFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
>
<el-form-item :label="$t('login.oauth2.authEndpoint')" prop="config_data.authEndpoint">
<el-input
v-model="form.config_data.authEndpoint"
:placeholder="$t('login.oauth2.authEndpointPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oauth2.tokenEndpoint')" prop="config_data.tokenEndpoint">
<el-input
v-model="form.config_data.tokenEndpoint"
:placeholder="$t('login.oauth2.tokenEndpointPlaceholder')"
/>
</el-form-item>
<el-form-item
:label="$t('login.oauth2.userInfoEndpoint')"
prop="config_data.userInfoEndpoint"
>
<el-input
v-model="form.config_data.userInfoEndpoint"
:placeholder="$t('login.oauth2.userInfoEndpointPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oauth2.scope')" prop="config_data.scope">
<el-input
v-model="form.config_data.scope"
:placeholder="$t('login.oauth2.scopePlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oauth2.clientId')" prop="config_data.clientId">
<el-input
v-model="form.config_data.clientId"
:placeholder="$t('login.oauth2.clientIdPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oauth2.clientSecret')" prop="config_data.clientSecret">
<el-input
v-model="form.config_data.clientSecret"
:placeholder="$t('login.oauth2.clientSecretPlaceholder')"
show-password
/>
</el-form-item>
<el-form-item :label="$t('login.oauth2.redirectUrl')" prop="config_data.redirectUrl">
<el-input
v-model="form.config_data.redirectUrl"
:placeholder="$t('login.oauth2.redirectUrlPlaceholder')"
/>
</el-form-item>
<el-form-item :label="$t('login.oauth2.filedMapping')" prop="config_data.fieldMapping">
<el-input
v-model="form.config_data.fieldMapping"
:placeholder="$t('login.oauth2.filedMappingPlaceholder')"
/>
</el-form-item>
<el-form-item>
<el-checkbox v-model="form.is_active"
>{{ $t('login.oauth2.enableAuthentication') }}
</el-checkbox>
</el-form-item>
</el-form>
<div class="text-right">
<el-button @click="submit(authFormRef)" type="primary" :disabled="loading">
{{ $t('login.ldap.save') }}
</el-button>
</div>
</div>
</el-scrollbar>
</div>
</template>
<script setup lang="ts">
import { reactive, ref, watch, onMounted } from 'vue'
import authApi from '@/api/auth-setting'
import type { FormInstance, FormRules } from 'element-plus'
import { t } from '@/locales'
import { MsgSuccess } from '@/utils/message'
const form = ref<any>({
id: '',
auth_type: 'OAUTH2',
config_data: {
authEndpoint: '',
tokenEndpoint: '',
userInfoEndpoint: '',
scope: '',
clientId: '',
clientSecret: '',
redirectUrl: '',
fieldMapping: ''
},
is_active: true
})
const authFormRef = ref()
const loading = ref(false)
const rules = reactive<FormRules<any>>({
'config_data.authEndpoint': [
{ required: true, message: t('login.oauth2.authEndpointPlaceholder'), trigger: 'blur' }
],
'config_data.tokenEndpoint': [
{ required: true, message: t('login.oauth2.tokenEndpointPlaceholder'), trigger: 'blur' }
],
'config_data.userInfoEndpoint': [
{
required: true,
message: t('login.oauth2.userInfoEndpointPlaceholder'),
trigger: 'blur'
}
],
'config_data.scope': [
{ required: true, message: t('login.oauth2.scopePlaceholder'), trigger: 'blur' }
],
'config_data.clientId': [
{ required: true, message: t('login.oauth2.clientIdPlaceholder'), trigger: 'blur' }
],
'config_data.clientSecret': [
{ required: true, message: t('login.oauth2.clientSecretPlaceholder'), trigger: 'blur' }
],
'config_data.redirectUrl': [
{ required: true, message: t('login.oauth2.redirectUrlPlaceholder'), trigger: 'blur' }
],
'config_data.fieldMapping': [
{ required: true, message: t('login.oauth2.filedMappingPlaceholder'), trigger: 'blur' }
]
})
const submit = async (formEl: FormInstance | undefined, test?: string) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
authApi.putAuthSetting(form.value.auth_type, form.value, loading).then((res) => {
MsgSuccess(t('login.ldap.saveSuccess'))
})
}
})
}
function getDetail() {
authApi.getAuthSetting(form.value.auth_type, loading).then((res: any) => {
if (res.data && JSON.stringify(res.data) !== '{}') {
form.value = res.data
}
})
}
onMounted(() => {
getDetail()
})
</script>
<style lang="scss" scoped></style>

View File

@ -19,6 +19,7 @@ import OIDC from './component/OIDC.vue'
import SCAN from './component/SCAN.vue'
import { t } from '@/locales'
import useStore from '@/stores'
import OAUTH2 from '@/views/authentication/component/OAUTH2.vue'
const { user } = useStore()
const router = useRouter()
@ -40,6 +41,11 @@ const tabList = [
name: 'OIDC',
component: OIDC
},
{
label: t('login.oauth2.title'),
name: 'OAUTH2',
component: OAUTH2
},
{
label: '扫码登录',
name: 'SCAN',

View File

@ -66,7 +66,8 @@
:key="item"
class="login-button-circle color-secondary"
@click="changeMode(item)"
>{{ item }}
>
<span style="font-size: 10px">{{ item }}</span>
</el-button>
<el-button
v-if="item === 'QR_CODE' && loginMode !== item"
@ -162,6 +163,14 @@ function redirectAuth(authType: string) {
if (authType === 'OIDC') {
url = `${config.authEndpoint}?client_id=${config.clientId}&redirect_uri=${redirectUrl}&response_type=code&scope=openid+profile+email`
}
if (authType === 'OAUTH2') {
url =
`${config.authEndpoint}?client_id=${config.clientId}&response_type=code` +
`&redirect_uri=${redirectUrl}&state=${res.data.id}`
if (config.scope) {
url += `&scope=${config.scope}`
}
}
if (url) {
window.location.href = url
}

View File

@ -128,7 +128,16 @@
@submitDialog="submitDialog"
/>
</el-form-item>
<el-form-item label="历史聊天记录">
<el-form-item>
<template #label>
<div class="flex-between">
<div>历史聊天记录</div>
<el-select v-model="form_data.dialogue_type" class="w-120">
<el-option label="节点" value="NODE"/>
<el-option label="工作流" value="WORKFLOW"/>
</el-select>
</div>
</template>
<el-input-number
v-model="form_data.dialogue_number"
:min="0"
@ -213,6 +222,7 @@ const form = {
system: '',
prompt: defaultPrompt,
dialogue_number: 0,
dialogue_type: 'NODE',
is_result: true,
temperature: null,
max_tokens: null,