mirror of
https://github.com/labring/FastGPT.git
synced 2025-12-25 20:02:47 +00:00
perf: vector format (#5516)
* perf: vector format * feat: embedding batch size
This commit is contained in:
parent
a92917c05f
commit
95325346ff
|
|
@ -6,9 +6,12 @@ description: 'FastGPT V4.12.2 更新说明'
|
|||
|
||||
## 🚀 新增内容
|
||||
|
||||
1. 向量模型并发请求设置,不统一设置成 10,避免部分向量模型不支持并发,默认均为 1,可在模型配置中设置。
|
||||
|
||||
## ⚙️ 优化
|
||||
|
||||
1. 增加工作流**独立分支**异常检测。
|
||||
2. 向量模型超过 1536 维度进行截断时,强制进行归一化。其他维度是否归一化,完全由配置决定,减少自动判断的计算量。
|
||||
|
||||
## 🐛 修复
|
||||
|
||||
|
|
@ -17,6 +20,7 @@ description: 'FastGPT V4.12.2 更新说明'
|
|||
3. 移动端,分享链接,异常加载了登录态对话页的导航。
|
||||
4. 用户同步可能出现写冲突问题。
|
||||
5. 无法完全关闭系统套餐,会存在空对象默认值,导致鉴权异常。
|
||||
6. 工作流,添加团队应用,搜索无效。
|
||||
|
||||
## 🔨 工具更新
|
||||
|
||||
|
|
|
|||
|
|
@ -97,13 +97,14 @@
|
|||
"document/content/docs/protocol/terms.en.mdx": "2025-08-03T22:37:45+08:00",
|
||||
"document/content/docs/protocol/terms.mdx": "2025-08-03T22:37:45+08:00",
|
||||
"document/content/docs/toc.en.mdx": "2025-08-04T13:42:36+08:00",
|
||||
"document/content/docs/toc.mdx": "2025-08-13T14:29:13+08:00",
|
||||
"document/content/docs/toc.mdx": "2025-08-20T21:58:13+08:00",
|
||||
"document/content/docs/upgrading/4-10/4100.mdx": "2025-08-02T19:38:37+08:00",
|
||||
"document/content/docs/upgrading/4-10/4101.mdx": "2025-08-02T19:38:37+08:00",
|
||||
"document/content/docs/upgrading/4-11/4110.mdx": "2025-08-05T23:20:39+08:00",
|
||||
"document/content/docs/upgrading/4-11/4111.mdx": "2025-08-07T22:49:09+08:00",
|
||||
"document/content/docs/upgrading/4-12/4120.mdx": "2025-08-12T22:45:19+08:00",
|
||||
"document/content/docs/upgrading/4-12/4121.mdx": "2025-08-15T22:53:06+08:00",
|
||||
"document/content/docs/upgrading/4-12/4122.mdx": "2025-08-22T09:38:44+08:00",
|
||||
"document/content/docs/upgrading/4-8/40.mdx": "2025-08-02T19:38:37+08:00",
|
||||
"document/content/docs/upgrading/4-8/41.mdx": "2025-08-02T19:38:37+08:00",
|
||||
"document/content/docs/upgrading/4-8/42.mdx": "2025-08-02T19:38:37+08:00",
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ export type EmbeddingModelItemType = PriceType &
|
|||
weight: number; // training weight
|
||||
hidden?: boolean; // Disallow creation
|
||||
normalization?: boolean; // normalization processing
|
||||
batchSize?: number;
|
||||
defaultConfig?: Record<string, any>; // post request config
|
||||
dbConfig?: Record<string, any>; // Custom parameters for storage
|
||||
queryConfig?: Record<string, any>; // Custom parameters for query
|
||||
|
|
|
|||
|
|
@ -110,6 +110,8 @@ export enum FlowNodeTypeEnum {
|
|||
systemConfig = 'userGuide',
|
||||
pluginConfig = 'pluginConfig',
|
||||
globalVariable = 'globalVariable',
|
||||
comment = 'comment',
|
||||
|
||||
workflowStart = 'workflowStart',
|
||||
chatNode = 'chatNode',
|
||||
|
||||
|
|
@ -141,7 +143,6 @@ export enum FlowNodeTypeEnum {
|
|||
loopStart = 'loopStart',
|
||||
loopEnd = 'loopEnd',
|
||||
formInput = 'formInput',
|
||||
comment = 'comment',
|
||||
tool = 'tool',
|
||||
toolSet = 'toolSet'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,8 +23,9 @@ export async function getVectorsByText({ model, input, type, headers }: GetVecto
|
|||
|
||||
const formatInput = Array.isArray(input) ? input : [input];
|
||||
|
||||
// 20 size every request
|
||||
const chunkSize = parseInt(process.env.EMBEDDING_CHUNK_SIZE || '10');
|
||||
let chunkSize = Number(model.batchSize || 1);
|
||||
chunkSize = isNaN(chunkSize) ? 1 : chunkSize;
|
||||
|
||||
const chunks = [];
|
||||
for (let i = 0; i < formatInput.length; i += chunkSize) {
|
||||
chunks.push(formatInput.slice(i, i + chunkSize));
|
||||
|
|
@ -74,14 +75,7 @@ export async function getVectorsByText({ model, input, type, headers }: GetVecto
|
|||
const tokens = await Promise.all(chunk.map((item) => countPromptTokens(item)));
|
||||
return tokens.reduce((sum, item) => sum + item, 0);
|
||||
})(),
|
||||
Promise.all(
|
||||
res.data
|
||||
.map((item) => unityDimensional(item.embedding))
|
||||
.map((item) => {
|
||||
if (model.normalization) return normalization(item);
|
||||
return item;
|
||||
})
|
||||
)
|
||||
Promise.all(res.data.map((item) => formatVectors(item.embedding, model.normalization)))
|
||||
]);
|
||||
|
||||
return {
|
||||
|
|
@ -105,29 +99,35 @@ export async function getVectorsByText({ model, input, type, headers }: GetVecto
|
|||
}
|
||||
}
|
||||
|
||||
function unityDimensional(vector: number[]) {
|
||||
if (vector.length > 1536) {
|
||||
console.log(
|
||||
`The current vector dimension is ${vector.length}, and the vector dimension cannot exceed 1536. The first 1536 dimensions are automatically captured`
|
||||
);
|
||||
return vector.slice(0, 1536);
|
||||
}
|
||||
let resultVector = vector;
|
||||
const vectorLen = vector.length;
|
||||
|
||||
const zeroVector = new Array(1536 - vectorLen).fill(0);
|
||||
|
||||
return resultVector.concat(zeroVector);
|
||||
}
|
||||
// normalization processing
|
||||
function normalization(vector: number[]) {
|
||||
if (vector.some((item) => item > 1)) {
|
||||
export function formatVectors(vector: number[], normalization = false) {
|
||||
// normalization processing
|
||||
function normalizationVector(vector: number[]) {
|
||||
// Calculate the Euclidean norm (L2 norm)
|
||||
const norm = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));
|
||||
|
||||
if (norm === 0) {
|
||||
return vector;
|
||||
}
|
||||
// Normalize the vector by dividing each component by the norm
|
||||
return vector.map((val) => val / norm);
|
||||
}
|
||||
|
||||
// 超过上限,截断,并强制归一化
|
||||
if (vector.length > 1536) {
|
||||
console.log(
|
||||
`The current vector dimension is ${vector.length}, and the vector dimension cannot exceed 1536. The first 1536 dimensions are automatically captured`
|
||||
);
|
||||
return normalizationVector(vector.slice(0, 1536));
|
||||
} else if (vector.length < 1536) {
|
||||
const vectorLen = vector.length;
|
||||
|
||||
const zeroVector = new Array(1536 - vectorLen).fill(0);
|
||||
|
||||
vector = vector.concat(zeroVector);
|
||||
}
|
||||
|
||||
if (normalization) {
|
||||
return normalizationVector(vector);
|
||||
}
|
||||
|
||||
return vector;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
"avg_ttfb": "Average first word duration (seconds)",
|
||||
"azure": "Azure",
|
||||
"base_url": "Base url",
|
||||
"batch_size": "Number of concurrent requests",
|
||||
"channel_name": "Channel",
|
||||
"channel_priority": "Priority",
|
||||
"channel_priority_tip": "The higher the priority channel, the easier it is to be requested",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
"avg_ttfb": "平均首字时长 (秒)",
|
||||
"azure": "微软 Azure",
|
||||
"base_url": "代理地址",
|
||||
"batch_size": "并发请求数",
|
||||
"channel_name": "渠道名",
|
||||
"channel_priority": "优先级",
|
||||
"channel_priority_tip": "优先级越高的渠道,越容易被请求到",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
"avg_ttfb": "平均首字時長 (秒)",
|
||||
"azure": "Azure",
|
||||
"base_url": "代理地址",
|
||||
"batch_size": "並發請求數",
|
||||
"channel_name": "管道名稱",
|
||||
"channel_priority": "優先順序",
|
||||
"channel_priority_tip": "優先順序越高的管道,越容易被請求到",
|
||||
|
|
|
|||
|
|
@ -10,12 +10,14 @@ import MyTooltip from '@fastgpt/web/components/common/MyTooltip';
|
|||
import { ModelProviderList } from '@fastgpt/global/core/ai/provider';
|
||||
import MultipleRowSelect from '@fastgpt/web/components/common/MySelect/MultipleRowSelect';
|
||||
import { getModelFromList } from '@fastgpt/global/core/ai/model';
|
||||
import type { ResponsiveValue } from '@chakra-ui/system';
|
||||
|
||||
type Props = SelectProps & {
|
||||
disableTip?: string;
|
||||
noOfLines?: ResponsiveValue<number>;
|
||||
};
|
||||
|
||||
const OneRowSelector = ({ list, onChange, disableTip, ...props }: Props) => {
|
||||
const OneRowSelector = ({ list, onChange, disableTip, noOfLines, ...props }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { llmModelList, embeddingModelList, ttsModelList, sttModelList, reRankModelList } =
|
||||
useSystemStore();
|
||||
|
|
@ -55,7 +57,7 @@ const OneRowSelector = ({ list, onChange, disableTip, ...props }: Props) => {
|
|||
fallbackSrc={HUGGING_FACE_ICON}
|
||||
/>
|
||||
|
||||
<Box noOfLines={1}>{modelData.name}</Box>
|
||||
<Box noOfLines={noOfLines}>{modelData.name}</Box>
|
||||
</Flex>
|
||||
)
|
||||
};
|
||||
|
|
@ -99,7 +101,14 @@ const OneRowSelector = ({ list, onChange, disableTip, ...props }: Props) => {
|
|||
);
|
||||
};
|
||||
|
||||
const MultipleRowSelector = ({ list, onChange, disableTip, placeholder, ...props }: Props) => {
|
||||
const MultipleRowSelector = ({
|
||||
list,
|
||||
onChange,
|
||||
disableTip,
|
||||
placeholder,
|
||||
noOfLines,
|
||||
...props
|
||||
}: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { llmModelList, embeddingModelList, ttsModelList, sttModelList, reRankModelList } =
|
||||
useSystemStore();
|
||||
|
|
@ -189,7 +198,7 @@ const MultipleRowSelector = ({ list, onChange, disableTip, placeholder, ...props
|
|||
fallbackSrc={HUGGING_FACE_ICON}
|
||||
w={avatarSize}
|
||||
/>
|
||||
<Box noOfLines={1}>{modelData?.name}</Box>
|
||||
<Box noOfLines={noOfLines}>{modelData?.name}</Box>
|
||||
</Flex>
|
||||
);
|
||||
}, [modelList, props.value, t, avatarSize]);
|
||||
|
|
@ -222,7 +231,7 @@ const MultipleRowSelector = ({ list, onChange, disableTip, placeholder, ...props
|
|||
};
|
||||
|
||||
const AIModelSelector = (props: Props) => {
|
||||
return props.list.length > 100 ? (
|
||||
return props.list.length > 10 ? (
|
||||
<MultipleRowSelector {...props} />
|
||||
) : (
|
||||
<OneRowSelector {...props} />
|
||||
|
|
|
|||
|
|
@ -476,6 +476,26 @@ export const ModelEditModal = ({
|
|||
</Flex>
|
||||
</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Td>
|
||||
<HStack spacing={1}>
|
||||
<Box>{t('account_model:batch_size')}</Box>
|
||||
</HStack>
|
||||
</Td>
|
||||
<Td textAlign={'right'}>
|
||||
<Flex justifyContent={'flex-end'}>
|
||||
<MyNumberInput
|
||||
defaultValue={1}
|
||||
register={register}
|
||||
name="batchSize"
|
||||
min={1}
|
||||
step={1}
|
||||
isRequired
|
||||
{...InputStyles}
|
||||
/>
|
||||
</Flex>
|
||||
</Td>
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Td>
|
||||
<HStack spacing={1}>
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ export const useDebug = () => {
|
|||
const getNodes = useContextSelector(WorkflowNodeEdgeContext, (v) => v.getNodes);
|
||||
const edges = useContextSelector(WorkflowNodeEdgeContext, (v) => v.edges);
|
||||
const onUpdateNodeError = useContextSelector(WorkflowContext, (v) => v.onUpdateNodeError);
|
||||
const onRemoveError = useContextSelector(WorkflowContext, (v) => v.onRemoveError);
|
||||
const onStartNodeDebug = useContextSelector(WorkflowContext, (v) => v.onStartNodeDebug);
|
||||
|
||||
const appDetail = useContextSelector(AppContext, (v) => v.appDetail);
|
||||
|
|
@ -80,6 +81,7 @@ export const useDebug = () => {
|
|||
|
||||
const checkResults = checkWorkflowNodeAndConnection({ nodes, edges });
|
||||
if (!checkResults) {
|
||||
onRemoveError();
|
||||
const storeNodes = uiWorkflow2StoreWorkflow({ nodes, edges });
|
||||
|
||||
return JSON.stringify(storeNodes);
|
||||
|
|
|
|||
|
|
@ -157,6 +157,7 @@ type WorkflowContextType = {
|
|||
nodeList: FlowNodeItemType[];
|
||||
|
||||
onUpdateNodeError: (node: string, isError: Boolean) => void;
|
||||
onRemoveError: () => void;
|
||||
onResetNode: (e: { id: string; node: FlowNodeTemplateType }) => void;
|
||||
onChangeNode: (e: FlowNodeChangeProps) => void;
|
||||
getNodeDynamicInputs: (nodeId: string) => FlowNodeInputItemType[];
|
||||
|
|
@ -401,6 +402,9 @@ export const WorkflowContext = createContext<WorkflowContextType>({
|
|||
isSaved?: boolean;
|
||||
}): boolean {
|
||||
throw new Error('Function not implemented.');
|
||||
},
|
||||
onRemoveError: function (): void {
|
||||
throw new Error('Function not implemented.');
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -473,6 +477,17 @@ const WorkflowContextProvider = ({
|
|||
});
|
||||
});
|
||||
});
|
||||
const onRemoveError = useMemoizedFn(() => {
|
||||
setNodes((state) => {
|
||||
return state.map((item) => {
|
||||
if (item.data.isError) {
|
||||
item.data.isError = false;
|
||||
item.selected = false;
|
||||
}
|
||||
return item;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// reset a node data. delete edge and replace it
|
||||
const onResetNode = useMemoizedFn(({ id, node }: { id: string; node: FlowNodeTemplateType }) => {
|
||||
|
|
@ -625,6 +640,7 @@ const WorkflowContextProvider = ({
|
|||
const checkResults = checkWorkflowNodeAndConnection({ nodes, edges });
|
||||
|
||||
if (!checkResults) {
|
||||
onRemoveError();
|
||||
const storeWorkflow = uiWorkflow2StoreWorkflow({ nodes, edges });
|
||||
|
||||
return storeWorkflow;
|
||||
|
|
@ -1025,6 +1041,7 @@ const WorkflowContextProvider = ({
|
|||
// node
|
||||
nodeList,
|
||||
onUpdateNodeError,
|
||||
onRemoveError,
|
||||
onResetNode,
|
||||
onChangeNode,
|
||||
getNodeDynamicInputs,
|
||||
|
|
@ -1075,6 +1092,7 @@ const WorkflowContextProvider = ({
|
|||
onChangeNode,
|
||||
onDelEdge,
|
||||
onNextNodeDebug,
|
||||
onRemoveError,
|
||||
onResetNode,
|
||||
onStartNodeDebug,
|
||||
onStopNodeDebug,
|
||||
|
|
|
|||
|
|
@ -248,6 +248,7 @@ const HomeChatWindow = ({ myApps }: Props) => {
|
|||
size="sm"
|
||||
bg={'myGray.50'}
|
||||
rounded="full"
|
||||
noOfLines={[1, 3]}
|
||||
list={availableModels}
|
||||
value={selectedModel}
|
||||
onChange={async (model) => {
|
||||
|
|
|
|||
|
|
@ -381,6 +381,11 @@ export const getNodeAllSource = ({
|
|||
};
|
||||
|
||||
/* ====== Connection ======= */
|
||||
// Connectivity check result type
|
||||
type ConnectivityIssue = {
|
||||
nodeId: string;
|
||||
issue: 'isolated' | 'no_input' | 'unreachable_from_start';
|
||||
};
|
||||
export const checkWorkflowNodeAndConnection = ({
|
||||
nodes,
|
||||
edges
|
||||
|
|
@ -388,7 +393,7 @@ export const checkWorkflowNodeAndConnection = ({
|
|||
nodes: Node<FlowNodeItemType, string | undefined>[];
|
||||
edges: Edge<any>[];
|
||||
}): string[] | undefined => {
|
||||
// 1. reference check. Required value
|
||||
// Node check
|
||||
for (const node of nodes) {
|
||||
const data = node.data;
|
||||
const inputs = data.inputs;
|
||||
|
|
@ -453,6 +458,15 @@ export const checkWorkflowNodeAndConnection = ({
|
|||
return [data.nodeId];
|
||||
}
|
||||
}
|
||||
if (data.flowNodeType === FlowNodeTypeEnum.agent) {
|
||||
const toolConnections = edges.filter(
|
||||
(edge) =>
|
||||
edge.source === data.nodeId && edge.sourceHandle === NodeOutputKeyEnum.selectedTools
|
||||
);
|
||||
if (toolConnections.length === 0) {
|
||||
return [data.nodeId];
|
||||
}
|
||||
}
|
||||
|
||||
// check node input
|
||||
if (
|
||||
|
|
@ -506,7 +520,7 @@ export const checkWorkflowNodeAndConnection = ({
|
|||
return [data.nodeId];
|
||||
}
|
||||
|
||||
// filter tools node edge
|
||||
// Check node has invalid edge
|
||||
const edgeFilted = edges.filter(
|
||||
(edge) =>
|
||||
!(
|
||||
|
|
@ -514,7 +528,7 @@ export const checkWorkflowNodeAndConnection = ({
|
|||
edge.sourceHandle === NodeOutputKeyEnum.selectedTools
|
||||
)
|
||||
);
|
||||
// check node has edge
|
||||
// Check node has edge
|
||||
const hasEdge = edgeFilted.some(
|
||||
(edge) => edge.source === data.nodeId || edge.target === data.nodeId
|
||||
);
|
||||
|
|
@ -522,6 +536,106 @@ export const checkWorkflowNodeAndConnection = ({
|
|||
return [data.nodeId];
|
||||
}
|
||||
}
|
||||
|
||||
// Edge check
|
||||
|
||||
/**
|
||||
* Check graph connectivity and identify connectivity issues
|
||||
*/
|
||||
const checkConnectivity = (
|
||||
nodes: Node<FlowNodeItemType, string | undefined>[],
|
||||
edges: Edge<any>[]
|
||||
): string[] => {
|
||||
// Find start node
|
||||
const startNode = nodes.find(
|
||||
(node) =>
|
||||
node.data.flowNodeType === FlowNodeTypeEnum.workflowStart ||
|
||||
node.data.flowNodeType === FlowNodeTypeEnum.pluginInput
|
||||
);
|
||||
|
||||
if (!startNode) {
|
||||
// No start node found - this is a critical issue
|
||||
return nodes.map((node) => node.data.nodeId);
|
||||
}
|
||||
|
||||
const issues: ConnectivityIssue[] = [];
|
||||
|
||||
// Build adjacency lists for both directions
|
||||
const outgoing = new Map<string, string[]>();
|
||||
const incoming = new Map<string, string[]>();
|
||||
|
||||
nodes.forEach((node) => {
|
||||
outgoing.set(node.data.nodeId, []);
|
||||
incoming.set(node.data.nodeId, []);
|
||||
});
|
||||
|
||||
edges.forEach((edge) => {
|
||||
const outList = outgoing.get(edge.source) || [];
|
||||
outList.push(edge.target);
|
||||
outgoing.set(edge.source, outList);
|
||||
|
||||
const inList = incoming.get(edge.target) || [];
|
||||
inList.push(edge.source);
|
||||
incoming.set(edge.target, inList);
|
||||
});
|
||||
|
||||
// Check reachability from start node(Start node/Loop start 可以到达的地方)
|
||||
const reachableFromStart = new Set<string>();
|
||||
const dfsFromStart = (nodeId: string) => {
|
||||
if (reachableFromStart.has(nodeId)) return;
|
||||
reachableFromStart.add(nodeId);
|
||||
|
||||
const neighbors = outgoing.get(nodeId) || [];
|
||||
neighbors.forEach((neighbor) => dfsFromStart(neighbor));
|
||||
};
|
||||
dfsFromStart(startNode.data.nodeId);
|
||||
nodes.forEach((node) => {
|
||||
if (node.data.flowNodeType === FlowNodeTypeEnum.loopStart) {
|
||||
dfsFromStart(node.data.nodeId);
|
||||
}
|
||||
});
|
||||
|
||||
// Check each node for connectivity issues
|
||||
for (const node of nodes) {
|
||||
const nodeId = node.data.nodeId;
|
||||
const nodeType = node.data.flowNodeType;
|
||||
|
||||
// Skip system nodes that don't need connectivity checks
|
||||
if (
|
||||
nodeType === FlowNodeTypeEnum.systemConfig ||
|
||||
nodeType === FlowNodeTypeEnum.pluginConfig ||
|
||||
nodeType === FlowNodeTypeEnum.comment ||
|
||||
nodeType === FlowNodeTypeEnum.globalVariable ||
|
||||
nodeType === FlowNodeTypeEnum.emptyNode
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const hasIncoming = (incoming.get(nodeId) || []).length > 0;
|
||||
const hasOutgoing = (outgoing.get(nodeId) || []).length > 0;
|
||||
const isStartNode = [
|
||||
FlowNodeTypeEnum.workflowStart,
|
||||
FlowNodeTypeEnum.pluginInput,
|
||||
FlowNodeTypeEnum.loopStart
|
||||
].includes(nodeType);
|
||||
|
||||
// Check if node is reachable from start
|
||||
if (!isStartNode && !reachableFromStart.has(nodeId)) {
|
||||
issues.push({
|
||||
nodeId,
|
||||
issue: 'unreachable_from_start'
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return issues.map((issue) => issue.nodeId);
|
||||
};
|
||||
|
||||
const connectivityIssues = checkConnectivity(nodes, edges);
|
||||
if (connectivityIssues.length > 0) {
|
||||
return connectivityIssues;
|
||||
}
|
||||
};
|
||||
|
||||
/* ====== Variables ======= */
|
||||
|
|
|
|||
|
|
@ -0,0 +1,224 @@
|
|||
import { formatVectors } from '@fastgpt/service/core/ai/embedding/index';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
describe('formatVectors function test', () => {
|
||||
// Helper function to create a normalized vector (L2 norm = 1)
|
||||
const createNormalizedVector = (length: number): number[] => {
|
||||
const vector = Array.from({ length }, (_, i) => (i + 1) / length);
|
||||
const norm = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));
|
||||
return vector.map((val) => val / norm);
|
||||
};
|
||||
|
||||
// Helper function to create an unnormalized vector
|
||||
const createUnnormalizedVector = (length: number): number[] => {
|
||||
return Array.from({ length }, (_, i) => (i + 1) * 10);
|
||||
};
|
||||
|
||||
// Helper function to calculate L2 norm
|
||||
const calculateNorm = (vector: number[]): number => {
|
||||
return Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));
|
||||
};
|
||||
|
||||
// Helper function to check if vector is normalized (L2 norm H 1)
|
||||
const isNormalized = (vector: number[]): boolean => {
|
||||
const norm = calculateNorm(vector);
|
||||
return Math.abs(norm - 1) < 1e-10;
|
||||
};
|
||||
|
||||
describe('1536 dimension vectors', () => {
|
||||
it('should handle normalized 1536-dim vector with normalization=true', () => {
|
||||
const inputVector = createNormalizedVector(1536);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// Since input is already normalized, result should be very similar
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining(inputVector.map((val) => expect.closeTo(val, 10)))
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle normalized 1536-dim vector with normalization=false', () => {
|
||||
const inputVector = createNormalizedVector(1536);
|
||||
const result = formatVectors(inputVector, false);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(result).toEqual(inputVector);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle unnormalized 1536-dim vector with normalization=true', () => {
|
||||
const inputVector = createUnnormalizedVector(1536);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// Result should be different from input (normalized)
|
||||
expect(result).not.toEqual(inputVector);
|
||||
});
|
||||
|
||||
it('should handle unnormalized 1536-dim vector with normalization=false', () => {
|
||||
const inputVector = createUnnormalizedVector(1536);
|
||||
const result = formatVectors(inputVector, false);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(result).toEqual(inputVector);
|
||||
expect(isNormalized(result)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Greater than 1536 dimension vectors', () => {
|
||||
it('should handle normalized >1536-dim vector with normalization=true', () => {
|
||||
const inputVector = createNormalizedVector(2048);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// Should be truncated to first 1536 elements and then normalized
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining(inputVector.slice(0, 1536).map((val) => expect.any(Number)))
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle normalized >1536-dim vector with normalization=false', () => {
|
||||
const inputVector = createNormalizedVector(2048);
|
||||
const result = formatVectors(inputVector, true); // Always normalized for >1536 dims
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// Should be truncated and normalized regardless of normalization flag
|
||||
});
|
||||
|
||||
it('should handle unnormalized >1536-dim vector with normalization=true', () => {
|
||||
const inputVector = createUnnormalizedVector(2048);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// Should be truncated to first 1536 elements and then normalized
|
||||
});
|
||||
|
||||
it('should handle unnormalized >1536-dim vector with normalization=false', () => {
|
||||
const inputVector = createUnnormalizedVector(2048);
|
||||
const result = formatVectors(inputVector, false); // Always normalized for >1536 dims
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// Should be truncated and normalized regardless of normalization flag
|
||||
});
|
||||
|
||||
it('should log warning for vectors with length > 1536', () => {
|
||||
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
const inputVector = createNormalizedVector(2000);
|
||||
|
||||
formatVectors(inputVector, false);
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'The current vector dimension is 2000, and the vector dimension cannot exceed 1536'
|
||||
)
|
||||
);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Less than 1536 dimension vectors', () => {
|
||||
it('should handle normalized <1536-dim vector with normalization=true', () => {
|
||||
const inputVector = createNormalizedVector(512);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// First 512 elements should match input, rest should be 0
|
||||
expect(result.slice(0, 512)).toEqual(
|
||||
expect.arrayContaining(inputVector.map((val) => expect.any(Number)))
|
||||
);
|
||||
expect(result.slice(512)).toEqual(new Array(1024).fill(0));
|
||||
});
|
||||
|
||||
it('should handle normalized <1536-dim vector with normalization=false', () => {
|
||||
const inputVector = createNormalizedVector(512);
|
||||
const result = formatVectors(inputVector, false);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
// First 512 elements should match input exactly, rest should be 0
|
||||
expect(result.slice(0, 512)).toEqual(inputVector);
|
||||
expect(result.slice(512)).toEqual(new Array(1024).fill(0));
|
||||
// The result remains normalized because adding zeros doesn't change the L2 norm
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle unnormalized <1536-dim vector with normalization=true', () => {
|
||||
const inputVector = createUnnormalizedVector(512);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
// Should be padded with zeros and then normalized
|
||||
expect(result.slice(512)).toEqual(new Array(1024).fill(0));
|
||||
});
|
||||
|
||||
it('should handle unnormalized <1536-dim vector with normalization=false', () => {
|
||||
const inputVector = createUnnormalizedVector(512);
|
||||
const result = formatVectors(inputVector, false);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
// First 512 elements should match input exactly, rest should be 0
|
||||
expect(result.slice(0, 512)).toEqual(inputVector);
|
||||
expect(result.slice(512)).toEqual(new Array(1024).fill(0));
|
||||
expect(isNormalized(result)).toBe(false);
|
||||
});
|
||||
|
||||
it('should demonstrate that padding preserves normalization status', () => {
|
||||
// Create a vector that becomes unnormalized after some scaling
|
||||
const baseVector = [3, 4]; // norm = 5, not normalized
|
||||
const result = formatVectors(baseVector, false);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(result[0]).toBe(3);
|
||||
expect(result[1]).toBe(4);
|
||||
expect(result.slice(2)).toEqual(new Array(1534).fill(0));
|
||||
expect(isNormalized(result)).toBe(false);
|
||||
expect(calculateNorm(result)).toBeCloseTo(5, 10);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle zero vector', () => {
|
||||
const inputVector = new Array(1536).fill(0);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(result).toEqual(inputVector); // Zero vector remains zero after normalization
|
||||
});
|
||||
|
||||
it('should handle single element vector', () => {
|
||||
const inputVector = [5.0];
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(result[0]).toBeCloseTo(1.0, 10); // Normalized single element should be 1
|
||||
expect(result.slice(1)).toEqual(new Array(1535).fill(0));
|
||||
});
|
||||
|
||||
it('should handle exactly 1536 dimension vector', () => {
|
||||
const inputVector = createNormalizedVector(1536);
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle vector with negative values', () => {
|
||||
const inputVector = [-1, -2, -3];
|
||||
const result = formatVectors(inputVector, true);
|
||||
|
||||
expect(result).toHaveLength(1536);
|
||||
expect(isNormalized(result)).toBe(true);
|
||||
expect(result[0]).toBeLessThan(0); // Should preserve negative values
|
||||
expect(result[1]).toBeLessThan(0);
|
||||
expect(result[2]).toBeLessThan(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
Loading…
Reference in New Issue