diff --git a/apps/common/management/commands/services/services/local_model.py b/apps/common/management/commands/services/services/local_model.py index 1e5e4bc13..4511f8f5f 100644 --- a/apps/common/management/commands/services/services/local_model.py +++ b/apps/common/management/commands/services/services/local_model.py @@ -23,7 +23,7 @@ class GunicornLocalModelService(BaseService): print("\n- Start Gunicorn Local Model WSGI HTTP Server") os.environ.setdefault('SERVER_NAME', 'local_model') log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' - bind = f'127.0.0.1:5432' + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' cmd = [ 'gunicorn', 'smartdoc.wsgi:application', '-b', bind, diff --git a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py index dd4918909..38cf1aeb9 100644 --- a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py @@ -14,6 +14,7 @@ from langchain_core.pydantic_v1 import BaseModel from langchain_huggingface import HuggingFaceEmbeddings from setting.models_provider.base_model_provider import MaxKBBaseModel +from smartdoc.const import CONFIG class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): @@ -28,14 +29,18 @@ class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): self.model_id = kwargs.get('model_id', None) def embed_query(self, text: str) -> List[float]: - res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_query', {'text': text}) + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/embed_query', + {'text': text}) result = res.json() if result.get('code', 500) == 200: return result.get('data') raise Exception(result.get('msg')) def embed_documents(self, texts: List[str]) -> List[List[float]]: - res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_documents', {'texts': texts}) + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/embed_documents', + {'texts': texts}) result = res.json() if result.get('code', 500) == 200: return result.get('data') diff --git a/apps/smartdoc/conf.py b/apps/smartdoc/conf.py index a7665149f..034973981 100644 --- a/apps/smartdoc/conf.py +++ b/apps/smartdoc/conf.py @@ -88,7 +88,10 @@ class Config(dict): # 向量库配置 "VECTOR_STORE_NAME": 'pg_vector', "DEBUG": False, - 'SANDBOX': False + 'SANDBOX': False, + 'LOCAL_MODEL_HOST': '127.0.0.1', + 'LOCAL_MODEL_PORT': '11636', + 'LOCAL_MODEL_PROTOCOL': "http" } diff --git a/main.py b/main.py index 1ae33fcb4..a8bd74af4 100644 --- a/main.py +++ b/main.py @@ -75,7 +75,9 @@ def dev(): management.call_command('celery', 'celery') elif services.__contains__('local_model'): os.environ.setdefault('SERVER_NAME', 'local_model') - management.call_command('runserver', "127.0.0.1:5432") + from smartdoc.const import CONFIG + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + management.call_command('runserver', bind) if __name__ == '__main__':