mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-11 20:12:20 +08:00
feat: 增加下载成功失败通知
This commit is contained in:
@@ -3,13 +3,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from telegram import BotCommand
|
from telegram import Bot, BotCommand
|
||||||
from telegram.ext import Application
|
from telegram.ext import Application
|
||||||
|
|
||||||
from src.core import BotConfig
|
from src.core import BotConfig
|
||||||
from src.telegram.handlers import Aria2BotAPI, build_handlers
|
from src.telegram.handlers import Aria2BotAPI, build_handlers
|
||||||
from src.utils import setup_logger
|
from src.utils import setup_logger
|
||||||
|
|
||||||
|
# 全局 bot 实例,用于自动上传等功能发送消息
|
||||||
|
_bot_instance: Bot | None = None
|
||||||
|
|
||||||
# Bot 命令列表,用于 Telegram 命令自动补全
|
# Bot 命令列表,用于 Telegram 命令自动补全
|
||||||
BOT_COMMANDS = [
|
BOT_COMMANDS = [
|
||||||
@@ -34,9 +36,11 @@ BOT_COMMANDS = [
|
|||||||
|
|
||||||
async def post_init(application: Application) -> None:
|
async def post_init(application: Application) -> None:
|
||||||
"""应用初始化后设置命令菜单"""
|
"""应用初始化后设置命令菜单"""
|
||||||
|
global _bot_instance
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
logger.info("Setting bot commands...")
|
logger.info("Setting bot commands...")
|
||||||
await application.bot.set_my_commands(BOT_COMMANDS)
|
await application.bot.set_my_commands(BOT_COMMANDS)
|
||||||
|
_bot_instance = application.bot
|
||||||
|
|
||||||
|
|
||||||
def create_app(config: BotConfig) -> Application:
|
def create_app(config: BotConfig) -> Application:
|
||||||
|
|||||||
@@ -29,14 +29,11 @@ from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
|
|||||||
from src.telegram.keyboards import (
|
from src.telegram.keyboards import (
|
||||||
STATUS_EMOJI,
|
STATUS_EMOJI,
|
||||||
build_list_type_keyboard,
|
build_list_type_keyboard,
|
||||||
build_task_keyboard,
|
|
||||||
build_task_list_keyboard,
|
build_task_list_keyboard,
|
||||||
build_delete_confirm_keyboard,
|
build_delete_confirm_keyboard,
|
||||||
build_detail_keyboard,
|
|
||||||
build_after_add_keyboard,
|
build_after_add_keyboard,
|
||||||
build_main_reply_keyboard,
|
build_main_reply_keyboard,
|
||||||
build_cloud_menu_keyboard,
|
build_cloud_menu_keyboard,
|
||||||
build_upload_choice_keyboard,
|
|
||||||
build_cloud_settings_keyboard,
|
build_cloud_settings_keyboard,
|
||||||
build_detail_keyboard_with_upload,
|
build_detail_keyboard_with_upload,
|
||||||
)
|
)
|
||||||
@@ -98,6 +95,9 @@ class Aria2BotAPI:
|
|||||||
self.service = Aria2ServiceManager()
|
self.service = Aria2ServiceManager()
|
||||||
self._rpc: Aria2RpcClient | None = None
|
self._rpc: Aria2RpcClient | None = None
|
||||||
self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task
|
self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task
|
||||||
|
self._auto_uploaded_gids: set[str] = set() # 已自动上传的任务GID,防止重复上传
|
||||||
|
self._download_monitors: dict[str, asyncio.Task] = {} # gid -> 监控任务
|
||||||
|
self._notified_gids: set[str] = set() # 已通知的 GID,防止重复通知
|
||||||
# 云存储相关
|
# 云存储相关
|
||||||
self._onedrive_config = onedrive_config
|
self._onedrive_config = onedrive_config
|
||||||
self._onedrive = None
|
self._onedrive = None
|
||||||
@@ -632,6 +632,114 @@ class Aria2BotAPI:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _trigger_auto_upload(self, chat_id: int, gid: str) -> None:
|
||||||
|
"""自动上传触发(下载完成后自动调用)"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger.info(f"触发自动上传 GID={gid}")
|
||||||
|
|
||||||
|
client = self._get_onedrive_client()
|
||||||
|
if not client or not await client.is_authenticated():
|
||||||
|
logger.warning(f"自动上传跳过:OneDrive 未认证 GID={gid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
rpc = self._get_rpc_client()
|
||||||
|
try:
|
||||||
|
task = await rpc.get_status(gid)
|
||||||
|
except RpcError as e:
|
||||||
|
logger.error(f"自动上传失败:获取任务信息失败 GID={gid}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if task.status != "complete":
|
||||||
|
logger.warning(f"自动上传跳过:任务未完成 GID={gid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
local_path = Path(task.dir) / task.name
|
||||||
|
if not local_path.exists():
|
||||||
|
logger.error(f"自动上传失败:本地文件不存在 GID={gid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 计算远程路径
|
||||||
|
try:
|
||||||
|
download_dir = DOWNLOAD_DIR.resolve()
|
||||||
|
relative_path = local_path.resolve().relative_to(download_dir)
|
||||||
|
remote_path = f"{self._onedrive_config.remote_path}/{relative_path.parent}"
|
||||||
|
except ValueError:
|
||||||
|
remote_path = self._onedrive_config.remote_path
|
||||||
|
|
||||||
|
# 启动后台上传任务
|
||||||
|
asyncio.create_task(self._do_auto_upload(
|
||||||
|
client, local_path, remote_path, task.name, chat_id, gid
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _do_auto_upload(
|
||||||
|
self, client, local_path, remote_path: str, task_name: str, chat_id: int, gid: str
|
||||||
|
) -> None:
|
||||||
|
"""后台执行自动上传任务"""
|
||||||
|
import shutil
|
||||||
|
from .app import _bot_instance # 获取全局 bot 实例
|
||||||
|
|
||||||
|
if _bot_instance is None:
|
||||||
|
logger.error(f"自动上传失败:无法获取 bot 实例 GID={gid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 发送上传开始通知
|
||||||
|
try:
|
||||||
|
msg = await _bot_instance.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=f"☁️ 自动上传开始: {task_name}\n⏳ 请稍候..."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"自动上传失败:发送消息失败 GID={gid}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# 进度回调函数
|
||||||
|
async def update_progress(progress):
|
||||||
|
if progress.status == UploadStatus.UPLOADING and progress.total_size > 0:
|
||||||
|
percent = progress.progress
|
||||||
|
uploaded_mb = progress.uploaded_size / (1024 * 1024)
|
||||||
|
total_mb = progress.total_size / (1024 * 1024)
|
||||||
|
progress_text = (
|
||||||
|
f"☁️ 自动上传: {task_name}\n"
|
||||||
|
f"📤 {percent:.1f}% ({uploaded_mb:.1f}MB / {total_mb:.1f}MB)"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await msg.edit_text(progress_text)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync_progress_callback(progress):
|
||||||
|
if progress.status == UploadStatus.UPLOADING:
|
||||||
|
asyncio.run_coroutine_threadsafe(update_progress(progress), loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await client.upload_file(local_path, remote_path, progress_callback=sync_progress_callback)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
result_text = f"✅ 自动上传成功: {task_name}"
|
||||||
|
if self._onedrive_config and self._onedrive_config.delete_after_upload:
|
||||||
|
try:
|
||||||
|
if local_path.is_dir():
|
||||||
|
shutil.rmtree(local_path)
|
||||||
|
else:
|
||||||
|
local_path.unlink()
|
||||||
|
result_text += "\n🗑️ 本地文件已删除"
|
||||||
|
except Exception as e:
|
||||||
|
result_text += f"\n⚠️ 删除本地文件失败: {e}"
|
||||||
|
await msg.edit_text(result_text)
|
||||||
|
logger.info(f"自动上传成功 GID={gid}")
|
||||||
|
else:
|
||||||
|
await msg.edit_text(f"❌ 自动上传失败: {task_name}")
|
||||||
|
logger.error(f"自动上传失败 GID={gid}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"自动上传异常 GID={gid}: {e}")
|
||||||
|
try:
|
||||||
|
await msg.edit_text(f"❌ 自动上传失败: {task_name}\n错误: {e}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""处理 Reply Keyboard 按钮点击"""
|
"""处理 Reply Keyboard 按钮点击"""
|
||||||
text = update.message.text
|
text = update.message.text
|
||||||
@@ -679,6 +787,9 @@ class Aria2BotAPI:
|
|||||||
keyboard = build_after_add_keyboard(gid)
|
keyboard = build_after_add_keyboard(gid)
|
||||||
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
|
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
|
||||||
logger.info(f"/add 命令执行成功, GID={gid} - {_get_user_info(update)}")
|
logger.info(f"/add 命令执行成功, GID={gid} - {_get_user_info(update)}")
|
||||||
|
# 启动下载监控,完成或失败时通知用户
|
||||||
|
chat_id = update.effective_chat.id
|
||||||
|
asyncio.create_task(self._start_download_monitor(gid, chat_id))
|
||||||
except RpcError as e:
|
except RpcError as e:
|
||||||
logger.error(f"/add 命令执行失败: {e} - {_get_user_info(update)}")
|
logger.error(f"/add 命令执行失败: {e} - {_get_user_info(update)}")
|
||||||
await self._reply(update, context, f"❌ 添加失败: {e}")
|
await self._reply(update, context, f"❌ 添加失败: {e}")
|
||||||
@@ -702,6 +813,9 @@ class Aria2BotAPI:
|
|||||||
keyboard = build_after_add_keyboard(gid)
|
keyboard = build_after_add_keyboard(gid)
|
||||||
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
|
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
|
||||||
logger.info(f"种子任务添加成功, GID={gid} - {_get_user_info(update)}")
|
logger.info(f"种子任务添加成功, GID={gid} - {_get_user_info(update)}")
|
||||||
|
# 启动下载监控,完成或失败时通知用户
|
||||||
|
chat_id = update.effective_chat.id
|
||||||
|
asyncio.create_task(self._start_download_monitor(gid, chat_id))
|
||||||
except RpcError as e:
|
except RpcError as e:
|
||||||
logger.error(f"种子任务添加失败: {e} - {_get_user_info(update)}")
|
logger.error(f"种子任务添加失败: {e} - {_get_user_info(update)}")
|
||||||
await self._reply(update, context, f"❌ 添加种子失败: {e}")
|
await self._reply(update, context, f"❌ 添加种子失败: {e}")
|
||||||
@@ -1002,12 +1116,81 @@ class Aria2BotAPI:
|
|||||||
|
|
||||||
# 任务完成或出错时停止刷新
|
# 任务完成或出错时停止刷新
|
||||||
if task.status in ("complete", "error", "removed"):
|
if task.status in ("complete", "error", "removed"):
|
||||||
|
# 任务完成时检查是否需要自动上传
|
||||||
|
if (task.status == "complete" and
|
||||||
|
gid not in self._auto_uploaded_gids and
|
||||||
|
self._onedrive_config and
|
||||||
|
self._onedrive_config.enabled and
|
||||||
|
self._onedrive_config.auto_upload):
|
||||||
|
self._auto_uploaded_gids.add(gid)
|
||||||
|
asyncio.create_task(self._trigger_auto_upload(message.chat_id, gid))
|
||||||
break
|
break
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
finally:
|
finally:
|
||||||
self._auto_refresh_tasks.pop(key, None)
|
self._auto_refresh_tasks.pop(key, None)
|
||||||
|
|
||||||
|
# === 下载任务监控和通知 ===
|
||||||
|
|
||||||
|
async def _start_download_monitor(self, gid: str, chat_id: int) -> None:
|
||||||
|
"""启动下载任务监控"""
|
||||||
|
if gid in self._download_monitors:
|
||||||
|
return
|
||||||
|
task = asyncio.create_task(self._monitor_download(gid, chat_id))
|
||||||
|
self._download_monitors[gid] = task
|
||||||
|
|
||||||
|
async def _monitor_download(self, gid: str, chat_id: int) -> None:
|
||||||
|
"""监控下载任务直到完成或失败"""
|
||||||
|
from .app import _bot_instance
|
||||||
|
try:
|
||||||
|
rpc = self._get_rpc_client()
|
||||||
|
for _ in range(17280): # 最长 24 小时 (5秒 * 17280)
|
||||||
|
try:
|
||||||
|
task = await rpc.get_status(gid)
|
||||||
|
except RpcError:
|
||||||
|
break # 任务可能已被删除
|
||||||
|
|
||||||
|
if task.status == "complete":
|
||||||
|
if gid not in self._notified_gids:
|
||||||
|
self._notified_gids.add(gid)
|
||||||
|
await self._send_completion_notification(chat_id, task)
|
||||||
|
break
|
||||||
|
elif task.status == "error":
|
||||||
|
if gid not in self._notified_gids:
|
||||||
|
self._notified_gids.add(gid)
|
||||||
|
await self._send_error_notification(chat_id, task)
|
||||||
|
break
|
||||||
|
elif task.status == "removed":
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
finally:
|
||||||
|
self._download_monitors.pop(gid, None)
|
||||||
|
|
||||||
|
async def _send_completion_notification(self, chat_id: int, task: DownloadTask) -> None:
|
||||||
|
"""发送下载完成通知"""
|
||||||
|
from .app import _bot_instance
|
||||||
|
if _bot_instance is None:
|
||||||
|
return
|
||||||
|
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
|
||||||
|
text = f"✅ *下载完成*\n📄 {safe_name}\n📦 大小: {task.size_str}\n🆔 GID: `{task.gid}`"
|
||||||
|
try:
|
||||||
|
await _bot_instance.send_message(chat_id=chat_id, text=text, parse_mode="Markdown")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"发送完成通知失败 (GID={task.gid}): {e}")
|
||||||
|
|
||||||
|
async def _send_error_notification(self, chat_id: int, task: DownloadTask) -> None:
|
||||||
|
"""发送下载失败通知"""
|
||||||
|
from .app import _bot_instance
|
||||||
|
if _bot_instance is None:
|
||||||
|
return
|
||||||
|
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
|
||||||
|
text = f"❌ *下载失败*\n📄 {safe_name}\n🆔 GID: `{task.gid}`\n⚠️ 原因: {task.error_message or '未知错误'}"
|
||||||
|
try:
|
||||||
|
await _bot_instance.send_message(chat_id=chat_id, text=text, parse_mode="Markdown")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"发送失败通知失败 (GID={task.gid}): {e}")
|
||||||
|
|
||||||
async def _handle_stats_callback(self, query, rpc: Aria2RpcClient) -> None:
|
async def _handle_stats_callback(self, query, rpc: Aria2RpcClient) -> None:
|
||||||
"""处理统计回调"""
|
"""处理统计回调"""
|
||||||
stat = await rpc.get_global_stat()
|
stat = await rpc.get_global_stat()
|
||||||
|
|||||||
Reference in New Issue
Block a user