mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
新增三个工具供 LLM 调用: - lookup_transcript: 查询转录原文(按时间范围、关键词、位置筛选) - get_video_info: 获取视频元信息(标题、作者、简介、标签等) - get_note_content: 获取完整笔记 Markdown 内容 实现 tool calling 循环(最多 3 轮),LLM 可根据问题 主动调用工具获取所需信息,不再完全依赖 RAG 检索。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
159 lines
5.2 KiB
Python
159 lines
5.2 KiB
Python
import json
|
||
from typing import Optional
|
||
|
||
from app.gpt.gpt_factory import GPTFactory
|
||
from app.models.model_config import ModelConfig
|
||
from app.services.provider import ProviderService
|
||
from app.services.vector_store import VectorStoreManager
|
||
from app.services.chat_tools import TOOLS, execute_tool
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力:
|
||
|
||
1. 系统已自动检索了一些相关内容作为初始参考(见下方)
|
||
2. 你可以调用工具主动查询更多信息:
|
||
- lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选)
|
||
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
|
||
- get_note_content: 获取完整笔记内容
|
||
|
||
--- 初始检索内容 ---
|
||
{context}
|
||
---
|
||
|
||
回答要求:
|
||
- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息
|
||
- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文
|
||
- 回答关于作者、标题等基本信息时,用 get_video_info 查询
|
||
- 请用中文回答,保持简洁准确"""
|
||
|
||
|
||
def _build_context(chunks: list[dict]) -> str:
|
||
"""将检索到的片段拼接为上下文文本。"""
|
||
parts = []
|
||
for chunk in chunks:
|
||
meta = chunk.get("metadata", {})
|
||
source_type = meta.get("source_type", "unknown")
|
||
if source_type == "meta":
|
||
label = "[视频信息]"
|
||
elif source_type == "markdown":
|
||
label = f"[笔记 - {meta.get('section_title', '')}]"
|
||
else:
|
||
start = meta.get("start_time", 0)
|
||
end = meta.get("end_time", 0)
|
||
label = f"[转录 - {start:.0f}s~{end:.0f}s]"
|
||
parts.append(f"{label}\n{chunk['text']}")
|
||
return "\n\n".join(parts)
|
||
|
||
|
||
def _build_sources(chunks: list[dict]) -> list[dict]:
|
||
"""从检索片段中提取来源信息。"""
|
||
sources = []
|
||
for chunk in chunks:
|
||
meta = chunk.get("metadata", {})
|
||
source = {
|
||
"text": chunk["text"][:200],
|
||
"source_type": meta.get("source_type", "unknown"),
|
||
}
|
||
if meta.get("section_title"):
|
||
source["section_title"] = meta["section_title"]
|
||
if meta.get("start_time") is not None:
|
||
source["start_time"] = meta["start_time"]
|
||
if meta.get("end_time") is not None:
|
||
source["end_time"] = meta["end_time"]
|
||
sources.append(source)
|
||
return sources
|
||
|
||
|
||
def chat(
|
||
task_id: str,
|
||
question: str,
|
||
history: list[dict],
|
||
provider_id: str,
|
||
model_name: str,
|
||
) -> dict:
|
||
"""
|
||
RAG + Tool Calling 问答。
|
||
1. 向量检索初始上下文
|
||
2. 调用 LLM(带 tools)
|
||
3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM
|
||
4. 循环直到 LLM 给出最终回答
|
||
"""
|
||
vector_store = VectorStoreManager()
|
||
|
||
# 1. 检索初始上下文
|
||
chunks = vector_store.query(task_id, question, n_results=6)
|
||
context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)"
|
||
sources = _build_sources(chunks) if chunks else []
|
||
|
||
# 2. 构建消息
|
||
system_msg = SYSTEM_PROMPT.format(context=context)
|
||
messages = [{"role": "system", "content": system_msg}]
|
||
|
||
for msg in history[-20:]:
|
||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||
|
||
messages.append({"role": "user", "content": question})
|
||
|
||
# 3. 获取 LLM client
|
||
provider = ProviderService.get_provider_by_id(provider_id)
|
||
if not provider:
|
||
raise ValueError(f"未找到模型供应商: {provider_id}")
|
||
|
||
config = ModelConfig(
|
||
api_key=provider["api_key"],
|
||
base_url=provider["base_url"],
|
||
model_name=model_name,
|
||
provider=provider["type"],
|
||
name=provider["name"],
|
||
)
|
||
gpt = GPTFactory.from_config(config)
|
||
|
||
logger.info(f"Chat: task_id={task_id}, model={model_name}")
|
||
|
||
# 4. Tool calling 循环(最多 3 轮)
|
||
max_rounds = 3
|
||
for round_i in range(max_rounds):
|
||
response = gpt.client.chat.completions.create(
|
||
model=gpt.model,
|
||
messages=messages,
|
||
tools=TOOLS,
|
||
temperature=0.7,
|
||
)
|
||
|
||
msg = response.choices[0].message
|
||
|
||
# 没有工具调用,直接返回
|
||
if not msg.tool_calls:
|
||
return {"answer": msg.content or "", "sources": sources}
|
||
|
||
# 处理工具调用
|
||
messages.append(msg)
|
||
|
||
for tool_call in msg.tool_calls:
|
||
fn_name = tool_call.function.name
|
||
try:
|
||
fn_args = json.loads(tool_call.function.arguments)
|
||
except json.JSONDecodeError:
|
||
fn_args = {}
|
||
|
||
logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})")
|
||
|
||
result = execute_tool(task_id, fn_name, fn_args)
|
||
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"content": result,
|
||
})
|
||
|
||
# 超过最大轮次,做最后一次不带 tools 的调用
|
||
response = gpt.client.chat.completions.create(
|
||
model=gpt.model,
|
||
messages=messages,
|
||
temperature=0.7,
|
||
)
|
||
|
||
return {"answer": response.choices[0].message.content or "", "sources": sources}
|