From 85f4c8a1318b644a7372a288a7ee332cc762fc58 Mon Sep 17 00:00:00 2001 From: dnslin Date: Fri, 12 Dec 2025 16:21:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=E6=88=90=E5=8A=9F=E5=A4=B1=E8=B4=A5=E9=80=9A=E7=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/telegram/app.py | 6 +- src/telegram/handlers.py | 189 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 191 insertions(+), 4 deletions(-) diff --git a/src/telegram/app.py b/src/telegram/app.py index d1d7658..6e0be85 100644 --- a/src/telegram/app.py +++ b/src/telegram/app.py @@ -3,13 +3,15 @@ from __future__ import annotations import sys -from telegram import BotCommand +from telegram import Bot, BotCommand from telegram.ext import Application from src.core import BotConfig from src.telegram.handlers import Aria2BotAPI, build_handlers from src.utils import setup_logger +# 全局 bot 实例,用于自动上传等功能发送消息 +_bot_instance: Bot | None = None # Bot 命令列表,用于 Telegram 命令自动补全 BOT_COMMANDS = [ @@ -34,9 +36,11 @@ BOT_COMMANDS = [ async def post_init(application: Application) -> None: """应用初始化后设置命令菜单""" + global _bot_instance logger = setup_logger() logger.info("Setting bot commands...") await application.bot.set_my_commands(BOT_COMMANDS) + _bot_instance = application.bot def create_app(config: BotConfig) -> Application: diff --git a/src/telegram/handlers.py b/src/telegram/handlers.py index e8b7753..51654ff 100644 --- a/src/telegram/handlers.py +++ b/src/telegram/handlers.py @@ -29,14 +29,11 @@ from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size from src.telegram.keyboards import ( STATUS_EMOJI, build_list_type_keyboard, - build_task_keyboard, build_task_list_keyboard, build_delete_confirm_keyboard, - build_detail_keyboard, build_after_add_keyboard, build_main_reply_keyboard, build_cloud_menu_keyboard, - build_upload_choice_keyboard, build_cloud_settings_keyboard, build_detail_keyboard_with_upload, ) @@ -98,6 +95,9 @@ class Aria2BotAPI: self.service = Aria2ServiceManager() self._rpc: Aria2RpcClient | None = None self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task + self._auto_uploaded_gids: set[str] = set() # 已自动上传的任务GID,防止重复上传 + self._download_monitors: dict[str, asyncio.Task] = {} # gid -> 监控任务 + self._notified_gids: set[str] = set() # 已通知的 GID,防止重复通知 # 云存储相关 self._onedrive_config = onedrive_config self._onedrive = None @@ -632,6 +632,114 @@ class Aria2BotAPI: except Exception: pass + async def _trigger_auto_upload(self, chat_id: int, gid: str) -> None: + """自动上传触发(下载完成后自动调用)""" + from pathlib import Path + + logger.info(f"触发自动上传 GID={gid}") + + client = self._get_onedrive_client() + if not client or not await client.is_authenticated(): + logger.warning(f"自动上传跳过:OneDrive 未认证 GID={gid}") + return + + rpc = self._get_rpc_client() + try: + task = await rpc.get_status(gid) + except RpcError as e: + logger.error(f"自动上传失败:获取任务信息失败 GID={gid}: {e}") + return + + if task.status != "complete": + logger.warning(f"自动上传跳过:任务未完成 GID={gid}") + return + + local_path = Path(task.dir) / task.name + if not local_path.exists(): + logger.error(f"自动上传失败:本地文件不存在 GID={gid}") + return + + # 计算远程路径 + try: + download_dir = DOWNLOAD_DIR.resolve() + relative_path = local_path.resolve().relative_to(download_dir) + remote_path = f"{self._onedrive_config.remote_path}/{relative_path.parent}" + except ValueError: + remote_path = self._onedrive_config.remote_path + + # 启动后台上传任务 + asyncio.create_task(self._do_auto_upload( + client, local_path, remote_path, task.name, chat_id, gid + )) + + async def _do_auto_upload( + self, client, local_path, remote_path: str, task_name: str, chat_id: int, gid: str + ) -> None: + """后台执行自动上传任务""" + import shutil + from .app import _bot_instance # 获取全局 bot 实例 + + if _bot_instance is None: + logger.error(f"自动上传失败:无法获取 bot 实例 GID={gid}") + return + + # 发送上传开始通知 + try: + msg = await _bot_instance.send_message( + chat_id=chat_id, + text=f"☁️ 自动上传开始: {task_name}\n⏳ 请稍候..." + ) + except Exception as e: + logger.error(f"自动上传失败:发送消息失败 GID={gid}: {e}") + return + + loop = asyncio.get_event_loop() + + # 进度回调函数 + async def update_progress(progress): + if progress.status == UploadStatus.UPLOADING and progress.total_size > 0: + percent = progress.progress + uploaded_mb = progress.uploaded_size / (1024 * 1024) + total_mb = progress.total_size / (1024 * 1024) + progress_text = ( + f"☁️ 自动上传: {task_name}\n" + f"📤 {percent:.1f}% ({uploaded_mb:.1f}MB / {total_mb:.1f}MB)" + ) + try: + await msg.edit_text(progress_text) + except Exception: + pass + + def sync_progress_callback(progress): + if progress.status == UploadStatus.UPLOADING: + asyncio.run_coroutine_threadsafe(update_progress(progress), loop) + + try: + success = await client.upload_file(local_path, remote_path, progress_callback=sync_progress_callback) + + if success: + result_text = f"✅ 自动上传成功: {task_name}" + if self._onedrive_config and self._onedrive_config.delete_after_upload: + try: + if local_path.is_dir(): + shutil.rmtree(local_path) + else: + local_path.unlink() + result_text += "\n🗑️ 本地文件已删除" + except Exception as e: + result_text += f"\n⚠️ 删除本地文件失败: {e}" + await msg.edit_text(result_text) + logger.info(f"自动上传成功 GID={gid}") + else: + await msg.edit_text(f"❌ 自动上传失败: {task_name}") + logger.error(f"自动上传失败 GID={gid}") + except Exception as e: + logger.error(f"自动上传异常 GID={gid}: {e}") + try: + await msg.edit_text(f"❌ 自动上传失败: {task_name}\n错误: {e}") + except Exception: + pass + async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """处理 Reply Keyboard 按钮点击""" text = update.message.text @@ -679,6 +787,9 @@ class Aria2BotAPI: keyboard = build_after_add_keyboard(gid) await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard) logger.info(f"/add 命令执行成功, GID={gid} - {_get_user_info(update)}") + # 启动下载监控,完成或失败时通知用户 + chat_id = update.effective_chat.id + asyncio.create_task(self._start_download_monitor(gid, chat_id)) except RpcError as e: logger.error(f"/add 命令执行失败: {e} - {_get_user_info(update)}") await self._reply(update, context, f"❌ 添加失败: {e}") @@ -702,6 +813,9 @@ class Aria2BotAPI: keyboard = build_after_add_keyboard(gid) await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard) logger.info(f"种子任务添加成功, GID={gid} - {_get_user_info(update)}") + # 启动下载监控,完成或失败时通知用户 + chat_id = update.effective_chat.id + asyncio.create_task(self._start_download_monitor(gid, chat_id)) except RpcError as e: logger.error(f"种子任务添加失败: {e} - {_get_user_info(update)}") await self._reply(update, context, f"❌ 添加种子失败: {e}") @@ -1002,12 +1116,81 @@ class Aria2BotAPI: # 任务完成或出错时停止刷新 if task.status in ("complete", "error", "removed"): + # 任务完成时检查是否需要自动上传 + if (task.status == "complete" and + gid not in self._auto_uploaded_gids and + self._onedrive_config and + self._onedrive_config.enabled and + self._onedrive_config.auto_upload): + self._auto_uploaded_gids.add(gid) + asyncio.create_task(self._trigger_auto_upload(message.chat_id, gid)) break await asyncio.sleep(2) finally: self._auto_refresh_tasks.pop(key, None) + # === 下载任务监控和通知 === + + async def _start_download_monitor(self, gid: str, chat_id: int) -> None: + """启动下载任务监控""" + if gid in self._download_monitors: + return + task = asyncio.create_task(self._monitor_download(gid, chat_id)) + self._download_monitors[gid] = task + + async def _monitor_download(self, gid: str, chat_id: int) -> None: + """监控下载任务直到完成或失败""" + from .app import _bot_instance + try: + rpc = self._get_rpc_client() + for _ in range(17280): # 最长 24 小时 (5秒 * 17280) + try: + task = await rpc.get_status(gid) + except RpcError: + break # 任务可能已被删除 + + if task.status == "complete": + if gid not in self._notified_gids: + self._notified_gids.add(gid) + await self._send_completion_notification(chat_id, task) + break + elif task.status == "error": + if gid not in self._notified_gids: + self._notified_gids.add(gid) + await self._send_error_notification(chat_id, task) + break + elif task.status == "removed": + break + + await asyncio.sleep(5) + finally: + self._download_monitors.pop(gid, None) + + async def _send_completion_notification(self, chat_id: int, task: DownloadTask) -> None: + """发送下载完成通知""" + from .app import _bot_instance + if _bot_instance is None: + return + safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`") + text = f"✅ *下载完成*\n📄 {safe_name}\n📦 大小: {task.size_str}\n🆔 GID: `{task.gid}`" + try: + await _bot_instance.send_message(chat_id=chat_id, text=text, parse_mode="Markdown") + except Exception as e: + logger.warning(f"发送完成通知失败 (GID={task.gid}): {e}") + + async def _send_error_notification(self, chat_id: int, task: DownloadTask) -> None: + """发送下载失败通知""" + from .app import _bot_instance + if _bot_instance is None: + return + safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`") + text = f"❌ *下载失败*\n📄 {safe_name}\n🆔 GID: `{task.gid}`\n⚠️ 原因: {task.error_message or '未知错误'}" + try: + await _bot_instance.send_message(chat_id=chat_id, text=text, parse_mode="Markdown") + except Exception as e: + logger.warning(f"发送失败通知失败 (GID={task.gid}): {e}") + async def _handle_stats_callback(self, query, rpc: Aria2RpcClient) -> None: """处理统计回调""" stat = await rpc.get_global_stat()