feat(chat): 支持 function calling,模型可主动查询原文数据

新增三个工具供 LLM 调用:
- lookup_transcript: 查询转录原文(按时间范围、关键词、位置筛选)
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
- get_note_content: 获取完整笔记 Markdown 内容

实现 tool calling 循环(最多 3 轮),LLM 可根据问题
主动调用工具获取所需信息,不再完全依赖 RAG 检索。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
huangjianwu
2026-03-23 15:48:23 +08:00
parent 3e9f908d7b
commit 05877a2197
2 changed files with 247 additions and 35 deletions

View File

@@ -1,35 +1,38 @@
import json
from typing import Optional
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
from app.services.vector_store import VectorStoreManager
from app.services.chat_tools import TOOLS, execute_tool
from app.utils.logger import get_logger
logger = get_logger(__name__)
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你可以参考两种来源回答用户的问题
1. [视频信息] — 视频标题、作者、简介、标签等元信息
2. [笔记] — AI 生成的视频摘要笔记
3. [转录] — 视频原始语音转录文本(含时间戳)
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力
以下是检索到的相关内容:
1. 系统已自动检索了一些相关内容作为初始参考(见下方)
2. 你可以调用工具主动查询更多信息:
- lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选)
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
- get_note_content: 获取完整笔记内容
--- 相关内容 ---
--- 初始检索内容 ---
{context}
---
回答要求:
- 优先使用转录原文回答关于视频具体内容、原话、细节的问题
- 优先使用笔记回答关于总结、要点、结构的问题
- 如果确实没有相关信息,请诚实告知
- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息
- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文
- 回答关于作者、标题等基本信息时,用 get_video_info 查询
- 请用中文回答,保持简洁准确"""
def _build_context(chunks: list[dict]) -> str:
"""将检索到的片段拼接为上下文文本。"""
parts = []
for i, chunk in enumerate(chunks, 1):
for chunk in chunks:
meta = chunk.get("metadata", {})
source_type = meta.get("source_type", "unknown")
if source_type == "meta":
@@ -71,39 +74,29 @@ def chat(
model_name: str,
) -> dict:
"""
RAG 问答:检索相关片段 → 构建 prompt → 调用 LLM → 返回答案 + 来源
Returns:
{"answer": str, "sources": list[dict]}
RAG + Tool Calling 问答
1. 向量检索初始上下文
2. 调用 LLM带 tools
3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM
4. 循环直到 LLM 给出最终回答
"""
vector_store = VectorStoreManager()
# 1. 检索相关片段
chunks = vector_store.query(task_id, question, n_results=5)
print(
f"检索到 {len(chunks)} 个相关片段: {[c['metadata'].get('source_type') for c in chunks]}"
)
if not chunks:
return {
"answer": "暂未找到相关笔记内容,请确认笔记已生成并完成索引。",
"sources": [],
}
# 1. 检索初始上下文
chunks = vector_store.query(task_id, question, n_results=6)
context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)"
sources = _build_sources(chunks) if chunks else []
# 2. 构建上下文和来源
context = _build_context(chunks)
sources = _build_sources(chunks)
# 3. 构建消息
# 2. 构建消息
system_msg = SYSTEM_PROMPT.format(context=context)
messages = [{"role": "system", "content": system_msg}]
# 加入历史对话(最近 10 轮)
for msg in history[-20:]:
messages.append({"role": msg["role"], "content": msg["content"]})
messages.append({"role": "user", "content": question})
# 4. 调用 LLM
# 3. 获取 LLM client
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
raise ValueError(f"未找到模型供应商: {provider_id}")
@@ -117,14 +110,49 @@ def chat(
)
gpt = GPTFactory.from_config(config)
logger.info(f"Chat RAG: task_id={task_id}, provider={provider['name']}, model={model_name}")
logger.info(f"Chat: task_id={task_id}, model={model_name}")
# 4. Tool calling 循环(最多 3 轮)
max_rounds = 3
for round_i in range(max_rounds):
response = gpt.client.chat.completions.create(
model=gpt.model,
messages=messages,
tools=TOOLS,
temperature=0.7,
)
msg = response.choices[0].message
# 没有工具调用,直接返回
if not msg.tool_calls:
return {"answer": msg.content or "", "sources": sources}
# 处理工具调用
messages.append(msg)
for tool_call in msg.tool_calls:
fn_name = tool_call.function.name
try:
fn_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
fn_args = {}
logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})")
result = execute_tool(task_id, fn_name, fn_args)
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": result,
})
# 超过最大轮次,做最后一次不带 tools 的调用
response = gpt.client.chat.completions.create(
model=gpt.model,
messages=messages,
temperature=0.7,
)
answer = response.choices[0].message.content
return {"answer": answer, "sources": sources}
return {"answer": response.choices[0].message.content or "", "sources": sources}

View File

@@ -0,0 +1,184 @@
"""
Chat function calling 工具定义与执行。
提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。
"""
import json
import os
from typing import Optional
from app.utils.logger import get_logger
logger = get_logger(__name__)
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
def _load_note_data(task_id: str) -> Optional[dict]:
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
# ── 工具定义OpenAI function calling 格式)──────────────────────
TOOLS = [
{
"type": "function",
"function": {
"name": "lookup_transcript",
"description": "查询视频原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。",
"parameters": {
"type": "object",
"properties": {
"start_time": {
"type": "number",
"description": "起始时间(秒),例如 0 表示视频开头60 表示第1分钟",
},
"end_time": {
"type": "number",
"description": "结束时间(秒),不传则到末尾",
},
"keyword": {
"type": "string",
"description": "搜索关键词,返回包含该关键词的转录片段",
},
"position": {
"type": "string",
"enum": ["start", "end"],
"description": "快捷位置start=视频开头前30句end=视频结尾后30句",
},
},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "get_video_info",
"description": "获取视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "get_note_content",
"description": "获取 AI 生成的完整笔记内容Markdown 格式)。",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
},
]
# ── 工具执行 ──────────────────────────────────────────────────
def execute_tool(task_id: str, tool_name: str, arguments: dict) -> str:
"""执行工具调用,返回结果字符串。"""
data = _load_note_data(task_id)
if not data:
return json.dumps({"error": "笔记数据不存在"}, ensure_ascii=False)
if tool_name == "lookup_transcript":
return _lookup_transcript(data, arguments)
elif tool_name == "get_video_info":
return _get_video_info(data)
elif tool_name == "get_note_content":
return _get_note_content(data)
else:
return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False)
def _lookup_transcript(data: dict, args: dict) -> str:
segments = data.get("transcript", {}).get("segments", [])
if not segments:
return json.dumps({"error": "没有转录数据"}, ensure_ascii=False)
position = args.get("position")
start_time = args.get("start_time")
end_time = args.get("end_time")
keyword = args.get("keyword", "").strip()
# 快捷位置
if position == "start":
filtered = segments[:30]
elif position == "end":
filtered = segments[-30:]
else:
filtered = segments
# 时间筛选
if start_time is not None:
filtered = [s for s in filtered if s.get("end", 0) >= start_time]
if end_time is not None:
filtered = [s for s in filtered if s.get("start", 0) <= end_time]
# 关键词筛选
if keyword:
filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()]
# 限制返回量,避免 token 爆炸
if len(filtered) > 50:
filtered = filtered[:50]
truncated = True
else:
truncated = False
result = {
"total_segments": len(data.get("transcript", {}).get("segments", [])),
"returned": len(filtered),
"truncated": truncated,
"segments": [
{
"start": round(s.get("start", 0), 1),
"end": round(s.get("end", 0), 1),
"text": s.get("text", ""),
}
for s in filtered
],
}
return json.dumps(result, ensure_ascii=False)
def _get_video_info(data: dict) -> str:
am = data.get("audio_meta", {})
raw = am.get("raw_info", {}) or {}
info = {
"title": am.get("title") or raw.get("title", ""),
"uploader": raw.get("uploader", ""),
"description": raw.get("description", "")[:1000],
"tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [],
"duration_seconds": am.get("duration", 0),
"platform": am.get("platform", ""),
"video_id": am.get("video_id", ""),
"url": raw.get("webpage_url", ""),
"view_count": raw.get("view_count"),
"like_count": raw.get("like_count"),
"comment_count": raw.get("comment_count"),
}
# 去除 None 值
info = {k: v for k, v in info.items() if v is not None and v != ""}
return json.dumps(info, ensure_ascii=False)
def _get_note_content(data: dict) -> str:
md = data.get("markdown", "")
if isinstance(md, list):
# 多版本,取最新
md = md[-1].get("content", "") if md else ""
# 限制长度
if len(md) > 5000:
md = md[:5000] + "\n\n... (内容过长已截断)"
return json.dumps({"markdown": md}, ensure_ascii=False)