diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 176230d2d..2663af097 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -1,23 +1,36 @@ # coding=utf-8 +import io + from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode from dataset.models import File +from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle class BaseDocumentExtractNode(IDocumentExtractNode): def execute(self, document, **kwargs): + get_buffer = FileBufferHandle().get_buffer + self.context['document_list'] = document content = '' - spliter = '\n-----------------------------------\n' - if len(document) > 0: - for doc in document: - file = QuerySet(File).filter(id=doc['file_id']).first() - file_type = doc['name'].split('.')[-1] - if file_type.lower() in ['txt', 'md', 'csv', 'html']: - content += spliter + doc['name'] + '\n' + file.get_byte().tobytes().decode('utf-8') + splitter = '\n-----------------------------------\n' + if document is None: + return NodeResult({'content': content}, {}) + for doc in document: + file = QuerySet(File).filter(id=doc['file_id']).first() + buffer = io.BytesIO(file.get_byte().tobytes()) + buffer.name = doc['name'] # this is the important line + + for split_handle in (parse_table_handle_list + split_handles): + if split_handle.support(buffer, get_buffer): + # 回到文件头 + buffer.seek(0) + file_content = split_handle.get_content(buffer) + content += splitter + '## ' + doc['name'] + '\n' + file_content + break return NodeResult({'content': content}, {}) diff --git a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py index 26fb431d0..31aaa0205 100644 --- a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py @@ -16,6 +16,8 @@ class ImageUnderstandNodeSerializer(serializers.Serializer): # 多轮对话数量 dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) @@ -32,7 +34,7 @@ class IImageUnderstandNode(INode): self.node_params_serializer.data.get('image_list')[1:]) return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) - def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, + def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id, image, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 046d6f783..7f0394fe5 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -63,7 +63,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode): self.context['question'] = details.get('question') self.answer_text = details.get('answer') - def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id, image, **kwargs) -> NodeResult: image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) @@ -71,10 +71,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode): self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) self.context['question'] = question.content - # todo 处理上传图片 message_list = self.generate_message_list(image_model, system, prompt, history_message, image) self.context['message_list'] = message_list self.context['image_list'] = image + self.context['dialogue_type'] = dialogue_type if stream: r = image_model.stream(message_list) return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list, @@ -86,15 +86,31 @@ class BaseImageUnderstandNode(IImageUnderstandNode): 'history_message': history_message, 'question': question.content}, {}, _write_context=write_context) - @staticmethod - def get_history_message(history_chat_record, dialogue_number): + def get_history_message(self, history_chat_record, dialogue_number): start_index = len(history_chat_record) - dialogue_number history_message = reduce(lambda x, y: [*x, *y], [ - [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + [self.generate_history_human_message(history_chat_record[index]), history_chat_record[index].get_ai_message()] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))], []) return history_message + def generate_history_human_message(self, chat_record): + + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + file_id = image_list[0]['file_id'] + file = QuerySet(File).filter(id=file_id).first() + base64_image = base64.b64encode(file.get_byte()).decode("utf-8") + return HumanMessage( + content=[ + {'type': 'text', 'text': data['question']}, + {'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}}, + ]) + return HumanMessage(content=chat_record.problem_text) + def generate_prompt_question(self, prompt): return HumanMessage(self.workflow_manage.generate_prompt(prompt)) @@ -148,5 +164,6 @@ class BaseImageUnderstandNode(IImageUnderstandNode): 'answer_tokens': self.context.get('answer_tokens'), 'status': self.status, 'err_message': self.err_message, - 'image_list': self.context.get('image_list') + 'image_list': self.context.get('image_list'), + 'dialogue_type': self.context.get('dialogue_type') } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 0cd92e8e1..d36a58b1f 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -143,7 +143,6 @@ class Flow: if model_params_setting is None: model_params_setting = model_params_setting_form.get_default_form_data() node.properties.get('node_data', {})['model_params_setting'] = model_params_setting - model_params_setting_form.valid_form(model_params_setting) if node.properties.get('status', 200) != 200: raise ValidationError(ErrorDetail(f'节点{node.properties.get("stepName")} 不可用')) node_list = [node for node in self.nodes if (node.type == 'function-lib-node')] diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 7e5cc8deb..23b058040 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -836,8 +836,6 @@ class ApplicationSerializer(serializers.Serializer): ApplicationSerializer.Edit(data=instance).is_valid( raise_exception=True) application_id = self.data.get("application_id") - valid_model_params_setting(instance.get('model_id'), - instance.get('model_params_setting')) application = QuerySet(Application).get(id=application_id) if instance.get('model_id') is None or len(instance.get('model_id')) == 0: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 45e18a1ed..f3710fe00 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -294,7 +294,6 @@ class ChatSerializers(serializers.Serializer): return chat_id def open_simple(self, application): - valid_model_params_setting(application.model_id, application.model_params_setting) application_id = self.data.get('application_id') dataset_id_list = [str(row.dataset_id) for row in QuerySet(ApplicationDatasetMapping).filter( @@ -376,7 +375,6 @@ class ChatSerializers(serializers.Serializer): model_id = self.data.get('model_id') dataset_id_list = self.data.get('dataset_id_list') dialogue_number = 3 if self.data.get('multiple_rounds_dialogue', False) else 0 - valid_model_params_setting(model_id, self.data.get('model_params_setting')) application = Application(id=None, dialogue_number=dialogue_number, model_id=model_id, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 790277860..74b6fc361 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -442,7 +442,8 @@ class ChatView(APIView): def post(self, request: Request, application_id: str, chat_id: str): files = request.FILES.getlist('file') file_ids = [] - meta = {'application_id': application_id, 'chat_id': chat_id} + debug = request.data.get("debug", "false").lower() == "true" + meta = {'application_id': application_id, 'chat_id': chat_id, 'debug': debug} for file in files: file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]}) diff --git a/apps/common/handle/base_parse_table_handle.py b/apps/common/handle/base_parse_table_handle.py index 487290378..b84690859 100644 --- a/apps/common/handle/base_parse_table_handle.py +++ b/apps/common/handle/base_parse_table_handle.py @@ -17,3 +17,7 @@ class BaseParseTableHandle(ABC): @abstractmethod def handle(self, file, get_buffer,save_image): pass + + @abstractmethod + def get_content(self, file): + pass \ No newline at end of file diff --git a/apps/common/handle/base_split_handle.py b/apps/common/handle/base_split_handle.py index f9b573f0f..ea48e6868 100644 --- a/apps/common/handle/base_split_handle.py +++ b/apps/common/handle/base_split_handle.py @@ -18,3 +18,7 @@ class BaseSplitHandle(ABC): @abstractmethod def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): pass + + @abstractmethod + def get_content(self, file): + pass diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py index c31c53ec1..350a3921a 100644 --- a/apps/common/handle/impl/doc_split_handle.py +++ b/apps/common/handle/impl/doc_split_handle.py @@ -189,3 +189,13 @@ class DocSplitHandle(BaseSplitHandle): ".DOC") or file_name.endswith(".DOCX"): return True return False + + def get_content(self, file): + try: + image_list = [] + buffer = file.read() + doc = Document(io.BytesIO(buffer)) + return self.to_md(doc, image_list, get_image_id_func()) + except BaseException as e: + traceback.print_exception(e) + return '' \ No newline at end of file diff --git a/apps/common/handle/impl/html_split_handle.py b/apps/common/handle/impl/html_split_handle.py index 878d9edda..688904567 100644 --- a/apps/common/handle/impl/html_split_handle.py +++ b/apps/common/handle/impl/html_split_handle.py @@ -7,6 +7,7 @@ @desc: """ import re +import traceback from typing import List from bs4 import BeautifulSoup @@ -59,3 +60,14 @@ class HTMLSplitHandle(BaseSplitHandle): return {'name': file.name, 'content': split_model.parse(content) } + + def get_content(self, file): + buffer = file.read() + + try: + encoding = get_encoding(buffer) + content = buffer.decode(encoding) + return html2text(content) + except BaseException as e: + traceback.print_exception(e) + return '' \ No newline at end of file diff --git a/apps/common/handle/impl/pdf_split_handle.py b/apps/common/handle/impl/pdf_split_handle.py index 52a33b0de..828196b7b 100644 --- a/apps/common/handle/impl/pdf_split_handle.py +++ b/apps/common/handle/impl/pdf_split_handle.py @@ -11,6 +11,7 @@ import os import re import tempfile import time +import traceback from typing import List import fitz @@ -297,3 +298,17 @@ class PdfSplitHandle(BaseSplitHandle): if file_name.endswith(".pdf") or file_name.endswith(".PDF"): return True return False + + def get_content(self, file): + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + # 将上传的文件保存到临时文件中 + temp_file.write(file.read()) + # 获取临时文件的路径 + temp_file_path = temp_file.name + + pdf_document = fitz.open(temp_file_path) + try: + return self.handle_pdf_content(file, pdf_document) + except BaseException as e: + traceback.print_exception(e) + return '' \ No newline at end of file diff --git a/apps/common/handle/impl/table/csv_parse_table_handle.py b/apps/common/handle/impl/table/csv_parse_table_handle.py index 71152f38e..c3a85db86 100644 --- a/apps/common/handle/impl/table/csv_parse_table_handle.py +++ b/apps/common/handle/impl/table/csv_parse_table_handle.py @@ -34,3 +34,11 @@ class CsvSplitHandle(BaseParseTableHandle): paragraphs.append({'title': '', 'content': line}) return [{'name': file.name, 'paragraphs': paragraphs}] + + def get_content(self, file): + buffer = file.read() + try: + return buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + max_kb.error(f'csv split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] \ No newline at end of file diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py index 6c30d49de..a3ef14443 100644 --- a/apps/common/handle/impl/table/xls_parse_table_handle.py +++ b/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -60,3 +60,24 @@ class XlsSplitHandle(BaseParseTableHandle): max_kb.error(f'excel split handle error: {e}') return [{'name': file.name, 'paragraphs': []}] return result + + def get_content(self, file): + # 打开 .xls 文件 + workbook = xlrd.open_workbook(file_contents=file.read(), formatting_info=True) + sheets = workbook.sheets() + md_tables = '' + for sheet in sheets: + + # 获取表头和内容 + headers = sheet.row_values(0) + data = [sheet.row_values(row_idx) for row_idx in range(1, sheet.nrows)] + + # 构建 Markdown 表格 + md_table = '| ' + ' | '.join(headers) + ' |\n' + md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n' + for row in data: + # 将每个单元格中的内容替换换行符为 以保留原始格式 + md_table += '| ' + ' | '.join([str(cell).replace('\n', '') if cell else '' for cell in row]) + ' |\n' + md_tables += md_table + '\n\n' + + return md_tables diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py index 35ef2f14b..e92d3c11a 100644 --- a/apps/common/handle/impl/table/xlsx_parse_table_handle.py +++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -72,3 +72,31 @@ class XlsxSplitHandle(BaseParseTableHandle): max_kb.error(f'excel split handle error: {e}') return [{'name': file.name, 'paragraphs': []}] return result + + + def get_content(self, file): + # 加载 Excel 文件 + workbook = load_workbook(file) + md_tables = '' + # 如果未指定 sheet_name,则使用第一个工作表 + for sheetname in workbook.sheetnames: + sheet = workbook[sheetname] if sheetname else workbook.active + + # 获取工作表的所有行 + rows = list(sheet.iter_rows(values_only=True)) + if not rows: + continue + + # 提取表头和内容 + headers = rows[0] + data = rows[1:] + + # 构建 Markdown 表格 + md_table = '| ' + ' | '.join(headers) + ' |\n' + md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n' + for row in data: + md_table += '| ' + ' | '.join( + [str(cell).replace('\n', '') if cell is not None else '' for cell in row]) + ' |\n' + + md_tables += md_table + '\n\n' + return md_tables \ No newline at end of file diff --git a/apps/common/handle/impl/text_split_handle.py b/apps/common/handle/impl/text_split_handle.py index 467607ff5..984c4e1e9 100644 --- a/apps/common/handle/impl/text_split_handle.py +++ b/apps/common/handle/impl/text_split_handle.py @@ -7,6 +7,7 @@ @desc: """ import re +import traceback from typing import List from charset_normalizer import detect @@ -49,3 +50,11 @@ class TextSplitHandle(BaseSplitHandle): return {'name': file.name, 'content': split_model.parse(content) } + + def get_content(self, file): + buffer = file.read() + try: + return buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + traceback.print_exception(e) + return '' \ No newline at end of file diff --git a/apps/common/job/__init__.py b/apps/common/job/__init__.py index 2f4ef2697..286c81cae 100644 --- a/apps/common/job/__init__.py +++ b/apps/common/job/__init__.py @@ -8,8 +8,10 @@ """ from .client_access_num_job import * from .clean_chat_job import * +from .clean_debug_file_job import * def run(): client_access_num_job.run() clean_chat_job.run() + clean_debug_file_job.run() diff --git a/apps/common/job/clean_debug_file_job.py b/apps/common/job/clean_debug_file_job.py new file mode 100644 index 000000000..c409f4404 --- /dev/null +++ b/apps/common/job/clean_debug_file_job.py @@ -0,0 +1,36 @@ +# coding=utf-8 + +import logging +from datetime import timedelta + +from apscheduler.schedulers.background import BackgroundScheduler +from django.db.models import Q +from django.utils import timezone +from django_apscheduler.jobstores import DjangoJobStore + +from common.lock.impl.file_lock import FileLock +from dataset.models import File + +scheduler = BackgroundScheduler() +scheduler.add_jobstore(DjangoJobStore(), "default") +lock = FileLock() + + +def clean_debug_file(): + logging.getLogger("max_kb").info('开始清理debug文件') + two_hours_ago = timezone.now() - timedelta(hours=2) + # 删除对应的文件 + File.objects.filter(Q(create_time__lt=two_hours_ago) & Q(meta__debug=True)).delete() + logging.getLogger("max_kb").info('结束清理debug文件') + + +def run(): + if lock.try_lock('clean_debug_file', 30 * 30): + try: + scheduler.start() + clean_debug_file_job = scheduler.get_job(job_id='clean_debug_file') + if clean_debug_file_job is not None: + clean_debug_file_job.remove() + scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0', id='clean_debug_file') + finally: + lock.un_lock('clean_debug_file') diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index f63c7392d..bc3f2e160 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -182,6 +182,25 @@ + + + + 参数输出 + + + + - + + + + + Promise = () => { const obj = props.applicationDetails if (props.appId) { return applicationApi .getChatOpen(props.appId) .then((res) => { chartOpenId.value = res.data - chatMessage(chat) + return res.data }) .catch((res) => { if (res.response.status === 403) { - application.asyncAppAuthentication(accessToken).then(() => { - getChartOpenId(chat) + return application.asyncAppAuthentication(accessToken).then(() => { + return openChatId() }) - } else { - loading.value = false - return Promise.reject(res) } + return Promise.reject(res) }) } else { if (isWorkFlow(obj.type)) { const submitObj = { work_flow: obj.work_flow } - return applicationApi - .postWorkflowChatOpen(submitObj) - .then((res) => { - chartOpenId.value = res.data - chatMessage(chat) - }) - .catch((res) => { - loading.value = false - return Promise.reject(res) - }) + return applicationApi.postWorkflowChatOpen(submitObj).then((res) => { + chartOpenId.value = res.data + return res.data + }) } else { - return applicationApi - .postChatOpen(obj) - .then((res) => { - chartOpenId.value = res.data - chatMessage(chat) - }) - .catch((res) => { - loading.value = false - return Promise.reject(res) - }) + return applicationApi.postChatOpen(obj).then((res) => { + chartOpenId.value = res.data + return res.data + }) } } } +/** + * 对话 + */ +function getChartOpenId(chat?: any) { + return openChatId().then(() => { + chatMessage(chat) + }) +} /** * 获取一个递归函数,处理流式数据 diff --git a/ui/src/locales/lang/en_US/index.ts b/ui/src/locales/lang/en_US/index.ts index 3e80c0adc..536afc7e0 100644 --- a/ui/src/locales/lang/en_US/index.ts +++ b/ui/src/locales/lang/en_US/index.ts @@ -1,66 +1,90 @@ -import en from 'element-plus/es/locale/lang/en'; -import components from './components'; -import layout from './layout'; -import views from './views'; +import en from 'element-plus/es/locale/lang/en' +import components from './components' +import layout from './layout' +import views from './views' export default { - lang: 'English', - layout, - views, - components, - en, - login: { - authentication: 'Login Authentication', - ldap: { - title: 'LDAP Settings', - address: 'LDAP Address', - serverPlaceholder: 'Please enter LDAP address', - bindDN: 'Bind DN', - bindDNPlaceholder: 'Please enter Bind DN', - password: 'Password', - passwordPlaceholder: 'Please enter password', - ou: 'User OU', - ouPlaceholder: 'Please enter User OU', - ldap_filter: 'User Filter', - ldap_filterPlaceholder: 'Please enter User Filter', - ldap_mapping: 'LDAP Attribute Mapping', - ldap_mappingPlaceholder: 'Please enter LDAP Attribute Mapping', - test: 'Test Connection', - enableAuthentication: 'Enable LDAP Authentication', - save: 'Save', - testConnectionSuccess: 'Test Connection Success', - testConnectionFailed: 'Test Connection Failed', - saveSuccess: 'Save Success', - }, - cas: { - title: 'CAS Settings', - ldpUri: 'ldpUri', - ldpUriPlaceholder: 'Please enter ldpUri', - redirectUrl: 'Callback Address', - redirectUrlPlaceholder: 'Please enter Callback Address', - enableAuthentication: 'Enable CAS Authentication', - saveSuccess: 'Save Success', - save: 'Save', - }, - oidc: { - title: 'OIDC Settings', - authEndpoint: 'Auth Endpoint', - authEndpointPlaceholder: 'Please enter Auth Endpoint', - tokenEndpoint: 'Token Endpoint', - tokenEndpointPlaceholder: 'Please enter Token Endpoint', - userInfoEndpoint: 'User Info Endpoint', - userInfoEndpointPlaceholder: 'Please enter User Info Endpoint', - clientId: 'Client ID', - clientIdPlaceholder: 'Please enter Client ID', - clientSecret: 'Client Secret', - clientSecretPlaceholder: 'Please enter Client Secret', - logoutEndpoint: 'Logout Endpoint', - logoutEndpointPlaceholder: 'Please enter Logout Endpoint', - redirectUrl: 'Redirect URL', - redirectUrlPlaceholder: 'Please enter Redirect URL', - enableAuthentication: 'Enable OIDC Authentication', - }, - jump_tip: 'Jumping to the authentication source page for authentication', - jump: 'Jump', + lang: 'English', + layout, + views, + components, + en, + login: { + authentication: 'Login Authentication', + ldap: { + title: 'LDAP Settings', + address: 'LDAP Address', + serverPlaceholder: 'Please enter LDAP address', + bindDN: 'Bind DN', + bindDNPlaceholder: 'Please enter Bind DN', + password: 'Password', + passwordPlaceholder: 'Please enter password', + ou: 'User OU', + ouPlaceholder: 'Please enter User OU', + ldap_filter: 'User Filter', + ldap_filterPlaceholder: 'Please enter User Filter', + ldap_mapping: 'LDAP Attribute Mapping', + ldap_mappingPlaceholder: 'Please enter LDAP Attribute Mapping', + test: 'Test Connection', + enableAuthentication: 'Enable LDAP Authentication', + save: 'Save', + testConnectionSuccess: 'Test Connection Success', + testConnectionFailed: 'Test Connection Failed', + saveSuccess: 'Save Success' }, -}; + cas: { + title: 'CAS Settings', + ldpUri: 'ldpUri', + ldpUriPlaceholder: 'Please enter ldpUri', + validateUrl: 'Validation Address', + validateUrlPlaceholder: 'Please enter Validation Address', + redirectUrl: 'Callback Address', + redirectUrlPlaceholder: 'Please enter Callback Address', + enableAuthentication: 'Enable CAS Authentication', + saveSuccess: 'Save Success', + save: 'Save' + }, + oidc: { + title: 'OIDC Settings', + authEndpoint: 'Auth Endpoint', + authEndpointPlaceholder: 'Please enter Auth Endpoint', + tokenEndpoint: 'Token Endpoint', + tokenEndpointPlaceholder: 'Please enter Token Endpoint', + userInfoEndpoint: 'User Info Endpoint', + userInfoEndpointPlaceholder: 'Please enter User Info Endpoint', + clientId: 'Client ID', + clientIdPlaceholder: 'Please enter Client ID', + clientSecret: 'Client Secret', + clientSecretPlaceholder: 'Please enter Client Secret', + logoutEndpoint: 'Logout Endpoint', + logoutEndpointPlaceholder: 'Please enter Logout Endpoint', + redirectUrl: 'Redirect URL', + redirectUrlPlaceholder: 'Please enter Redirect URL', + enableAuthentication: 'Enable OIDC Authentication' + }, + jump_tip: 'Jumping to the authentication source page for authentication', + jump: 'Jump', + oauth2: { + title: 'OAUTH2 Settings', + authEndpoint: 'Auth Endpoint', + authEndpointPlaceholder: 'Please enter Auth Endpoint', + tokenEndpoint: 'Token Endpoint', + tokenEndpointPlaceholder: 'Please enter Token Endpoint', + userInfoEndpoint: 'User Info Endpoint', + userInfoEndpointPlaceholder: 'Please enter User Info Endpoint', + scope: 'Scope', + scopePlaceholder: 'Please enter Scope', + clientId: 'Client ID', + clientIdPlaceholder: 'Please enter Client ID', + clientSecret: 'Client Secret', + clientSecretPlaceholder: 'Please enter Client Secret', + redirectUrl: 'Redirect URL', + redirectUrlPlaceholder: 'Please enter Redirect URL', + filedMapping: 'Field Mapping', + filedMappingPlaceholder: 'Please enter Field Mapping', + enableAuthentication: 'Enable OAUTH2 Authentication', + save: 'Save', + saveSuccess: 'Save Success' + } + } +} diff --git a/ui/src/locales/lang/zh_CN/index.ts b/ui/src/locales/lang/zh_CN/index.ts index 26a9a0c10..62ba593c1 100644 --- a/ui/src/locales/lang/zh_CN/index.ts +++ b/ui/src/locales/lang/zh_CN/index.ts @@ -1,66 +1,90 @@ -import zhCn from 'element-plus/es/locale/lang/zh-cn'; -import components from './components'; -import layout from './layout'; -import views from './views'; +import zhCn from 'element-plus/es/locale/lang/zh-cn' +import components from './components' +import layout from './layout' +import views from './views' export default { - lang: '简体中文', - layout, - views, - components, - zhCn, - login: { - authentication: '登录认证', - ldap: { - title: 'LDAP 设置', - address: 'LDAP 地址', - serverPlaceholder: '请输入LDAP 地址', - bindDN: '绑定DN', - bindDNPlaceholder: '请输入绑定 DN', - password: '密码', - passwordPlaceholder: '请输入密码', - ou: '用户OU', - ouPlaceholder: '请输入用户 OU', - ldap_filter: '用户过滤器', - ldap_filterPlaceholder: '请输入用户过滤器', - ldap_mapping: 'LDAP 属性映射', - ldap_mappingPlaceholder: '请输入 LDAP 属性映射', - test: '测试连接', - enableAuthentication: '启用 LDAP 认证', - save: '保存', - testConnectionSuccess: '测试连接成功', - testConnectionFailed: '测试连接失败', - saveSuccess: '保存成功', - }, - cas: { - title: 'CAS 设置', - ldpUri: 'ldpUri', - ldpUriPlaceholder: '请输入ldpUri', - redirectUrl: '回调地址', - redirectUrlPlaceholder: '请输入回调地址', - enableAuthentication: '启用CAS认证', - saveSuccess: '保存成功', - save: '保存', - }, - oidc: { - title: 'OIDC 设置', - authEndpoint: '授权端地址', - authEndpointPlaceholder: '请输入授权端地址', - tokenEndpoint: 'Token端地址', - tokenEndpointPlaceholder: '请输入Token端地址', - userInfoEndpoint: '用户信息端地址', - userInfoEndpointPlaceholder: '请输入用户信息端地址', - clientId: '客户端ID', - clientIdPlaceholder: '请输入客户端ID', - clientSecret: '客户端密钥', - clientSecretPlaceholder: '请输入客户端密钥', - logoutEndpoint: '注销端地址', - logoutEndpointPlaceholder: '请输入注销端地址', - redirectUrl: '回调地址', - redirectUrlPlaceholder: '请输入回调地址', - enableAuthentication: '启用OIDC认证', - }, - jump_tip: '即将跳转至认证源页面进行认证', - jump: '跳转', + lang: '简体中文', + layout, + views, + components, + zhCn, + login: { + authentication: '登录认证', + ldap: { + title: 'LDAP 设置', + address: 'LDAP 地址', + serverPlaceholder: '请输入LDAP 地址', + bindDN: '绑定DN', + bindDNPlaceholder: '请输入绑定 DN', + password: '密码', + passwordPlaceholder: '请输入密码', + ou: '用户OU', + ouPlaceholder: '请输入用户 OU', + ldap_filter: '用户过滤器', + ldap_filterPlaceholder: '请输入用户过滤器', + ldap_mapping: 'LDAP 属性映射', + ldap_mappingPlaceholder: '请输入 LDAP 属性映射', + test: '测试连接', + enableAuthentication: '启用 LDAP 认证', + save: '保存', + testConnectionSuccess: '测试连接成功', + testConnectionFailed: '测试连接失败', + saveSuccess: '保存成功' }, -}; + cas: { + title: 'CAS 设置', + ldpUri: 'ldpUri', + ldpUriPlaceholder: '请输入ldpUri', + validateUrl: '验证地址', + validateUrlPlaceholder: '请输入验证地址', + redirectUrl: '回调地址', + redirectUrlPlaceholder: '请输入回调地址', + enableAuthentication: '启用CAS认证', + saveSuccess: '保存成功', + save: '保存' + }, + oidc: { + title: 'OIDC 设置', + authEndpoint: '授权端地址', + authEndpointPlaceholder: '请输入授权端地址', + tokenEndpoint: 'Token端地址', + tokenEndpointPlaceholder: '请输入Token端地址', + userInfoEndpoint: '用户信息端地址', + userInfoEndpointPlaceholder: '请输入用户信息端地址', + clientId: '客户端ID', + clientIdPlaceholder: '请输入客户端ID', + clientSecret: '客户端密钥', + clientSecretPlaceholder: '请输入客户端密钥', + logoutEndpoint: '注销端地址', + logoutEndpointPlaceholder: '请输入注销端地址', + redirectUrl: '回调地址', + redirectUrlPlaceholder: '请输入回调地址', + enableAuthentication: '启用OIDC认证' + }, + jump_tip: '即将跳转至认证源页面进行认证', + jump: '跳转', + oauth2: { + title: 'OAUTH2 设置', + authEndpoint: '授权端地址', + authEndpointPlaceholder: '请输入授权端地址', + tokenEndpoint: 'Token端地址', + tokenEndpointPlaceholder: '请输入Token端地址', + userInfoEndpoint: '用户信息端地址', + userInfoEndpointPlaceholder: '请输入用户信息端地址', + scope: '连接范围', + scopePlaceholder: '请输入连接范围', + clientId: '客户端ID', + clientIdPlaceholder: '请输入客户端ID', + clientSecret: '客户端密钥', + clientSecretPlaceholder: '请输入客户端密钥', + redirectUrl: '回调地址', + redirectUrlPlaceholder: '请输入回调地址', + filedMapping: '字段映射', + filedMappingPlaceholder: '请输入字段映射', + enableAuthentication: '启用OAUTH2认证', + save: '保存', + saveSuccess: '保存成功' + } + } +} diff --git a/ui/src/views/authentication/component/CAS.vue b/ui/src/views/authentication/component/CAS.vue index 4f790db66..70aed9061 100644 --- a/ui/src/views/authentication/component/CAS.vue +++ b/ui/src/views/authentication/component/CAS.vue @@ -15,6 +15,12 @@ :placeholder="$t('login.cas.ldpUriPlaceholder')" /> + + + ({ auth_type: 'CAS', config_data: { ldpUri: '', + validateUrl: '', redirectUrl: '' }, is_active: true @@ -62,6 +69,9 @@ const rules = reactive>({ 'config_data.ldpUri': [ { required: true, message: t('login.cas.ldpUriPlaceholder'), trigger: 'blur' } ], + 'config_data.validateUrl': [ + { required: true, message: t('login.cas.validateUrlPlaceholder'), trigger: 'blur' } + ], 'config_data.redirectUrl': [ { required: true, @@ -85,6 +95,9 @@ const submit = async (formEl: FormInstance | undefined) => { function getDetail() { authApi.getAuthSetting(form.value.auth_type, loading).then((res: any) => { if (res.data && JSON.stringify(res.data) !== '{}') { + if (!res.data.config_data.validateUrl) { + res.data.config_data.validateUrl = res.data.config_data.ldpUri + } form.value = res.data } }) diff --git a/ui/src/views/authentication/component/OAUTH2.vue b/ui/src/views/authentication/component/OAUTH2.vue new file mode 100644 index 000000000..f428bd7f9 --- /dev/null +++ b/ui/src/views/authentication/component/OAUTH2.vue @@ -0,0 +1,161 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {{ $t('login.oauth2.enableAuthentication') }} + + + + + + + {{ $t('login.ldap.save') }} + + + + + + + + diff --git a/ui/src/views/authentication/index.vue b/ui/src/views/authentication/index.vue index 484f915eb..d12d9a12a 100644 --- a/ui/src/views/authentication/index.vue +++ b/ui/src/views/authentication/index.vue @@ -19,6 +19,7 @@ import OIDC from './component/OIDC.vue' import SCAN from './component/SCAN.vue' import { t } from '@/locales' import useStore from '@/stores' +import OAUTH2 from '@/views/authentication/component/OAUTH2.vue' const { user } = useStore() const router = useRouter() @@ -40,6 +41,11 @@ const tabList = [ name: 'OIDC', component: OIDC }, + { + label: t('login.oauth2.title'), + name: 'OAUTH2', + component: OAUTH2 + }, { label: '扫码登录', name: 'SCAN', diff --git a/ui/src/views/login/index.vue b/ui/src/views/login/index.vue index dbdba29a4..dbc2ab691 100644 --- a/ui/src/views/login/index.vue +++ b/ui/src/views/login/index.vue @@ -66,7 +66,8 @@ :key="item" class="login-button-circle color-secondary" @click="changeMode(item)" - >{{ item }} + > + {{ item }} - + + + + 历史聊天记录 + + + + + +