mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-12 04:22:21 +08:00
feat: 增加下载暂停等aria2功能
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes, CommandHandler
|
||||
from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
@@ -13,12 +13,23 @@ from src.core import (
|
||||
ServiceError,
|
||||
DownloadError,
|
||||
ConfigError,
|
||||
RpcError,
|
||||
is_aria2_installed,
|
||||
get_aria2_version,
|
||||
generate_rpc_secret,
|
||||
ARIA2_CONF,
|
||||
)
|
||||
from src.aria2 import Aria2Installer, Aria2ServiceManager
|
||||
from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
|
||||
from src.telegram.keyboards import (
|
||||
STATUS_EMOJI,
|
||||
build_list_type_keyboard,
|
||||
build_task_keyboard,
|
||||
build_task_list_keyboard,
|
||||
build_delete_confirm_keyboard,
|
||||
build_detail_keyboard,
|
||||
build_after_add_keyboard,
|
||||
)
|
||||
|
||||
logger = get_logger("handlers")
|
||||
|
||||
@@ -36,6 +47,15 @@ class Aria2BotAPI:
|
||||
self.config = config or Aria2Config()
|
||||
self.installer = Aria2Installer(self.config)
|
||||
self.service = Aria2ServiceManager()
|
||||
self._rpc: Aria2RpcClient | None = None
|
||||
|
||||
def _get_rpc_client(self) -> Aria2RpcClient:
|
||||
"""获取或创建 RPC 客户端"""
|
||||
if self._rpc is None:
|
||||
secret = self._get_rpc_secret()
|
||||
port = self._get_rpc_port() or 6800
|
||||
self._rpc = Aria2RpcClient(port=port, secret=secret)
|
||||
return self._rpc
|
||||
|
||||
async def _reply(self, update: Update, context: ContextTypes.DEFAULT_TYPE, text: str, **kwargs):
|
||||
if update.effective_message:
|
||||
@@ -285,6 +305,7 @@ class Aria2BotAPI:
|
||||
async def help_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
logger.info(f"收到 /help 命令 - {_get_user_info(update)}")
|
||||
commands = [
|
||||
"*服务管理*",
|
||||
"/install - 安装 aria2",
|
||||
"/uninstall - 卸载 aria2",
|
||||
"/start - 启动 aria2 服务",
|
||||
@@ -292,17 +313,322 @@ class Aria2BotAPI:
|
||||
"/restart - 重启 aria2 服务",
|
||||
"/status - 查看 aria2 状态",
|
||||
"/logs - 查看最近日志",
|
||||
"/clear_logs - 清空日志",
|
||||
"/set_secret <密钥> - 设置自定义 RPC 密钥",
|
||||
"/reset_secret - 重新生成随机 RPC 密钥",
|
||||
"/clear\\_logs - 清空日志",
|
||||
"/set\\_secret <密钥> - 设置 RPC 密钥",
|
||||
"/reset\\_secret - 重新生成 RPC 密钥",
|
||||
"",
|
||||
"*下载管理*",
|
||||
"/add <URL> - 添加下载任务",
|
||||
"/list - 查看下载列表",
|
||||
"/stats - 全局下载统计",
|
||||
"",
|
||||
"/help - 显示此帮助",
|
||||
]
|
||||
await self._reply(update, context, "可用命令:\n" + "\n".join(commands))
|
||||
await self._reply(update, context, "可用命令:\n" + "\n".join(commands), parse_mode="Markdown")
|
||||
|
||||
# === 下载管理命令 ===
|
||||
|
||||
async def add_download(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""/add <url> - 添加下载任务"""
|
||||
logger.info(f"收到 /add 命令 - {_get_user_info(update)}")
|
||||
if not context.args:
|
||||
await self._reply(update, context, "用法: /add <URL>\n支持 HTTP/HTTPS/磁力链接")
|
||||
return
|
||||
|
||||
url = context.args[0]
|
||||
try:
|
||||
rpc = self._get_rpc_client()
|
||||
gid = await rpc.add_uri(url)
|
||||
task = await rpc.get_status(gid)
|
||||
# 转义文件名中的 Markdown 特殊字符
|
||||
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
|
||||
text = f"✅ 任务已添加\n📄 {safe_name}\n🆔 GID: `{gid}`"
|
||||
keyboard = build_after_add_keyboard(gid)
|
||||
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
|
||||
logger.info(f"/add 命令执行成功, GID={gid} - {_get_user_info(update)}")
|
||||
except RpcError as e:
|
||||
logger.error(f"/add 命令执行失败: {e} - {_get_user_info(update)}")
|
||||
await self._reply(update, context, f"❌ 添加失败: {e}")
|
||||
|
||||
async def handle_torrent(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""处理用户发送的种子文件"""
|
||||
logger.info(f"收到种子文件 - {_get_user_info(update)}")
|
||||
document = update.message.document
|
||||
if not document or not document.file_name.endswith(".torrent"):
|
||||
return
|
||||
|
||||
try:
|
||||
file = await context.bot.get_file(document.file_id)
|
||||
torrent_data = await file.download_as_bytearray()
|
||||
rpc = self._get_rpc_client()
|
||||
gid = await rpc.add_torrent(bytes(torrent_data))
|
||||
task = await rpc.get_status(gid)
|
||||
# 转义文件名中的 Markdown 特殊字符
|
||||
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
|
||||
text = f"✅ 种子任务已添加\n📄 {safe_name}\n🆔 GID: `{gid}`"
|
||||
keyboard = build_after_add_keyboard(gid)
|
||||
await self._reply(update, context, text, parse_mode="Markdown", reply_markup=keyboard)
|
||||
logger.info(f"种子任务添加成功, GID={gid} - {_get_user_info(update)}")
|
||||
except RpcError as e:
|
||||
logger.error(f"种子任务添加失败: {e} - {_get_user_info(update)}")
|
||||
await self._reply(update, context, f"❌ 添加种子失败: {e}")
|
||||
|
||||
async def list_downloads(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""/list - 查看下载列表"""
|
||||
logger.info(f"收到 /list 命令 - {_get_user_info(update)}")
|
||||
try:
|
||||
rpc = self._get_rpc_client()
|
||||
stat = await rpc.get_global_stat()
|
||||
active_count = int(stat.get("numActive", 0))
|
||||
waiting_count = int(stat.get("numWaiting", 0))
|
||||
stopped_count = int(stat.get("numStopped", 0))
|
||||
|
||||
keyboard = build_list_type_keyboard(active_count, waiting_count, stopped_count)
|
||||
await self._reply(update, context, "📥 选择查看类型:", reply_markup=keyboard)
|
||||
except RpcError as e:
|
||||
logger.error(f"/list 命令执行失败: {e} - {_get_user_info(update)}")
|
||||
await self._reply(update, context, f"❌ 获取列表失败: {e}")
|
||||
|
||||
async def global_stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""/stats - 全局下载统计"""
|
||||
logger.info(f"收到 /stats 命令 - {_get_user_info(update)}")
|
||||
try:
|
||||
rpc = self._get_rpc_client()
|
||||
stat = await rpc.get_global_stat()
|
||||
text = (
|
||||
"📊 *全局统计*\n"
|
||||
f"⬇️ 下载速度: {_format_size(int(stat.get('downloadSpeed', 0)))}/s\n"
|
||||
f"⬆️ 上传速度: {_format_size(int(stat.get('uploadSpeed', 0)))}/s\n"
|
||||
f"▶️ 活动任务: {stat.get('numActive', 0)}\n"
|
||||
f"⏳ 等待任务: {stat.get('numWaiting', 0)}\n"
|
||||
f"⏹️ 已停止: {stat.get('numStopped', 0)}"
|
||||
)
|
||||
await self._reply(update, context, text, parse_mode="Markdown")
|
||||
except RpcError as e:
|
||||
logger.error(f"/stats 命令执行失败: {e} - {_get_user_info(update)}")
|
||||
await self._reply(update, context, f"❌ 获取统计失败: {e}")
|
||||
|
||||
# === Callback Query 处理 ===
|
||||
|
||||
async def handle_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""处理 Inline Keyboard 回调"""
|
||||
query = update.callback_query
|
||||
|
||||
try:
|
||||
await query.answer()
|
||||
except Exception as e:
|
||||
logger.warning(f"回调应答失败 (可忽略): {e}")
|
||||
|
||||
data = query.data
|
||||
if not data:
|
||||
return
|
||||
|
||||
parts = data.split(":")
|
||||
action = parts[0]
|
||||
|
||||
try:
|
||||
rpc = self._get_rpc_client()
|
||||
|
||||
if action == "list":
|
||||
await self._handle_list_callback(query, rpc, parts)
|
||||
elif action == "pause":
|
||||
await self._handle_pause_callback(query, rpc, parts[1])
|
||||
elif action == "resume":
|
||||
await self._handle_resume_callback(query, rpc, parts[1])
|
||||
elif action == "delete":
|
||||
await self._handle_delete_callback(query, parts[1])
|
||||
elif action == "confirm_del":
|
||||
await self._handle_confirm_delete_callback(query, rpc, parts[1], parts[2])
|
||||
elif action == "detail":
|
||||
await self._handle_detail_callback(query, rpc, parts[1])
|
||||
elif action == "stats":
|
||||
await self._handle_stats_callback(query, rpc)
|
||||
elif action == "cancel":
|
||||
await query.edit_message_text("❌ 操作已取消")
|
||||
|
||||
except RpcError as e:
|
||||
await query.edit_message_text(f"❌ 操作失败: {e}")
|
||||
|
||||
async def _handle_list_callback(self, query, rpc: Aria2RpcClient, parts: list) -> None:
|
||||
"""处理列表相关回调"""
|
||||
if parts[1] == "menu":
|
||||
stat = await rpc.get_global_stat()
|
||||
keyboard = build_list_type_keyboard(
|
||||
int(stat.get("numActive", 0)),
|
||||
int(stat.get("numWaiting", 0)),
|
||||
int(stat.get("numStopped", 0)),
|
||||
)
|
||||
await query.edit_message_text("📥 选择查看类型:", reply_markup=keyboard)
|
||||
return
|
||||
|
||||
list_type = parts[1]
|
||||
page = int(parts[2]) if len(parts) > 2 else 1
|
||||
|
||||
if list_type == "active":
|
||||
tasks = await rpc.get_active()
|
||||
title = "▶️ 活动任务"
|
||||
elif list_type == "waiting":
|
||||
tasks = await rpc.get_waiting()
|
||||
title = "⏳ 等待任务"
|
||||
else: # stopped
|
||||
tasks = await rpc.get_stopped()
|
||||
title = "✅ 已完成/错误"
|
||||
|
||||
await self._send_task_list(query, tasks, page, list_type, title)
|
||||
|
||||
async def _send_task_list(self, query, tasks: list[DownloadTask], page: int, list_type: str, title: str) -> None:
|
||||
"""发送任务列表"""
|
||||
page_size = 5
|
||||
total_pages = max(1, (len(tasks) + page_size - 1) // page_size)
|
||||
start = (page - 1) * page_size
|
||||
page_tasks = tasks[start:start + page_size]
|
||||
|
||||
if not tasks:
|
||||
keyboard = build_task_list_keyboard(1, 1, list_type)
|
||||
await query.edit_message_text(f"{title}\n\n📭 暂无任务", reply_markup=keyboard)
|
||||
return
|
||||
|
||||
lines = [f"{title} ({page}/{total_pages})\n"]
|
||||
for t in page_tasks:
|
||||
emoji = STATUS_EMOJI.get(t.status, "❓")
|
||||
lines.append(f"{emoji} {t.name}")
|
||||
lines.append(f" {t.progress_bar} {t.progress:.1f}%")
|
||||
lines.append(f" {t.size_str} | {t.speed_str}")
|
||||
# 添加操作按钮提示
|
||||
if t.status == "active":
|
||||
lines.append(f" ⏸ /pause\\_{t.gid[:8]}")
|
||||
elif t.status in ("paused", "waiting"):
|
||||
lines.append(f" ▶️ /resume\\_{t.gid[:8]}")
|
||||
lines.append(f" 📋 详情: 点击下方按钮\n")
|
||||
|
||||
# 为每个任务添加操作按钮
|
||||
task_buttons = []
|
||||
for t in page_tasks:
|
||||
row = []
|
||||
if t.status == "active":
|
||||
row.append({"text": f"⏸ {t.gid[:6]}", "callback_data": f"pause:{t.gid}"})
|
||||
elif t.status in ("paused", "waiting"):
|
||||
row.append({"text": f"▶️ {t.gid[:6]}", "callback_data": f"resume:{t.gid}"})
|
||||
row.append({"text": f"🗑 {t.gid[:6]}", "callback_data": f"delete:{t.gid}"})
|
||||
row.append({"text": f"📋 {t.gid[:6]}", "callback_data": f"detail:{t.gid}"})
|
||||
task_buttons.append(row)
|
||||
|
||||
# 构建完整键盘
|
||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup
|
||||
keyboard_rows = []
|
||||
for row in task_buttons:
|
||||
keyboard_rows.append([InlineKeyboardButton(b["text"], callback_data=b["callback_data"]) for b in row])
|
||||
|
||||
# 添加翻页按钮
|
||||
nav_buttons = []
|
||||
if page > 1:
|
||||
nav_buttons.append(InlineKeyboardButton("⬅️ 上一页", callback_data=f"list:{list_type}:{page - 1}"))
|
||||
if page < total_pages:
|
||||
nav_buttons.append(InlineKeyboardButton("➡️ 下一页", callback_data=f"list:{list_type}:{page + 1}"))
|
||||
if nav_buttons:
|
||||
keyboard_rows.append(nav_buttons)
|
||||
|
||||
keyboard_rows.append([InlineKeyboardButton("🔙 返回列表", callback_data="list:menu")])
|
||||
|
||||
await query.edit_message_text("\n".join(lines), reply_markup=InlineKeyboardMarkup(keyboard_rows))
|
||||
|
||||
async def _handle_pause_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
|
||||
"""处理暂停回调"""
|
||||
await rpc.pause(gid)
|
||||
task = await rpc.get_status(gid)
|
||||
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
|
||||
keyboard = build_task_keyboard(gid, task.status)
|
||||
await query.edit_message_text(f"⏸️ 任务已暂停\n📄 {safe_name}\n🆔 GID: `{gid}`",
|
||||
parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
async def _handle_resume_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
|
||||
"""处理恢复回调"""
|
||||
await rpc.unpause(gid)
|
||||
task = await rpc.get_status(gid)
|
||||
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
|
||||
keyboard = build_task_keyboard(gid, task.status)
|
||||
await query.edit_message_text(f"▶️ 任务已恢复\n📄 {safe_name}\n🆔 GID: `{gid}`",
|
||||
parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
async def _handle_delete_callback(self, query, gid: str) -> None:
|
||||
"""处理删除确认回调"""
|
||||
keyboard = build_delete_confirm_keyboard(gid)
|
||||
await query.edit_message_text(f"⚠️ 确认删除任务?\n🆔 GID: `{gid}`",
|
||||
parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
async def _handle_confirm_delete_callback(self, query, rpc: Aria2RpcClient, gid: str, delete_file: str) -> None:
|
||||
"""处理确认删除回调"""
|
||||
task = None
|
||||
try:
|
||||
task = await rpc.get_status(gid)
|
||||
except RpcError:
|
||||
pass
|
||||
|
||||
# 尝试删除任务
|
||||
try:
|
||||
await rpc.remove(gid)
|
||||
except RpcError:
|
||||
try:
|
||||
await rpc.force_remove(gid)
|
||||
except RpcError:
|
||||
pass
|
||||
try:
|
||||
await rpc.remove_download_result(gid)
|
||||
except RpcError:
|
||||
pass
|
||||
|
||||
# 如果需要删除文件
|
||||
file_deleted = False
|
||||
if delete_file == "1" and task:
|
||||
file_deleted = rpc.delete_files(task)
|
||||
|
||||
msg = f"🗑️ 任务已删除\n🆔 GID: `{gid}`"
|
||||
if delete_file == "1":
|
||||
msg += f"\n📁 文件: {'已删除' if file_deleted else '删除失败或不存在'}"
|
||||
|
||||
await query.edit_message_text(msg, parse_mode="Markdown")
|
||||
|
||||
async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
|
||||
"""处理详情回调"""
|
||||
task = await rpc.get_status(gid)
|
||||
emoji = STATUS_EMOJI.get(task.status, "❓")
|
||||
safe_name = task.name.replace("_", "\\_").replace("*", "\\*").replace("`", "\\`")
|
||||
text = (
|
||||
f"📋 *任务详情*\n"
|
||||
f"📄 文件: {safe_name}\n"
|
||||
f"🆔 GID: `{task.gid}`\n"
|
||||
f"📊 状态: {emoji} {task.status}\n"
|
||||
f"📈 进度: {task.progress_bar} {task.progress:.1f}%\n"
|
||||
f"📦 大小: {task.size_str}\n"
|
||||
f"⬇️ 下载: {task.speed_str}\n"
|
||||
f"⬆️ 上传: {_format_size(task.upload_speed)}/s"
|
||||
)
|
||||
if task.error_message:
|
||||
text += f"\n❌ 错误: {task.error_message}"
|
||||
|
||||
keyboard = build_detail_keyboard(gid, task.status)
|
||||
await query.edit_message_text(text, parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
async def _handle_stats_callback(self, query, rpc: Aria2RpcClient) -> None:
|
||||
"""处理统计回调"""
|
||||
stat = await rpc.get_global_stat()
|
||||
text = (
|
||||
"📊 *全局统计*\n"
|
||||
f"⬇️ 下载速度: {_format_size(int(stat.get('downloadSpeed', 0)))}/s\n"
|
||||
f"⬆️ 上传速度: {_format_size(int(stat.get('uploadSpeed', 0)))}/s\n"
|
||||
f"▶️ 活动任务: {stat.get('numActive', 0)}\n"
|
||||
f"⏳ 等待任务: {stat.get('numWaiting', 0)}\n"
|
||||
f"⏹️ 已停止: {stat.get('numStopped', 0)}"
|
||||
)
|
||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup
|
||||
keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("🔙 返回列表", callback_data="list:menu")]])
|
||||
await query.edit_message_text(text, parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
|
||||
def build_handlers(api: Aria2BotAPI) -> list[CommandHandler]:
|
||||
"""构建 CommandHandler 列表"""
|
||||
def build_handlers(api: Aria2BotAPI) -> list:
|
||||
"""构建 Handler 列表"""
|
||||
return [
|
||||
# 服务管理命令
|
||||
CommandHandler("install", api.install),
|
||||
CommandHandler("uninstall", api.uninstall),
|
||||
CommandHandler("start", api.start_service),
|
||||
@@ -314,4 +640,12 @@ def build_handlers(api: Aria2BotAPI) -> list[CommandHandler]:
|
||||
CommandHandler("set_secret", api.set_secret),
|
||||
CommandHandler("reset_secret", api.reset_secret),
|
||||
CommandHandler("help", api.help_command),
|
||||
# 下载管理命令
|
||||
CommandHandler("add", api.add_download),
|
||||
CommandHandler("list", api.list_downloads),
|
||||
CommandHandler("stats", api.global_stats),
|
||||
# 种子文件处理
|
||||
MessageHandler(filters.Document.FileExtension("torrent"), api.handle_torrent),
|
||||
# Callback Query 处理
|
||||
CallbackQueryHandler(api.handle_callback),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user