mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-11 18:10:06 +08:00
feat(chat): 支持 function calling,模型可主动查询原文数据
新增三个工具供 LLM 调用: - lookup_transcript: 查询转录原文(按时间范围、关键词、位置筛选) - get_video_info: 获取视频元信息(标题、作者、简介、标签等) - get_note_content: 获取完整笔记 Markdown 内容 实现 tool calling 循环(最多 3 轮),LLM 可根据问题 主动调用工具获取所需信息,不再完全依赖 RAG 检索。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,35 +1,38 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.provider import ProviderService
|
||||
from app.services.vector_store import VectorStoreManager
|
||||
from app.services.chat_tools import TOOLS, execute_tool
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你可以参考两种来源回答用户的问题:
|
||||
1. [视频信息] — 视频标题、作者、简介、标签等元信息
|
||||
2. [笔记] — AI 生成的视频摘要笔记
|
||||
3. [转录] — 视频原始语音转录文本(含时间戳)
|
||||
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力:
|
||||
|
||||
以下是检索到的相关内容:
|
||||
1. 系统已自动检索了一些相关内容作为初始参考(见下方)
|
||||
2. 你可以调用工具主动查询更多信息:
|
||||
- lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选)
|
||||
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
|
||||
- get_note_content: 获取完整笔记内容
|
||||
|
||||
--- 相关内容 ---
|
||||
--- 初始检索内容 ---
|
||||
{context}
|
||||
---
|
||||
|
||||
回答要求:
|
||||
- 优先使用转录原文回答关于视频具体内容、原话、细节的问题
|
||||
- 优先使用笔记回答关于总结、要点、结构的问题
|
||||
- 如果确实没有相关信息,请诚实告知
|
||||
- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息
|
||||
- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文
|
||||
- 回答关于作者、标题等基本信息时,用 get_video_info 查询
|
||||
- 请用中文回答,保持简洁准确"""
|
||||
|
||||
|
||||
def _build_context(chunks: list[dict]) -> str:
|
||||
"""将检索到的片段拼接为上下文文本。"""
|
||||
parts = []
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
for chunk in chunks:
|
||||
meta = chunk.get("metadata", {})
|
||||
source_type = meta.get("source_type", "unknown")
|
||||
if source_type == "meta":
|
||||
@@ -71,39 +74,29 @@ def chat(
|
||||
model_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
RAG 问答:检索相关片段 → 构建 prompt → 调用 LLM → 返回答案 + 来源。
|
||||
|
||||
Returns:
|
||||
{"answer": str, "sources": list[dict]}
|
||||
RAG + Tool Calling 问答。
|
||||
1. 向量检索初始上下文
|
||||
2. 调用 LLM(带 tools)
|
||||
3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM
|
||||
4. 循环直到 LLM 给出最终回答
|
||||
"""
|
||||
vector_store = VectorStoreManager()
|
||||
|
||||
# 1. 检索相关片段
|
||||
chunks = vector_store.query(task_id, question, n_results=5)
|
||||
print(
|
||||
f"检索到 {len(chunks)} 个相关片段: {[c['metadata'].get('source_type') for c in chunks]}"
|
||||
)
|
||||
if not chunks:
|
||||
return {
|
||||
"answer": "暂未找到相关笔记内容,请确认笔记已生成并完成索引。",
|
||||
"sources": [],
|
||||
}
|
||||
# 1. 检索初始上下文
|
||||
chunks = vector_store.query(task_id, question, n_results=6)
|
||||
context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)"
|
||||
sources = _build_sources(chunks) if chunks else []
|
||||
|
||||
# 2. 构建上下文和来源
|
||||
context = _build_context(chunks)
|
||||
sources = _build_sources(chunks)
|
||||
|
||||
# 3. 构建消息
|
||||
# 2. 构建消息
|
||||
system_msg = SYSTEM_PROMPT.format(context=context)
|
||||
messages = [{"role": "system", "content": system_msg}]
|
||||
|
||||
# 加入历史对话(最近 10 轮)
|
||||
for msg in history[-20:]:
|
||||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
messages.append({"role": "user", "content": question})
|
||||
|
||||
# 4. 调用 LLM
|
||||
# 3. 获取 LLM client
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
raise ValueError(f"未找到模型供应商: {provider_id}")
|
||||
@@ -117,14 +110,49 @@ def chat(
|
||||
)
|
||||
gpt = GPTFactory.from_config(config)
|
||||
|
||||
logger.info(f"Chat RAG: task_id={task_id}, provider={provider['name']}, model={model_name}")
|
||||
logger.info(f"Chat: task_id={task_id}, model={model_name}")
|
||||
|
||||
# 4. Tool calling 循环(最多 3 轮)
|
||||
max_rounds = 3
|
||||
for round_i in range(max_rounds):
|
||||
response = gpt.client.chat.completions.create(
|
||||
model=gpt.model,
|
||||
messages=messages,
|
||||
tools=TOOLS,
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
msg = response.choices[0].message
|
||||
|
||||
# 没有工具调用,直接返回
|
||||
if not msg.tool_calls:
|
||||
return {"answer": msg.content or "", "sources": sources}
|
||||
|
||||
# 处理工具调用
|
||||
messages.append(msg)
|
||||
|
||||
for tool_call in msg.tool_calls:
|
||||
fn_name = tool_call.function.name
|
||||
try:
|
||||
fn_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
fn_args = {}
|
||||
|
||||
logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})")
|
||||
|
||||
result = execute_tool(task_id, fn_name, fn_args)
|
||||
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": result,
|
||||
})
|
||||
|
||||
# 超过最大轮次,做最后一次不带 tools 的调用
|
||||
response = gpt.client.chat.completions.create(
|
||||
model=gpt.model,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
return {"answer": answer, "sources": sources}
|
||||
return {"answer": response.choices[0].message.content or "", "sources": sources}
|
||||
|
||||
184
backend/app/services/chat_tools.py
Normal file
184
backend/app/services/chat_tools.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Chat function calling 工具定义与执行。
|
||||
提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
|
||||
|
||||
|
||||
def _load_note_data(task_id: str) -> Optional[dict]:
|
||||
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# ── 工具定义(OpenAI function calling 格式)──────────────────────
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_transcript",
|
||||
"description": "查询视频原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start_time": {
|
||||
"type": "number",
|
||||
"description": "起始时间(秒),例如 0 表示视频开头,60 表示第1分钟",
|
||||
},
|
||||
"end_time": {
|
||||
"type": "number",
|
||||
"description": "结束时间(秒),不传则到末尾",
|
||||
},
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词,返回包含该关键词的转录片段",
|
||||
},
|
||||
"position": {
|
||||
"type": "string",
|
||||
"enum": ["start", "end"],
|
||||
"description": "快捷位置:start=视频开头前30句,end=视频结尾后30句",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_video_info",
|
||||
"description": "获取视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_note_content",
|
||||
"description": "获取 AI 生成的完整笔记内容(Markdown 格式)。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── 工具执行 ──────────────────────────────────────────────────
|
||||
|
||||
def execute_tool(task_id: str, tool_name: str, arguments: dict) -> str:
|
||||
"""执行工具调用,返回结果字符串。"""
|
||||
data = _load_note_data(task_id)
|
||||
if not data:
|
||||
return json.dumps({"error": "笔记数据不存在"}, ensure_ascii=False)
|
||||
|
||||
if tool_name == "lookup_transcript":
|
||||
return _lookup_transcript(data, arguments)
|
||||
elif tool_name == "get_video_info":
|
||||
return _get_video_info(data)
|
||||
elif tool_name == "get_note_content":
|
||||
return _get_note_content(data)
|
||||
else:
|
||||
return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _lookup_transcript(data: dict, args: dict) -> str:
|
||||
segments = data.get("transcript", {}).get("segments", [])
|
||||
if not segments:
|
||||
return json.dumps({"error": "没有转录数据"}, ensure_ascii=False)
|
||||
|
||||
position = args.get("position")
|
||||
start_time = args.get("start_time")
|
||||
end_time = args.get("end_time")
|
||||
keyword = args.get("keyword", "").strip()
|
||||
|
||||
# 快捷位置
|
||||
if position == "start":
|
||||
filtered = segments[:30]
|
||||
elif position == "end":
|
||||
filtered = segments[-30:]
|
||||
else:
|
||||
filtered = segments
|
||||
|
||||
# 时间筛选
|
||||
if start_time is not None:
|
||||
filtered = [s for s in filtered if s.get("end", 0) >= start_time]
|
||||
if end_time is not None:
|
||||
filtered = [s for s in filtered if s.get("start", 0) <= end_time]
|
||||
|
||||
# 关键词筛选
|
||||
if keyword:
|
||||
filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()]
|
||||
|
||||
# 限制返回量,避免 token 爆炸
|
||||
if len(filtered) > 50:
|
||||
filtered = filtered[:50]
|
||||
truncated = True
|
||||
else:
|
||||
truncated = False
|
||||
|
||||
result = {
|
||||
"total_segments": len(data.get("transcript", {}).get("segments", [])),
|
||||
"returned": len(filtered),
|
||||
"truncated": truncated,
|
||||
"segments": [
|
||||
{
|
||||
"start": round(s.get("start", 0), 1),
|
||||
"end": round(s.get("end", 0), 1),
|
||||
"text": s.get("text", ""),
|
||||
}
|
||||
for s in filtered
|
||||
],
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
def _get_video_info(data: dict) -> str:
|
||||
am = data.get("audio_meta", {})
|
||||
raw = am.get("raw_info", {}) or {}
|
||||
|
||||
info = {
|
||||
"title": am.get("title") or raw.get("title", ""),
|
||||
"uploader": raw.get("uploader", ""),
|
||||
"description": raw.get("description", "")[:1000],
|
||||
"tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [],
|
||||
"duration_seconds": am.get("duration", 0),
|
||||
"platform": am.get("platform", ""),
|
||||
"video_id": am.get("video_id", ""),
|
||||
"url": raw.get("webpage_url", ""),
|
||||
"view_count": raw.get("view_count"),
|
||||
"like_count": raw.get("like_count"),
|
||||
"comment_count": raw.get("comment_count"),
|
||||
}
|
||||
# 去除 None 值
|
||||
info = {k: v for k, v in info.items() if v is not None and v != ""}
|
||||
return json.dumps(info, ensure_ascii=False)
|
||||
|
||||
|
||||
def _get_note_content(data: dict) -> str:
|
||||
md = data.get("markdown", "")
|
||||
if isinstance(md, list):
|
||||
# 多版本,取最新
|
||||
md = md[-1].get("content", "") if md else ""
|
||||
# 限制长度
|
||||
if len(md) > 5000:
|
||||
md = md[:5000] + "\n\n... (内容过长已截断)"
|
||||
return json.dumps({"markdown": md}, ensure_ascii=False)
|
||||
Reference in New Issue
Block a user