fix(security): 增加多处安全检查防止路径遍历和恶意输入,完善资源关闭逻辑

This commit is contained in:
dnslin
2025-12-12 11:12:32 +08:00
parent 7e8317c970
commit cdcbf7d7cb
4 changed files with 92 additions and 11 deletions

View File

@@ -50,6 +50,15 @@ class Aria2Installer:
self.arch = detect_arch() self.arch = detect_arch()
self._executor = ThreadPoolExecutor(max_workers=4) self._executor = ThreadPoolExecutor(max_workers=4)
def __del__(self):
"""确保线程池被关闭,防止资源泄漏"""
if hasattr(self, '_executor'):
self._executor.shutdown(wait=False)
def close(self):
"""显式关闭资源"""
self._executor.shutdown(wait=True)
async def get_latest_version(self) -> str: async def get_latest_version(self) -> str:
"""从 GitHub API 获取最新版本号""" """从 GitHub API 获取最新版本号"""
logger.info("正在获取 aria2 最新版本...") logger.info("正在获取 aria2 最新版本...")
@@ -281,6 +290,10 @@ class Aria2Installer:
@staticmethod @staticmethod
def _extract_binary(archive_path: Path, extract_dir: Path) -> Path: def _extract_binary(archive_path: Path, extract_dir: Path) -> Path:
with tarfile.open(archive_path, "r:gz") as tar: with tarfile.open(archive_path, "r:gz") as tar:
# 安全检查:验证所有成员路径,防止 Zip Slip 攻击
for member in tar.getmembers():
if member.name.startswith('/') or '..' in member.name:
raise DownloadError(f"不安全的 tar 成员: {member.name}")
tar.extractall(extract_dir) tar.extractall(extract_dir)
for candidate in extract_dir.rglob("aria2c"): for candidate in extract_dir.rglob("aria2c"):
if candidate.is_file(): if candidate.is_file():

View File

@@ -180,7 +180,13 @@ class Aria2RpcClient:
if not task.dir or not task.name: if not task.dir or not task.name:
return False return False
try: try:
file_path = Path(task.dir) / task.name file_path = (Path(task.dir) / task.name).resolve()
# 安全检查:验证路径在下载目录内,防止路径遍历攻击
from src.core.constants import DOWNLOAD_DIR
download_dir = DOWNLOAD_DIR.resolve()
if not str(file_path).startswith(str(download_dir) + "/"):
logger.error(f"路径遍历尝试被阻止: {file_path}")
return False
if file_path.exists(): if file_path.exists():
if file_path.is_dir(): if file_path.is_dir():
import shutil import shutil

View File

@@ -29,21 +29,37 @@ class BotConfig:
def from_env(cls) -> "BotConfig": def from_env(cls) -> "BotConfig":
"""从环境变量加载配置""" """从环境变量加载配置"""
from dotenv import load_dotenv from dotenv import load_dotenv
from src.core.exceptions import ConfigError
load_dotenv() load_dotenv()
token = os.environ.get("TELEGRAM_BOT_TOKEN", "") # 验证必需的 Token
token = os.environ.get("TELEGRAM_BOT_TOKEN", "").strip()
if not token:
raise ConfigError("TELEGRAM_BOT_TOKEN 环境变量未设置")
# 安全解析 RPC 端口
port_str = os.environ.get("ARIA2_RPC_PORT", "6800")
try:
rpc_port = int(port_str)
if not (1 <= rpc_port <= 65535):
raise ValueError("端口必须在 1-65535 范围内")
except ValueError as e:
raise ConfigError(f"无效的 ARIA2_RPC_PORT: {e}") from e
# 解析允许的用户 ID 列表 # 解析允许的用户 ID 列表
allowed_users_str = os.environ.get("ALLOWED_USERS", "") allowed_users_str = os.environ.get("ALLOWED_USERS", "")
allowed_users = set() allowed_users = set()
if allowed_users_str: if allowed_users_str:
allowed_users = { for uid in allowed_users_str.split(","):
int(uid.strip()) for uid in allowed_users_str.split(",") uid = uid.strip()
if uid.strip().isdigit() if uid.isdigit():
} user_id = int(uid)
# 验证用户 ID 在合理范围内
if 0 < user_id < 2**63:
allowed_users.add(user_id)
aria2 = Aria2Config( aria2 = Aria2Config(
rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")), rpc_port=rpc_port,
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""), rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
) )
return cls( return cls(

View File

@@ -1,6 +1,8 @@
"""Telegram bot command handlers.""" """Telegram bot command handlers."""
from __future__ import annotations from __future__ import annotations
from urllib.parse import urlparse
from telegram import Update from telegram import Update
from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters
@@ -55,6 +57,28 @@ def _get_user_info(update: Update) -> str:
return "未知用户" return "未知用户"
def _validate_download_url(url: str) -> tuple[bool, str]:
"""验证下载 URL 的有效性,防止恶意输入"""
# 检查 URL 长度
if len(url) > 2048:
return False, "URL 过长(最大 2048 字符)"
# 磁力链接直接通过
if url.startswith("magnet:"):
return True, ""
# 验证 HTTP/HTTPS URL
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False, f"不支持的协议: {parsed.scheme or ''},仅支持 HTTP/HTTPS/磁力链接"
if not parsed.netloc:
return False, "无效的 URL 格式"
return True, ""
except Exception:
return False, "URL 解析失败"
import asyncio import asyncio
from functools import wraps from functools import wraps
@@ -400,6 +424,13 @@ class Aria2BotAPI:
return return
url = context.args[0] url = context.args[0]
# 验证 URL 格式
is_valid, error_msg = _validate_download_url(url)
if not is_valid:
await self._reply(update, context, f"❌ URL 无效: {error_msg}")
return
try: try:
rpc = self._get_rpc_client() rpc = self._get_rpc_client()
gid = await rpc.add_uri(url) gid = await rpc.add_uri(url)
@@ -488,8 +519,20 @@ class Aria2BotAPI:
return return
parts = data.split(":") parts = data.split(":")
if not parts:
await query.edit_message_text("❌ 无效操作")
return
action = parts[0] action = parts[0]
# 安全检查:验证回调数据格式,防止索引越界
required_parts = {
"pause": 2, "resume": 2, "delete": 2, "detail": 2, "refresh": 2,
"confirm_del": 3, "cancel_del": 3,
}
if action in required_parts and len(parts) < required_parts[action]:
await query.edit_message_text("❌ 无效操作")
return
# 点击非详情相关按钮时,停止该消息的自动刷新 # 点击非详情相关按钮时,停止该消息的自动刷新
if action not in ("detail", "refresh", "pause", "resume"): if action not in ("detail", "refresh", "pause", "resume"):
key = f"{query.message.chat_id}:{query.message.message_id}" key = f"{query.message.chat_id}:{query.message.message_id}"
@@ -652,10 +695,12 @@ class Aria2BotAPI:
await query.edit_message_text(msg, parse_mode="Markdown") await query.edit_message_text(msg, parse_mode="Markdown")
def _stop_auto_refresh(self, key: str) -> None: def _stop_auto_refresh(self, key: str) -> None:
"""停止自动刷新任务""" """停止自动刷新任务并等待清理"""
if key in self._auto_refresh_tasks: if key in self._auto_refresh_tasks:
self._auto_refresh_tasks[key].cancel() task = self._auto_refresh_tasks.pop(key)
del self._auto_refresh_tasks[key] task.cancel()
# 注意:这里不等待任务完成,因为是同步方法
# 任务会在 finally 块中自行清理
async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None: async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
"""处理详情回调,启动自动刷新""" """处理详情回调,启动自动刷新"""
@@ -702,7 +747,8 @@ class Aria2BotAPI:
try: try:
await message.edit_text(text, parse_mode="Markdown", reply_markup=keyboard) await message.edit_text(text, parse_mode="Markdown", reply_markup=keyboard)
last_text = text last_text = text
except Exception: except Exception as e:
logger.warning(f"编辑消息失败 (GID={gid}): {e}")
break break
# 任务完成或出错时停止刷新 # 任务完成或出错时停止刷新