feat: 增加下载成功失败通知

This commit is contained in:
dnslin
2025-12-12 16:21:40 +08:00
parent 566d9912cf
commit 85f4c8a131
2 changed files with 191 additions and 4 deletions

View File

@@ -3,13 +3,15 @@ from __future__ import annotations
import sys import sys
from telegram import BotCommand from telegram import Bot, BotCommand
from telegram.ext import Application from telegram.ext import Application
from src.core import BotConfig from src.core import BotConfig
from src.telegram.handlers import Aria2BotAPI, build_handlers from src.telegram.handlers import Aria2BotAPI, build_handlers
from src.utils import setup_logger from src.utils import setup_logger
# 全局 bot 实例,用于自动上传等功能发送消息
_bot_instance: Bot | None = None
# Bot 命令列表,用于 Telegram 命令自动补全 # Bot 命令列表,用于 Telegram 命令自动补全
BOT_COMMANDS = [ BOT_COMMANDS = [
@@ -34,9 +36,11 @@ BOT_COMMANDS = [
async def post_init(application: Application) -> None: async def post_init(application: Application) -> None:
"""应用初始化后设置命令菜单""" """应用初始化后设置命令菜单"""
global _bot_instance
logger = setup_logger() logger = setup_logger()
logger.info("Setting bot commands...") logger.info("Setting bot commands...")
await application.bot.set_my_commands(BOT_COMMANDS) await application.bot.set_my_commands(BOT_COMMANDS)
_bot_instance = application.bot
def create_app(config: BotConfig) -> Application: def create_app(config: BotConfig) -> Application:

View File

@@ -29,14 +29,11 @@ from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
from src.telegram.keyboards import ( from src.telegram.keyboards import (
STATUS_EMOJI, STATUS_EMOJI,
build_list_type_keyboard, build_list_type_keyboard,
build_task_keyboard,
build_task_list_keyboard, build_task_list_keyboard,
build_delete_confirm_keyboard, build_delete_confirm_keyboard,
build_detail_keyboard,
build_after_add_keyboard, build_after_add_keyboard,
build_main_reply_keyboard, build_main_reply_keyboard,
build_cloud_menu_keyboard, build_cloud_menu_keyboard,
build_upload_choice_keyboard,
build_cloud_settings_keyboard, build_cloud_settings_keyboard,
build_detail_keyboard_with_upload, build_detail_keyboard_with_upload,
) )
@@ -98,6 +95,9 @@ class Aria2BotAPI:
self.service = Aria2ServiceManager() self.service = Aria2ServiceManager()
self._rpc: Aria2RpcClient | None = None self._rpc: Aria2RpcClient | None = None
self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task
self._auto_uploaded_gids: set[str] = set() # 已自动上传的任务GID防止重复上传
self._download_monitors: dict[str, asyncio.Task] = {} # gid -> 监控任务
self._notified_gids: set[str] = set() # 已通知的 GID防止重复通知
# 云存储相关 # 云存储相关
self._onedrive_config = onedrive_config self._onedrive_config = onedrive_config
self._onedrive = None self._onedrive = None
@@ -632,6 +632,114 @@ class Aria2BotAPI:
except Exception: except Exception:
pass pass
async def _trigger_auto_upload(self, chat_id: int, gid: str) -> None:
"""自动上传触发(下载完成后自动调用)"""
from pathlib import Path
logger.info(f"触发自动上传 GID={gid}")
client = self._get_onedrive_client()
if not client or not await client.is_authenticated():
logger.warning(f"自动上传跳过OneDrive 未认证 GID={gid}")
return
rpc = self._get_rpc_client()
try:
task = await rpc.get_status(gid)
except RpcError as e:
logger.error(f"自动上传失败:获取任务信息失败 GID={gid}: {e}")
return
if task.status != "complete":
logger.warning(f"自动上传跳过:任务未完成 GID={gid}")
return
local_path = Path(task.dir) / task.name
if not local_path.exists():
logger.error(f"自动上传失败:本地文件不存在 GID={gid}")
return
# 计算远程路径
try:
download_dir = DOWNLOAD_DIR.resolve()
relative_path = local_path.resolve().relative_to(download_dir)
remote_path = f"{self._onedrive_config.remote_path}/{relative_path.parent}"
except ValueError:
remote_path = self._onedrive_config.remote_path
# 启动后台上传任务
asyncio.create_task(self._do_auto_upload(
client, local_path, remote_path, task.name, chat_id, gid
))
async def _do_auto_upload(
self, client, local_path, remote_path: str, task_name: str, chat_id: int, gid: str
) -> None:
"""后台执行自动上传任务"""
import shutil
from .app import _bot_instance # 获取全局 bot 实例
if _bot_instance is None:
logger.error(f"自动上传失败:无法获取 bot 实例 GID={gid}")
return
# 发送上传开始通知
try:
msg = await _bot_instance.send_message(
chat_id=chat_id,
text=f"☁️ 自动上传开始: {task_name}\n⏳ 请稍候..."
)
except Exception as e:
logger.error(f"自动上传失败:发送消息失败 GID={gid}: {e}")
return
loop = asyncio.get_event_loop()
# 进度回调函数
async def update_progress(progress):
if progress.status == UploadStatus.UPLOADING and progress.total_size > 0:
percent = progress.progress
uploaded_mb = progress.uploaded_size / (1024 * 1024)
total_mb = progress.total_size / (1024 * 1024)
progress_text = (
f"☁️ 自动上传: {task_name}\n"
f"📤 {percent:.1f}% ({uploaded_mb:.1f}MB / {total_mb:.1f}MB)"
)
try:
await msg.edit_text(progress_text)
except Exception:
pass
def sync_progress_callback(progress):
if progress.status == UploadStatus.UPLOADING:
asyncio.run_coroutine_threadsafe(update_progress(progress), loop)
try:
success = await client.upload_file(local_path, remote_path, progress_callback=sync_progress_callback)
if success:
result_text = f"✅ 自动上传成功: {task_name}"
if self._onedrive_config and self._onedrive_config.delete_after_upload:
try:
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
result_text += "\n🗑️ 本地文件已删除"
except Exception as e:
result_text += f"\n⚠️ 删除本地文件失败: {e}"
await msg.edit_text(result_text)
logger.info(f"自动上传成功 GID={gid}")
else:
await msg.edit_text(f"❌ 自动上传失败: {task_name}")
logger.error(f"自动上传失败 GID={gid}")
except Exception as e:
logger.error(f"自动上传异常 GID={gid}: {e}")
try:
await msg.edit_text(f"❌ 自动上传失败: {task_name}\n错误: {e}")
except Exception:
pass
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""处理 Reply Keyboard 按钮点击""" """处理 Reply Keyboard 按钮点击"""
text = update.message.text text = update.message.text
@@ -679,6 +787,9 @@ class Aria2BotAPI:
keyboard = build_after_add_keyboard(gid) keyboard = build_after_add_keyboard(gid)
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard) await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
logger.info(f"/add 命令执行成功, GID={gid} - {_get_user_info(update)}") logger.info(f"/add 命令执行成功, GID={gid} - {_get_user_info(update)}")
# 启动下载监控,完成或失败时通知用户
chat_id = update.effective_chat.id
asyncio.create_task(self._start_download_monitor(gid, chat_id))
except RpcError as e: except RpcError as e:
logger.error(f"/add 命令执行失败: {e} - {_get_user_info(update)}") logger.error(f"/add 命令执行失败: {e} - {_get_user_info(update)}")
await self._reply(update, context, f"❌ 添加失败: {e}") await self._reply(update, context, f"❌ 添加失败: {e}")
@@ -702,6 +813,9 @@ class Aria2BotAPI:
keyboard = build_after_add_keyboard(gid) keyboard = build_after_add_keyboard(gid)
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard) await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
logger.info(f"种子任务添加成功, GID={gid} - {_get_user_info(update)}") logger.info(f"种子任务添加成功, GID={gid} - {_get_user_info(update)}")
# 启动下载监控,完成或失败时通知用户
chat_id = update.effective_chat.id
asyncio.create_task(self._start_download_monitor(gid, chat_id))
except RpcError as e: except RpcError as e:
logger.error(f"种子任务添加失败: {e} - {_get_user_info(update)}") logger.error(f"种子任务添加失败: {e} - {_get_user_info(update)}")
await self._reply(update, context, f"❌ 添加种子失败: {e}") await self._reply(update, context, f"❌ 添加种子失败: {e}")
@@ -1002,12 +1116,81 @@ class Aria2BotAPI:
# 任务完成或出错时停止刷新 # 任务完成或出错时停止刷新
if task.status in ("complete", "error", "removed"): if task.status in ("complete", "error", "removed"):
# 任务完成时检查是否需要自动上传
if (task.status == "complete" and
gid not in self._auto_uploaded_gids and
self._onedrive_config and
self._onedrive_config.enabled and
self._onedrive_config.auto_upload):
self._auto_uploaded_gids.add(gid)
asyncio.create_task(self._trigger_auto_upload(message.chat_id, gid))
break break
await asyncio.sleep(2) await asyncio.sleep(2)
finally: finally:
self._auto_refresh_tasks.pop(key, None) self._auto_refresh_tasks.pop(key, None)
# === 下载任务监控和通知 ===
async def _start_download_monitor(self, gid: str, chat_id: int) -> None:
"""启动下载任务监控"""
if gid in self._download_monitors:
return
task = asyncio.create_task(self._monitor_download(gid, chat_id))
self._download_monitors[gid] = task
async def _monitor_download(self, gid: str, chat_id: int) -> None:
"""监控下载任务直到完成或失败"""
from .app import _bot_instance
try:
rpc = self._get_rpc_client()
for _ in range(17280): # 最长 24 小时 (5秒 * 17280)
try:
task = await rpc.get_status(gid)
except RpcError:
break # 任务可能已被删除
if task.status == "complete":
if gid not in self._notified_gids:
self._notified_gids.add(gid)
await self._send_completion_notification(chat_id, task)
break
elif task.status == "error":
if gid not in self._notified_gids:
self._notified_gids.add(gid)
await self._send_error_notification(chat_id, task)
break
elif task.status == "removed":
break
await asyncio.sleep(5)
finally:
self._download_monitors.pop(gid, None)
async def _send_completion_notification(self, chat_id: int, task: DownloadTask) -> None:
"""发送下载完成通知"""
from .app import _bot_instance
if _bot_instance is None:
return
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
text = f"✅ *下载完成*\n📄 {safe_name}\n📦 大小: {task.size_str}\n🆔 GID: `{task.gid}`"
try:
await _bot_instance.send_message(chat_id=chat_id, text=text, parse_mode="Markdown")
except Exception as e:
logger.warning(f"发送完成通知失败 (GID={task.gid}): {e}")
async def _send_error_notification(self, chat_id: int, task: DownloadTask) -> None:
"""发送下载失败通知"""
from .app import _bot_instance
if _bot_instance is None:
return
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
text = f"❌ *下载失败*\n📄 {safe_name}\n🆔 GID: `{task.gid}`\n⚠️ 原因: {task.error_message or '未知错误'}"
try:
await _bot_instance.send_message(chat_id=chat_id, text=text, parse_mode="Markdown")
except Exception as e:
logger.warning(f"发送失败通知失败 (GID={task.gid}): {e}")
async def _handle_stats_callback(self, query, rpc: Aria2RpcClient) -> None: async def _handle_stats_callback(self, query, rpc: Aria2RpcClient) -> None:
"""处理统计回调""" """处理统计回调"""
stat = await rpc.get_global_stat() stat = await rpc.get_global_stat()