diff --git a/backend/app/services/vector_store.py b/backend/app/services/vector_store.py index e938c25..464e9f2 100644 --- a/backend/app/services/vector_store.py +++ b/backend/app/services/vector_store.py @@ -173,10 +173,10 @@ class VectorStoreManager: }) return chunks - def query(self, task_id: str, query_text: str, n_results: int = 5) -> list[dict]: + def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]: """ - 分别从 markdown 和 transcript 各检索,确保两种来源都被召回, - 最后按距离排序返回 top-n。 + 按固定配额从各来源检索:meta 1 条、markdown 2 条、transcript 3 条, + 确保三种来源都被召回。 """ col_name = self._collection_name(task_id) try: @@ -187,21 +187,21 @@ class VectorStoreManager: all_chunks = [] - # 分别从各来源检索 - for source_type in ("meta", "markdown", "transcript"): + # 每种来源的配额 + quotas = {"meta": 1, "markdown": 2, "transcript": 3} + + for source_type, quota in quotas.items(): try: results = collection.query( query_texts=[query_text], - n_results=n_results, + n_results=quota, where={"source_type": source_type}, ) all_chunks.extend(self._parse_results(results)) except Exception: pass - # 按距离排序,取 top-n - all_chunks.sort(key=lambda c: c.get("distance", 999)) - return all_chunks[:n_results] + return all_chunks def delete_index(self, task_id: str) -> None: """删除指定任务的向量索引。"""