refactor: 查询模型优化

This commit is contained in:
CaptainB 2024-10-11 18:55:56 +08:00 committed by 刘瑞斌
parent 2f8aa5e5fa
commit 5f0bc2401e
3 changed files with 74 additions and 11 deletions

View File

@ -235,7 +235,6 @@ class Dataset(APIView):
dynamic_tag=keywords.get('dataset_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, dataset_id: str):
print(dataset_id)
return result.success(
ModelSerializer.Query(
data={'user_id': request.user.id, 'model_type': 'LLM'}).list(

View File

@ -75,12 +75,26 @@ class ModelSerializer(serializers.Serializer):
provider = serializers.CharField(required=False, error_messages=ErrMessage.char("供应商"))
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"))
create_user = serializers.CharField(required=False, error_messages=ErrMessage.char("创建者"))
def list(self, with_valid):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
name = self.data.get('name')
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
create_user = self.data.get('create_user')
if create_user is not None:
# 当前用户能查看自己的模型,包括公开和私有的
if create_user == user_id:
model_query_set = QuerySet(Model).filter(Q(user_id=create_user))
# 当前用户能查看其他人的模型,只能查看公开的
else:
model_query_set = QuerySet(Model).filter((Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC')))
else:
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
query_params = {}
if name is not None:
query_params['name__contains'] = name
@ -90,6 +104,9 @@ class ModelSerializer(serializers.Serializer):
query_params['model_name'] = self.data.get('model_name')
if self.data.get('provider') is not None:
query_params['provider'] = self.data.get('provider')
if self.data.get('permission_type') is not None:
query_params['permission_type'] = self.data.get('permission_type')
return [
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,

View File

@ -41,14 +41,42 @@
<h4>{{ active_provider?.name }}</h4>
<div class="flex-between mt-16 mb-16">
<el-button type="primary" @click="openCreateModel(active_provider)">添加模型</el-button>
<el-input
v-model="model_search_form.name"
@change="list_model"
placeholder="按名称搜索"
prefix-icon="Search"
style="max-width: 240px"
clearable
/>
<div class="flex-between">
<el-select v-model="search_type" style="width: 200px" @change="search_type_change">
<el-option label="创建者" value="create_user" />
<el-option label="权限" value="permission_type" />
<el-option label="模型类型" value="model_type" />
<el-option label="模型名称" value="name" />
</el-select>
<el-input
v-if="search_type === 'name'"
v-model="model_search_form.name"
@change="list_model"
placeholder="按名称搜索"
prefix-icon="Search"
style="max-width: 240px"
clearable
/>
<el-select v-else-if="search_type === 'create_user'" v-model="model_search_form.create_user" @change="list_model"
clearable>
<el-option v-for="u in user_options" :key="u.id" :value="u.id" :label="u.username" />
</el-select>
<el-select v-else-if="search_type === 'permission_type'" v-model="model_search_form.permission_type"
clearable
@change="list_model">
<el-option label="公有" value="PUBLIC" />
<el-option label="私有" value="PRIVATE" />
</el-select>
<el-select v-else-if="search_type === 'model_type'" v-model="model_search_form.model_type"
clearable
@change="list_model">
<el-option label="大语言模型" value="LLM" />
<el-option label="向量模型" value="EMBEDDING" />
<el-option label="重排模型" value="RERANKER" />
<el-option label="语音识别" value="STT" />
<el-option label="语音合成" value="TTS" />
</el-select>
</div>
</div>
</div>
<div class="model-list-height">
@ -114,7 +142,14 @@ const allObj = {
const loading = ref<boolean>(false)
const active_provider = ref<Provider>()
const model_search_form = ref<{ name: string }>({ name: '' })
const search_type = ref('name')
const model_search_form = ref<{ name: string, create_user: string, permission_type: string, model_type: string }>({
name: '',
create_user: '',
permission_type: '',
model_type: ''
})
const user_options = ref<any[]>([])
const list_model_loading = ref<boolean>(false)
const provider_list = ref<Array<Provider>>([])
@ -150,9 +185,20 @@ const list_model = () => {
const params = active_provider.value?.provider ? { provider: active_provider.value.provider } : {}
ModelApi.getModel({ ...model_search_form.value, ...params }, list_model_loading).then((ok) => {
model_list.value = ok.data
const v = model_list.value.map((m) => ({ id: m.user_id, username: m.username }))
if (user_options.value.length === 0){
user_options.value = Array.from(
new Map(v.map(item => [item.id, item])).values()
)
}
})
}
const search_type_change = () => {
model_search_form.value = { name: '', create_user: '', permission_type: '', model_type: '' }
}
onMounted(() => {
ModelApi.getProvider(loading).then((ok) => {
active_provider.value = allObj
@ -173,6 +219,7 @@ onMounted(() => {
.model-list-height {
height: calc(var(--create-dataset-height) - 70px);
}
.model-list-height-left {
height: calc(var(--create-dataset-height));
}