Files
aria2bot/src/aria2/rpc.py
dnslin bf4fdf1377 fix: 修复安全漏洞和代码质量问题
安全修复:
- 修复路径遍历检查,使用 Path.relative_to() 替代字符串前缀检查
- 修复 Zip Slip 漏洞,添加符号链接检查和路径验证
- 隐藏 RPC 密钥显示,防止敏感信息泄露
- 设置配置文件权限为 0o600

Bug 修复:
- 修复 HTTP 状态码检查(resp.status → resp.code)
- 修复 OneDrive 认证 flow 参数类型
- 修复 RPC 请求缺少状态码验证
- 修复配置文件渲染会替换注释行的问题

代码改进:
- 添加 subprocess 超时处理,防止进程挂起
- 修复异步代码问题(get_event_loop → get_running_loop)
- 使用 asyncio.to_thread 避免阻塞事件循环
- 添加 httpx 超时和状态码异常处理
- 移除无用的 ONEDRIVE_CLIENT_SECRET 配置

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 16:42:48 +08:00

236 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""aria2 JSON-RPC 2.0 客户端"""
from __future__ import annotations
import base64
import json
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import httpx
from src.core.exceptions import RpcError
from src.utils.logger import get_logger
logger = get_logger("rpc")
def _format_size(size: int) -> str:
"""格式化字节大小"""
for unit in ("B", "KB", "MB", "GB"):
if size < 1024:
return f"{size:.1f}{unit}"
size /= 1024
return f"{size:.1f}TB"
@dataclass
class DownloadTask:
"""下载任务数据类"""
gid: str
status: str # active, waiting, paused, error, complete, removed
name: str
total_length: int
completed_length: int
download_speed: int
upload_speed: int = 0
error_message: str = ""
dir: str = ""
@property
def progress(self) -> float:
"""计算下载进度百分比"""
if self.total_length == 0:
return 0.0
return (self.completed_length / self.total_length) * 100
@property
def progress_bar(self) -> str:
"""生成进度条"""
pct = int(self.progress / 10)
return "" * pct + "" * (10 - pct)
@property
def speed_str(self) -> str:
"""格式化下载速度"""
return _format_size(self.download_speed) + "/s"
@property
def size_str(self) -> str:
"""格式化文件大小"""
return f"{_format_size(self.completed_length)}/{_format_size(self.total_length)}"
class Aria2RpcClient:
"""aria2 RPC 客户端"""
def __init__(self, host: str = "localhost", port: int = 6800, secret: str = ""):
self.url = f"http://{host}:{port}/jsonrpc"
self.secret = secret
async def _call(self, method: str, params: list | None = None) -> Any:
"""发送 RPC 请求"""
payload = {
"jsonrpc": "2.0",
"id": str(uuid.uuid4()),
"method": method,
"params": [],
}
# 添加 token 认证
if self.secret:
payload["params"].append(f"token:{self.secret}")
if params:
payload["params"].extend(params)
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.post(self.url, json=payload)
resp.raise_for_status()
data = resp.json()
except httpx.ConnectError:
raise RpcError("aria2 服务可能未运行,请先使用 /start 命令启动服务") from None
except httpx.TimeoutException:
raise RpcError("RPC 请求超时aria2 服务响应缓慢") from None
except httpx.HTTPStatusError as e:
raise RpcError(f"RPC 请求失败HTTP 状态码: {e.response.status_code}") from e
except httpx.RequestError as e:
raise RpcError(f"RPC 请求失败: {e}") from e
except json.JSONDecodeError as e:
raise RpcError(f"RPC 响应解析失败: {e}") from e
if "error" in data:
raise RpcError(data["error"].get("message", "未知错误"))
return data.get("result")
# === 添加任务 ===
async def add_uri(self, uri: str) -> str:
"""添加 URL 下载任务,返回 GID"""
result = await self._call("aria2.addUri", [[uri]])
logger.info(f"添加下载任务: {uri[:50]}..., GID={result}")
return result
async def add_torrent(self, torrent_data: bytes) -> str:
"""添加种子下载任务,返回 GID"""
b64_data = base64.b64encode(torrent_data).decode("utf-8")
result = await self._call("aria2.addTorrent", [b64_data])
logger.info(f"添加种子任务, GID={result}")
return result
# === 任务控制 ===
async def pause(self, gid: str) -> str:
"""暂停任务"""
return await self._call("aria2.pause", [gid])
async def unpause(self, gid: str) -> str:
"""恢复任务"""
return await self._call("aria2.unpause", [gid])
async def remove(self, gid: str) -> str:
"""删除任务(仅从队列移除)"""
return await self._call("aria2.remove", [gid])
async def force_remove(self, gid: str) -> str:
"""强制删除任务"""
return await self._call("aria2.forceRemove", [gid])
async def remove_download_result(self, gid: str) -> str:
"""删除已完成/错误任务的记录"""
return await self._call("aria2.removeDownloadResult", [gid])
# === 查询任务 ===
async def get_status(self, gid: str) -> DownloadTask:
"""获取单个任务状态"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "errorMessage", "dir"]
result = await self._call("aria2.tellStatus", [gid, keys])
return self._parse_task(result)
async def get_active(self) -> list[DownloadTask]:
"""获取活动任务列表"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "dir"]
result = await self._call("aria2.tellActive", [keys])
return [self._parse_task(t) for t in result]
async def get_waiting(self, offset: int = 0, num: int = 100) -> list[DownloadTask]:
"""获取等待/暂停任务列表"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "dir"]
result = await self._call("aria2.tellWaiting", [offset, num, keys])
return [self._parse_task(t) for t in result]
async def get_stopped(self, offset: int = 0, num: int = 100) -> list[DownloadTask]:
"""获取已停止任务列表(完成/错误)"""
keys = ["gid", "status", "totalLength", "completedLength",
"downloadSpeed", "uploadSpeed", "files", "errorMessage", "dir"]
result = await self._call("aria2.tellStopped", [offset, num, keys])
return [self._parse_task(t) for t in result]
async def get_global_stat(self) -> dict:
"""获取全局统计"""
return await self._call("aria2.getGlobalStat")
# === 文件操作 ===
async def get_files(self, gid: str) -> list[dict]:
"""获取任务文件列表"""
return await self._call("aria2.getFiles", [gid])
def delete_files(self, task: DownloadTask) -> bool:
"""删除任务对应的文件(同步方法)"""
if not task.dir or not task.name:
return False
try:
file_path = (Path(task.dir) / task.name).resolve()
# 安全检查:验证路径在下载目录内,防止路径遍历攻击
from src.core.constants import DOWNLOAD_DIR
download_dir = DOWNLOAD_DIR.resolve()
try:
file_path.relative_to(download_dir)
except ValueError:
logger.error(f"路径遍历尝试被阻止: {file_path}")
return False
if file_path.exists():
if file_path.is_dir():
import shutil
shutil.rmtree(file_path)
else:
file_path.unlink()
logger.info(f"已删除文件: {file_path}")
return True
except OSError as e:
logger.error(f"删除文件失败: {e}")
return False
# === 内部方法 ===
def _parse_task(self, data: dict) -> DownloadTask:
"""解析任务数据"""
# 从 files 中提取文件名
name = "未知文件"
if data.get("files"):
path = data["files"][0].get("path", "")
if path:
name = path.split("/")[-1]
elif data["files"][0].get("uris"):
uris = data["files"][0]["uris"]
if uris:
uri = uris[0].get("uri", "")
name = uri.split("/")[-1].split("?")[0] or uri[:30]
return DownloadTask(
gid=data.get("gid", ""),
status=data.get("status", "unknown"),
name=name[:40] if len(name) > 40 else name, # 截断文件名
total_length=int(data.get("totalLength", 0)),
completed_length=int(data.get("completedLength", 0)),
download_speed=int(data.get("downloadSpeed", 0)),
upload_speed=int(data.get("uploadSpeed", 0)),
error_message=data.get("errorMessage", ""),
dir=data.get("dir", ""),
)