diff --git a/backend/app/services/chat_service.py b/backend/app/services/chat_service.py index cac5bf1..e7239cf 100644 --- a/backend/app/services/chat_service.py +++ b/backend/app/services/chat_service.py @@ -1,35 +1,38 @@ +import json from typing import Optional from app.gpt.gpt_factory import GPTFactory from app.models.model_config import ModelConfig from app.services.provider import ProviderService from app.services.vector_store import VectorStoreManager +from app.services.chat_tools import TOOLS, execute_tool from app.utils.logger import get_logger logger = get_logger(__name__) -SYSTEM_PROMPT = """你是一个视频笔记问答助手。你可以参考两种来源回答用户的问题: -1. [视频信息] — 视频标题、作者、简介、标签等元信息 -2. [笔记] — AI 生成的视频摘要笔记 -3. [转录] — 视频原始语音转录文本(含时间戳) +SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力: -以下是检索到的相关内容: +1. 系统已自动检索了一些相关内容作为初始参考(见下方) +2. 你可以调用工具主动查询更多信息: + - lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选) + - get_video_info: 获取视频元信息(标题、作者、简介、标签等) + - get_note_content: 获取完整笔记内容 ---- 相关内容 --- +--- 初始检索内容 --- {context} --- 回答要求: -- 优先使用转录原文回答关于视频具体内容、原话、细节的问题 -- 优先使用笔记回答关于总结、要点、结构的问题 -- 如果确实没有相关信息,请诚实告知 +- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息 +- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文 +- 回答关于作者、标题等基本信息时,用 get_video_info 查询 - 请用中文回答,保持简洁准确""" def _build_context(chunks: list[dict]) -> str: """将检索到的片段拼接为上下文文本。""" parts = [] - for i, chunk in enumerate(chunks, 1): + for chunk in chunks: meta = chunk.get("metadata", {}) source_type = meta.get("source_type", "unknown") if source_type == "meta": @@ -71,39 +74,29 @@ def chat( model_name: str, ) -> dict: """ - RAG 问答:检索相关片段 → 构建 prompt → 调用 LLM → 返回答案 + 来源。 - - Returns: - {"answer": str, "sources": list[dict]} + RAG + Tool Calling 问答。 + 1. 向量检索初始上下文 + 2. 调用 LLM(带 tools) + 3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM + 4. 循环直到 LLM 给出最终回答 """ vector_store = VectorStoreManager() - # 1. 检索相关片段 - chunks = vector_store.query(task_id, question, n_results=5) - print( - f"检索到 {len(chunks)} 个相关片段: {[c['metadata'].get('source_type') for c in chunks]}" - ) - if not chunks: - return { - "answer": "暂未找到相关笔记内容,请确认笔记已生成并完成索引。", - "sources": [], - } + # 1. 检索初始上下文 + chunks = vector_store.query(task_id, question, n_results=6) + context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)" + sources = _build_sources(chunks) if chunks else [] - # 2. 构建上下文和来源 - context = _build_context(chunks) - sources = _build_sources(chunks) - - # 3. 构建消息 + # 2. 构建消息 system_msg = SYSTEM_PROMPT.format(context=context) messages = [{"role": "system", "content": system_msg}] - # 加入历史对话(最近 10 轮) for msg in history[-20:]: messages.append({"role": msg["role"], "content": msg["content"]}) messages.append({"role": "user", "content": question}) - # 4. 调用 LLM + # 3. 获取 LLM client provider = ProviderService.get_provider_by_id(provider_id) if not provider: raise ValueError(f"未找到模型供应商: {provider_id}") @@ -117,14 +110,49 @@ def chat( ) gpt = GPTFactory.from_config(config) - logger.info(f"Chat RAG: task_id={task_id}, provider={provider['name']}, model={model_name}") + logger.info(f"Chat: task_id={task_id}, model={model_name}") + # 4. Tool calling 循环(最多 3 轮) + max_rounds = 3 + for round_i in range(max_rounds): + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=messages, + tools=TOOLS, + temperature=0.7, + ) + + msg = response.choices[0].message + + # 没有工具调用,直接返回 + if not msg.tool_calls: + return {"answer": msg.content or "", "sources": sources} + + # 处理工具调用 + messages.append(msg) + + for tool_call in msg.tool_calls: + fn_name = tool_call.function.name + try: + fn_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + fn_args = {} + + logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})") + + result = execute_tool(task_id, fn_name, fn_args) + + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + }) + + # 超过最大轮次,做最后一次不带 tools 的调用 response = gpt.client.chat.completions.create( model=gpt.model, messages=messages, temperature=0.7, ) - answer = response.choices[0].message.content - - return {"answer": answer, "sources": sources} + return {"answer": response.choices[0].message.content or "", "sources": sources} diff --git a/backend/app/services/chat_tools.py b/backend/app/services/chat_tools.py new file mode 100644 index 0000000..186002b --- /dev/null +++ b/backend/app/services/chat_tools.py @@ -0,0 +1,184 @@ +""" +Chat function calling 工具定义与执行。 +提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。 +""" + +import json +import os +from typing import Optional + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") + + +def _load_note_data(task_id: str) -> Optional[dict]: + path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(path): + return None + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +# ── 工具定义(OpenAI function calling 格式)────────────────────── + +TOOLS = [ + { + "type": "function", + "function": { + "name": "lookup_transcript", + "description": "查询视频原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。", + "parameters": { + "type": "object", + "properties": { + "start_time": { + "type": "number", + "description": "起始时间(秒),例如 0 表示视频开头,60 表示第1分钟", + }, + "end_time": { + "type": "number", + "description": "结束时间(秒),不传则到末尾", + }, + "keyword": { + "type": "string", + "description": "搜索关键词,返回包含该关键词的转录片段", + }, + "position": { + "type": "string", + "enum": ["start", "end"], + "description": "快捷位置:start=视频开头前30句,end=视频结尾后30句", + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_video_info", + "description": "获取视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_note_content", + "description": "获取 AI 生成的完整笔记内容(Markdown 格式)。", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, +] + + +# ── 工具执行 ────────────────────────────────────────────────── + +def execute_tool(task_id: str, tool_name: str, arguments: dict) -> str: + """执行工具调用,返回结果字符串。""" + data = _load_note_data(task_id) + if not data: + return json.dumps({"error": "笔记数据不存在"}, ensure_ascii=False) + + if tool_name == "lookup_transcript": + return _lookup_transcript(data, arguments) + elif tool_name == "get_video_info": + return _get_video_info(data) + elif tool_name == "get_note_content": + return _get_note_content(data) + else: + return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False) + + +def _lookup_transcript(data: dict, args: dict) -> str: + segments = data.get("transcript", {}).get("segments", []) + if not segments: + return json.dumps({"error": "没有转录数据"}, ensure_ascii=False) + + position = args.get("position") + start_time = args.get("start_time") + end_time = args.get("end_time") + keyword = args.get("keyword", "").strip() + + # 快捷位置 + if position == "start": + filtered = segments[:30] + elif position == "end": + filtered = segments[-30:] + else: + filtered = segments + + # 时间筛选 + if start_time is not None: + filtered = [s for s in filtered if s.get("end", 0) >= start_time] + if end_time is not None: + filtered = [s for s in filtered if s.get("start", 0) <= end_time] + + # 关键词筛选 + if keyword: + filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()] + + # 限制返回量,避免 token 爆炸 + if len(filtered) > 50: + filtered = filtered[:50] + truncated = True + else: + truncated = False + + result = { + "total_segments": len(data.get("transcript", {}).get("segments", [])), + "returned": len(filtered), + "truncated": truncated, + "segments": [ + { + "start": round(s.get("start", 0), 1), + "end": round(s.get("end", 0), 1), + "text": s.get("text", ""), + } + for s in filtered + ], + } + return json.dumps(result, ensure_ascii=False) + + +def _get_video_info(data: dict) -> str: + am = data.get("audio_meta", {}) + raw = am.get("raw_info", {}) or {} + + info = { + "title": am.get("title") or raw.get("title", ""), + "uploader": raw.get("uploader", ""), + "description": raw.get("description", "")[:1000], + "tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [], + "duration_seconds": am.get("duration", 0), + "platform": am.get("platform", ""), + "video_id": am.get("video_id", ""), + "url": raw.get("webpage_url", ""), + "view_count": raw.get("view_count"), + "like_count": raw.get("like_count"), + "comment_count": raw.get("comment_count"), + } + # 去除 None 值 + info = {k: v for k, v in info.items() if v is not None and v != ""} + return json.dumps(info, ensure_ascii=False) + + +def _get_note_content(data: dict) -> str: + md = data.get("markdown", "") + if isinstance(md, list): + # 多版本,取最新 + md = md[-1].get("content", "") if md else "" + # 限制长度 + if len(md) > 5000: + md = md[:5000] + "\n\n... (内容过长已截断)" + return json.dumps({"markdown": md}, ensure_ascii=False)