Files
BiliNote/backend/app/services/chat_service.py
huangjianwu 05877a2197 feat(chat): 支持 function calling,模型可主动查询原文数据
新增三个工具供 LLM 调用:
- lookup_transcript: 查询转录原文(按时间范围、关键词、位置筛选)
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
- get_note_content: 获取完整笔记 Markdown 内容

实现 tool calling 循环(最多 3 轮),LLM 可根据问题
主动调用工具获取所需信息,不再完全依赖 RAG 检索。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 15:48:23 +08:00

159 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from typing import Optional
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
from app.services.vector_store import VectorStoreManager
from app.services.chat_tools import TOOLS, execute_tool
from app.utils.logger import get_logger
logger = get_logger(__name__)
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力:
1. 系统已自动检索了一些相关内容作为初始参考(见下方)
2. 你可以调用工具主动查询更多信息:
- lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选)
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
- get_note_content: 获取完整笔记内容
--- 初始检索内容 ---
{context}
---
回答要求:
- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息
- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文
- 回答关于作者、标题等基本信息时,用 get_video_info 查询
- 请用中文回答,保持简洁准确"""
def _build_context(chunks: list[dict]) -> str:
"""将检索到的片段拼接为上下文文本。"""
parts = []
for chunk in chunks:
meta = chunk.get("metadata", {})
source_type = meta.get("source_type", "unknown")
if source_type == "meta":
label = "[视频信息]"
elif source_type == "markdown":
label = f"[笔记 - {meta.get('section_title', '')}]"
else:
start = meta.get("start_time", 0)
end = meta.get("end_time", 0)
label = f"[转录 - {start:.0f}s~{end:.0f}s]"
parts.append(f"{label}\n{chunk['text']}")
return "\n\n".join(parts)
def _build_sources(chunks: list[dict]) -> list[dict]:
"""从检索片段中提取来源信息。"""
sources = []
for chunk in chunks:
meta = chunk.get("metadata", {})
source = {
"text": chunk["text"][:200],
"source_type": meta.get("source_type", "unknown"),
}
if meta.get("section_title"):
source["section_title"] = meta["section_title"]
if meta.get("start_time") is not None:
source["start_time"] = meta["start_time"]
if meta.get("end_time") is not None:
source["end_time"] = meta["end_time"]
sources.append(source)
return sources
def chat(
task_id: str,
question: str,
history: list[dict],
provider_id: str,
model_name: str,
) -> dict:
"""
RAG + Tool Calling 问答。
1. 向量检索初始上下文
2. 调用 LLM带 tools
3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM
4. 循环直到 LLM 给出最终回答
"""
vector_store = VectorStoreManager()
# 1. 检索初始上下文
chunks = vector_store.query(task_id, question, n_results=6)
context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)"
sources = _build_sources(chunks) if chunks else []
# 2. 构建消息
system_msg = SYSTEM_PROMPT.format(context=context)
messages = [{"role": "system", "content": system_msg}]
for msg in history[-20:]:
messages.append({"role": msg["role"], "content": msg["content"]})
messages.append({"role": "user", "content": question})
# 3. 获取 LLM client
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
raise ValueError(f"未找到模型供应商: {provider_id}")
config = ModelConfig(
api_key=provider["api_key"],
base_url=provider["base_url"],
model_name=model_name,
provider=provider["type"],
name=provider["name"],
)
gpt = GPTFactory.from_config(config)
logger.info(f"Chat: task_id={task_id}, model={model_name}")
# 4. Tool calling 循环(最多 3 轮)
max_rounds = 3
for round_i in range(max_rounds):
response = gpt.client.chat.completions.create(
model=gpt.model,
messages=messages,
tools=TOOLS,
temperature=0.7,
)
msg = response.choices[0].message
# 没有工具调用,直接返回
if not msg.tool_calls:
return {"answer": msg.content or "", "sources": sources}
# 处理工具调用
messages.append(msg)
for tool_call in msg.tool_calls:
fn_name = tool_call.function.name
try:
fn_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
fn_args = {}
logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})")
result = execute_tool(task_id, fn_name, fn_args)
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": result,
})
# 超过最大轮次,做最后一次不带 tools 的调用
response = gpt.client.chat.completions.create(
model=gpt.model,
messages=messages,
temperature=0.7,
)
return {"answer": response.choices[0].message.content or "", "sources": sources}