From a9b8bdd36533310c8d2e6787ade9acd978d85666 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Fri, 19 Jul 2024 16:44:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/event/listener_manage.py | 4 +-- .../ollama_model_provider/model/embedding.py | 30 +++++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 1e874062c..0de80acee 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -159,7 +159,7 @@ class ListenerManagement: @param embedding_model 向量模型 :return: None """ - if not try_lock('embedding' + document_id): + if not try_lock('embedding' + str(document_id)): return max_kb.info(f"开始--->向量化文档:{document_id}") QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) @@ -186,7 +186,7 @@ class ListenerManagement: **{'status': status, 'update_time': datetime.datetime.now()}) QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) max_kb.info(f"结束--->向量化文档:{document_id}") - un_lock('embedding' + document_id) + un_lock('embedding' + str(document_id)) @staticmethod @embedding_poxy diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py b/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py index 2bb3bb659..d1a68ebc7 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py @@ -6,7 +6,7 @@ @date:2024/7/12 15:02 @desc: """ -from typing import Dict +from typing import Dict, List from langchain_community.embeddings import OllamaEmbeddings @@ -16,7 +16,33 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - return OllamaEmbeddings( + return OllamaEmbedding( model=model_name, base_url=model_credential.get('api_base'), ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using an Ollama deployed embedding model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + instruction_pairs = [f"{text}" for text in texts] + embeddings = self._embed(instruction_pairs) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a Ollama deployed embedding model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + instruction_pair = f"{text}" + embedding = self._embed([instruction_pair])[0] + return embedding