feat: 增加下载暂停等aria2功能

This commit is contained in:
dnslin
2025-12-12 10:01:15 +08:00
parent cf1d98610e
commit cde94d3287
10 changed files with 693 additions and 13 deletions

View File

@@ -1,5 +1,6 @@
"""Aria2 operations module - installer and service management."""
"""Aria2 operations module - installer, service management, and RPC client."""
from src.aria2.installer import Aria2Installer
from src.aria2.service import Aria2ServiceManager
from src.aria2.rpc import Aria2RpcClient, DownloadTask
__all__ = ["Aria2Installer", "Aria2ServiceManager"]
__all__ = ["Aria2Installer", "Aria2ServiceManager", "Aria2RpcClient", "DownloadTask"]

222
src/aria2/rpc.py Normal file
View File

@@ -0,0 +1,222 @@
"""aria2 JSON-RPC 2.0 客户端"""
from __future__ import annotations
import base64
import json
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import httpx
from src.core.exceptions import RpcError
from src.utils.logger import get_logger
logger = get_logger("rpc")
def _format_size(size: int) -> str:
"""格式化字节大小"""
for unit in ("B", "KB", "MB", "GB"):
if size < 1024:
return f"{size:.1f}{unit}"
size /= 1024
return f"{size:.1f}TB"
@dataclass
class DownloadTask:
"""下载任务数据类"""
gid: str
status: str # active, waiting, paused, error, complete, removed
name: str
total_length: int
completed_length: int
download_speed: int
upload_speed: int = 0
error_message: str = ""
dir: str = ""
@property
def progress(self) -> float:
"""计算下载进度百分比"""
if self.total_length == 0:
return 0.0
return (self.completed_length / self.total_length) * 100
@property
def progress_bar(self) -> str:
"""生成进度条"""
pct = int(self.progress / 10)
return "" * pct + "" * (10 - pct)
@property
def speed_str(self) -> str:
"""格式化下载速度"""
return _format_size(self.download_speed) + "/s"
@property
def size_str(self) -> str:
"""格式化文件大小"""
return f"{_format_size(self.completed_length)}/{_format_size(self.total_length)}"
class Aria2RpcClient:
"""aria2 RPC 客户端"""
def __init__(self, host: str = "localhost", port: int = 6800, secret: str = ""):
self.url = f"http://{host}:{port}/jsonrpc"
self.secret = secret
async def _call(self, method: str, params: list | None = None) -> Any:
"""发送 RPC 请求"""
payload = {
"jsonrpc": "2.0",
"id": str(uuid.uuid4()),
"method": method,
"params": [],
}
# 添加 token 认证
if self.secret:
payload["params"].append(f"token:{self.secret}")
if params:
payload["params"].extend(params)
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.post(self.url, json=payload)
data = resp.json()
except httpx.ConnectError:
raise RpcError("aria2 服务可能未运行,请先使用 /start 命令启动服务") from None
except httpx.RequestError as e:
raise RpcError(f"RPC 请求失败: {e}") from e
except json.JSONDecodeError as e:
raise RpcError(f"RPC 响应解析失败: {e}") from e
if "error" in data:
raise RpcError(data["error"].get("message", "未知错误"))
return data.get("result")
# === 添加任务 ===
async def add_uri(self, uri: str) -> str:
"""添加 URL 下载任务,返回 GID"""
result = await self._call("aria2.addUri", [[uri]])
logger.info(f"添加下载任务: {uri[:50]}..., GID={result}")
return result
async def add_torrent(self, torrent_data: bytes) -> str:
"""添加种子下载任务,返回 GID"""
b64_data = base64.b64encode(torrent_data).decode("utf-8")
result = await self._call("aria2.addTorrent", [b64_data])
logger.info(f"添加种子任务, GID={result}")
return result
# === 任务控制 ===
async def pause(self, gid: str) -> str:
"""暂停任务"""
return await self._call("aria2.pause", [gid])
async def unpause(self, gid: str) -> str:
"""恢复任务"""
return await self._call("aria2.unpause", [gid])
async def remove(self, gid: str) -> str:
"""删除任务(仅从队列移除)"""
return await self._call("aria2.remove", [gid])
async def force_remove(self, gid: str) -> str:
"""强制删除任务"""
return await self._call("aria2.forceRemove", [gid])
async def remove_download_result(self, gid: str) -> str:
"""删除已完成/错误任务的记录"""
return await self._call("aria2.removeDownloadResult", [gid])
# === 查询任务 ===
async def get_status(self, gid: str) -> DownloadTask:
"""获取单个任务状态"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "errorMessage", "dir"]
result = await self._call("aria2.tellStatus", [gid, keys])
return self._parse_task(result)
async def get_active(self) -> list[DownloadTask]:
"""获取活动任务列表"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "dir"]
result = await self._call("aria2.tellActive", [keys])
return [self._parse_task(t) for t in result]
async def get_waiting(self, offset: int = 0, num: int = 100) -> list[DownloadTask]:
"""获取等待/暂停任务列表"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "dir"]
result = await self._call("aria2.tellWaiting", [offset, num, keys])
return [self._parse_task(t) for t in result]
async def get_stopped(self, offset: int = 0, num: int = 100) -> list[DownloadTask]:
"""获取已停止任务列表(完成/错误)"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "errorMessage", "dir"]
result = await self._call("aria2.tellStopped", [offset, num, keys])
return [self._parse_task(t) for t in result]
async def get_global_stat(self) -> dict:
"""获取全局统计"""
return await self._call("aria2.getGlobalStat")
# === 文件操作 ===
async def get_files(self, gid: str) -> list[dict]:
"""获取任务文件列表"""
return await self._call("aria2.getFiles", [gid])
def delete_files(self, task: DownloadTask) -> bool:
"""删除任务对应的文件(同步方法)"""
if not task.dir or not task.name:
return False
try:
file_path = Path(task.dir) / task.name
if file_path.exists():
if file_path.is_dir():
import shutil
shutil.rmtree(file_path)
else:
file_path.unlink()
logger.info(f"已删除文件: {file_path}")
return True
except OSError as e:
logger.error(f"删除文件失败: {e}")
return False
# === 内部方法 ===
def _parse_task(self, data: dict) -> DownloadTask:
"""解析任务数据"""
# 从 files 中提取文件名
name = "未知文件"
if data.get("files"):
path = data["files"][0].get("path", "")
if path:
name = path.split("/")[-1]
elif data["files"][0].get("uris"):
uris = data["files"][0]["uris"]
if uris:
uri = uris[0].get("uri", "")
name = uri.split("/")[-1].split("?")[0] or uri[:30]
return DownloadTask(
gid=data.get("gid", ""),
status=data.get("status", "unknown"),
name=name[:40] if len(name) > 40 else name, # 截断文件名
total_length=int(data.get("totalLength", 0)),
completed_length=int(data.get("completedLength", 0)),
download_speed=int(data.get("downloadSpeed", 0)),
upload_speed=int(data.get("uploadSpeed", 0)),
error_message=data.get("errorMessage", ""),
dir=data.get("dir", ""),
)