mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-11 04:02:20 +08:00
fix(security): 增加多处安全检查防止路径遍历和恶意输入,完善资源关闭逻辑
This commit is contained in:
@@ -50,6 +50,15 @@ class Aria2Installer:
|
||||
self.arch = detect_arch()
|
||||
self._executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
def __del__(self):
|
||||
"""确保线程池被关闭,防止资源泄漏"""
|
||||
if hasattr(self, '_executor'):
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def close(self):
|
||||
"""显式关闭资源"""
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
async def get_latest_version(self) -> str:
|
||||
"""从 GitHub API 获取最新版本号"""
|
||||
logger.info("正在获取 aria2 最新版本...")
|
||||
@@ -281,6 +290,10 @@ class Aria2Installer:
|
||||
@staticmethod
|
||||
def _extract_binary(archive_path: Path, extract_dir: Path) -> Path:
|
||||
with tarfile.open(archive_path, "r:gz") as tar:
|
||||
# 安全检查:验证所有成员路径,防止 Zip Slip 攻击
|
||||
for member in tar.getmembers():
|
||||
if member.name.startswith('/') or '..' in member.name:
|
||||
raise DownloadError(f"不安全的 tar 成员: {member.name}")
|
||||
tar.extractall(extract_dir)
|
||||
for candidate in extract_dir.rglob("aria2c"):
|
||||
if candidate.is_file():
|
||||
|
||||
@@ -180,7 +180,13 @@ class Aria2RpcClient:
|
||||
if not task.dir or not task.name:
|
||||
return False
|
||||
try:
|
||||
file_path = Path(task.dir) / task.name
|
||||
file_path = (Path(task.dir) / task.name).resolve()
|
||||
# 安全检查:验证路径在下载目录内,防止路径遍历攻击
|
||||
from src.core.constants import DOWNLOAD_DIR
|
||||
download_dir = DOWNLOAD_DIR.resolve()
|
||||
if not str(file_path).startswith(str(download_dir) + "/"):
|
||||
logger.error(f"路径遍历尝试被阻止: {file_path}")
|
||||
return False
|
||||
if file_path.exists():
|
||||
if file_path.is_dir():
|
||||
import shutil
|
||||
|
||||
@@ -29,21 +29,37 @@ class BotConfig:
|
||||
def from_env(cls) -> "BotConfig":
|
||||
"""从环境变量加载配置"""
|
||||
from dotenv import load_dotenv
|
||||
from src.core.exceptions import ConfigError
|
||||
load_dotenv()
|
||||
|
||||
token = os.environ.get("TELEGRAM_BOT_TOKEN", "")
|
||||
# 验证必需的 Token
|
||||
token = os.environ.get("TELEGRAM_BOT_TOKEN", "").strip()
|
||||
if not token:
|
||||
raise ConfigError("TELEGRAM_BOT_TOKEN 环境变量未设置")
|
||||
|
||||
# 安全解析 RPC 端口
|
||||
port_str = os.environ.get("ARIA2_RPC_PORT", "6800")
|
||||
try:
|
||||
rpc_port = int(port_str)
|
||||
if not (1 <= rpc_port <= 65535):
|
||||
raise ValueError("端口必须在 1-65535 范围内")
|
||||
except ValueError as e:
|
||||
raise ConfigError(f"无效的 ARIA2_RPC_PORT: {e}") from e
|
||||
|
||||
# 解析允许的用户 ID 列表
|
||||
allowed_users_str = os.environ.get("ALLOWED_USERS", "")
|
||||
allowed_users = set()
|
||||
if allowed_users_str:
|
||||
allowed_users = {
|
||||
int(uid.strip()) for uid in allowed_users_str.split(",")
|
||||
if uid.strip().isdigit()
|
||||
}
|
||||
for uid in allowed_users_str.split(","):
|
||||
uid = uid.strip()
|
||||
if uid.isdigit():
|
||||
user_id = int(uid)
|
||||
# 验证用户 ID 在合理范围内
|
||||
if 0 < user_id < 2**63:
|
||||
allowed_users.add(user_id)
|
||||
|
||||
aria2 = Aria2Config(
|
||||
rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")),
|
||||
rpc_port=rpc_port,
|
||||
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
|
||||
)
|
||||
return cls(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Telegram bot command handlers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
||||
|
||||
@@ -55,6 +57,28 @@ def _get_user_info(update: Update) -> str:
|
||||
return "未知用户"
|
||||
|
||||
|
||||
def _validate_download_url(url: str) -> tuple[bool, str]:
|
||||
"""验证下载 URL 的有效性,防止恶意输入"""
|
||||
# 检查 URL 长度
|
||||
if len(url) > 2048:
|
||||
return False, "URL 过长(最大 2048 字符)"
|
||||
|
||||
# 磁力链接直接通过
|
||||
if url.startswith("magnet:"):
|
||||
return True, ""
|
||||
|
||||
# 验证 HTTP/HTTPS URL
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False, f"不支持的协议: {parsed.scheme or '无'},仅支持 HTTP/HTTPS/磁力链接"
|
||||
if not parsed.netloc:
|
||||
return False, "无效的 URL 格式"
|
||||
return True, ""
|
||||
except Exception:
|
||||
return False, "URL 解析失败"
|
||||
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
@@ -400,6 +424,13 @@ class Aria2BotAPI:
|
||||
return
|
||||
|
||||
url = context.args[0]
|
||||
|
||||
# 验证 URL 格式
|
||||
is_valid, error_msg = _validate_download_url(url)
|
||||
if not is_valid:
|
||||
await self._reply(update, context, f"❌ URL 无效: {error_msg}")
|
||||
return
|
||||
|
||||
try:
|
||||
rpc = self._get_rpc_client()
|
||||
gid = await rpc.add_uri(url)
|
||||
@@ -488,8 +519,20 @@ class Aria2BotAPI:
|
||||
return
|
||||
|
||||
parts = data.split(":")
|
||||
if not parts:
|
||||
await query.edit_message_text("❌ 无效操作")
|
||||
return
|
||||
action = parts[0]
|
||||
|
||||
# 安全检查:验证回调数据格式,防止索引越界
|
||||
required_parts = {
|
||||
"pause": 2, "resume": 2, "delete": 2, "detail": 2, "refresh": 2,
|
||||
"confirm_del": 3, "cancel_del": 3,
|
||||
}
|
||||
if action in required_parts and len(parts) < required_parts[action]:
|
||||
await query.edit_message_text("❌ 无效操作")
|
||||
return
|
||||
|
||||
# 点击非详情相关按钮时,停止该消息的自动刷新
|
||||
if action not in ("detail", "refresh", "pause", "resume"):
|
||||
key = f"{query.message.chat_id}:{query.message.message_id}"
|
||||
@@ -652,10 +695,12 @@ class Aria2BotAPI:
|
||||
await query.edit_message_text(msg, parse_mode="Markdown")
|
||||
|
||||
def _stop_auto_refresh(self, key: str) -> None:
|
||||
"""停止自动刷新任务"""
|
||||
"""停止自动刷新任务并等待清理"""
|
||||
if key in self._auto_refresh_tasks:
|
||||
self._auto_refresh_tasks[key].cancel()
|
||||
del self._auto_refresh_tasks[key]
|
||||
task = self._auto_refresh_tasks.pop(key)
|
||||
task.cancel()
|
||||
# 注意:这里不等待任务完成,因为是同步方法
|
||||
# 任务会在 finally 块中自行清理
|
||||
|
||||
async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
|
||||
"""处理详情回调,启动自动刷新"""
|
||||
@@ -702,7 +747,8 @@ class Aria2BotAPI:
|
||||
try:
|
||||
await message.edit_text(text, parse_mode="Markdown", reply_markup=keyboard)
|
||||
last_text = text
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败 (GID={gid}): {e}")
|
||||
break
|
||||
|
||||
# 任务完成或出错时停止刷新
|
||||
|
||||
Reference in New Issue
Block a user