fix(security): 增加多处安全检查防止路径遍历和恶意输入,完善资源关闭逻辑

This commit is contained in:
dnslin
2025-12-12 11:12:32 +08:00
parent 7e8317c970
commit cdcbf7d7cb
4 changed files with 92 additions and 11 deletions

View File

@@ -50,6 +50,15 @@ class Aria2Installer:
self.arch = detect_arch()
self._executor = ThreadPoolExecutor(max_workers=4)
def __del__(self):
"""确保线程池被关闭,防止资源泄漏"""
if hasattr(self, '_executor'):
self._executor.shutdown(wait=False)
def close(self):
"""显式关闭资源"""
self._executor.shutdown(wait=True)
async def get_latest_version(self) -> str:
"""从 GitHub API 获取最新版本号"""
logger.info("正在获取 aria2 最新版本...")
@@ -281,6 +290,10 @@ class Aria2Installer:
@staticmethod
def _extract_binary(archive_path: Path, extract_dir: Path) -> Path:
with tarfile.open(archive_path, "r:gz") as tar:
# 安全检查:验证所有成员路径,防止 Zip Slip 攻击
for member in tar.getmembers():
if member.name.startswith('/') or '..' in member.name:
raise DownloadError(f"不安全的 tar 成员: {member.name}")
tar.extractall(extract_dir)
for candidate in extract_dir.rglob("aria2c"):
if candidate.is_file():

View File

@@ -180,7 +180,13 @@ class Aria2RpcClient:
if not task.dir or not task.name:
return False
try:
file_path = Path(task.dir) / task.name
file_path = (Path(task.dir) / task.name).resolve()
# 安全检查:验证路径在下载目录内,防止路径遍历攻击
from src.core.constants import DOWNLOAD_DIR
download_dir = DOWNLOAD_DIR.resolve()
if not str(file_path).startswith(str(download_dir) + "/"):
logger.error(f"路径遍历尝试被阻止: {file_path}")
return False
if file_path.exists():
if file_path.is_dir():
import shutil

View File

@@ -29,21 +29,37 @@ class BotConfig:
def from_env(cls) -> "BotConfig":
"""从环境变量加载配置"""
from dotenv import load_dotenv
from src.core.exceptions import ConfigError
load_dotenv()
token = os.environ.get("TELEGRAM_BOT_TOKEN", "")
# 验证必需的 Token
token = os.environ.get("TELEGRAM_BOT_TOKEN", "").strip()
if not token:
raise ConfigError("TELEGRAM_BOT_TOKEN 环境变量未设置")
# 安全解析 RPC 端口
port_str = os.environ.get("ARIA2_RPC_PORT", "6800")
try:
rpc_port = int(port_str)
if not (1 <= rpc_port <= 65535):
raise ValueError("端口必须在 1-65535 范围内")
except ValueError as e:
raise ConfigError(f"无效的 ARIA2_RPC_PORT: {e}") from e
# 解析允许的用户 ID 列表
allowed_users_str = os.environ.get("ALLOWED_USERS", "")
allowed_users = set()
if allowed_users_str:
allowed_users = {
int(uid.strip()) for uid in allowed_users_str.split(",")
if uid.strip().isdigit()
}
for uid in allowed_users_str.split(","):
uid = uid.strip()
if uid.isdigit():
user_id = int(uid)
# 验证用户 ID 在合理范围内
if 0 < user_id < 2**63:
allowed_users.add(user_id)
aria2 = Aria2Config(
rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")),
rpc_port=rpc_port,
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
)
return cls(

View File

@@ -1,6 +1,8 @@
"""Telegram bot command handlers."""
from __future__ import annotations
from urllib.parse import urlparse
from telegram import Update
from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters
@@ -55,6 +57,28 @@ def _get_user_info(update: Update) -> str:
return "未知用户"
def _validate_download_url(url: str) -> tuple[bool, str]:
"""验证下载 URL 的有效性,防止恶意输入"""
# 检查 URL 长度
if len(url) > 2048:
return False, "URL 过长(最大 2048 字符)"
# 磁力链接直接通过
if url.startswith("magnet:"):
return True, ""
# 验证 HTTP/HTTPS URL
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False, f"不支持的协议: {parsed.scheme or ''},仅支持 HTTP/HTTPS/磁力链接"
if not parsed.netloc:
return False, "无效的 URL 格式"
return True, ""
except Exception:
return False, "URL 解析失败"
import asyncio
from functools import wraps
@@ -400,6 +424,13 @@ class Aria2BotAPI:
return
url = context.args[0]
# 验证 URL 格式
is_valid, error_msg = _validate_download_url(url)
if not is_valid:
await self._reply(update, context, f"❌ URL 无效: {error_msg}")
return
try:
rpc = self._get_rpc_client()
gid = await rpc.add_uri(url)
@@ -488,8 +519,20 @@ class Aria2BotAPI:
return
parts = data.split(":")
if not parts:
await query.edit_message_text("❌ 无效操作")
return
action = parts[0]
# 安全检查:验证回调数据格式,防止索引越界
required_parts = {
"pause": 2, "resume": 2, "delete": 2, "detail": 2, "refresh": 2,
"confirm_del": 3, "cancel_del": 3,
}
if action in required_parts and len(parts) < required_parts[action]:
await query.edit_message_text("❌ 无效操作")
return
# 点击非详情相关按钮时,停止该消息的自动刷新
if action not in ("detail", "refresh", "pause", "resume"):
key = f"{query.message.chat_id}:{query.message.message_id}"
@@ -652,10 +695,12 @@ class Aria2BotAPI:
await query.edit_message_text(msg, parse_mode="Markdown")
def _stop_auto_refresh(self, key: str) -> None:
"""停止自动刷新任务"""
"""停止自动刷新任务并等待清理"""
if key in self._auto_refresh_tasks:
self._auto_refresh_tasks[key].cancel()
del self._auto_refresh_tasks[key]
task = self._auto_refresh_tasks.pop(key)
task.cancel()
# 注意:这里不等待任务完成,因为是同步方法
# 任务会在 finally 块中自行清理
async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
"""处理详情回调,启动自动刷新"""
@@ -702,7 +747,8 @@ class Aria2BotAPI:
try:
await message.edit_text(text, parse_mode="Markdown", reply_markup=keyboard)
last_text = text
except Exception:
except Exception as e:
logger.warning(f"编辑消息失败 (GID={gid}): {e}")
break
# 任务完成或出错时停止刷新