From cdcbf7d7cbf7a9e5951fdbab37de111b5e8485fe Mon Sep 17 00:00:00 2001 From: dnslin Date: Fri, 12 Dec 2025 11:12:32 +0800 Subject: [PATCH] =?UTF-8?q?fix(security):=20=E5=A2=9E=E5=8A=A0=E5=A4=9A?= =?UTF-8?q?=E5=A4=84=E5=AE=89=E5=85=A8=E6=A3=80=E6=9F=A5=E9=98=B2=E6=AD=A2?= =?UTF-8?q?=E8=B7=AF=E5=BE=84=E9=81=8D=E5=8E=86=E5=92=8C=E6=81=B6=E6=84=8F?= =?UTF-8?q?=E8=BE=93=E5=85=A5=EF=BC=8C=E5=AE=8C=E5=96=84=E8=B5=84=E6=BA=90?= =?UTF-8?q?=E5=85=B3=E9=97=AD=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/aria2/installer.py | 13 ++++++++++ src/aria2/rpc.py | 8 +++++- src/core/config.py | 28 ++++++++++++++++----- src/telegram/handlers.py | 54 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 92 insertions(+), 11 deletions(-) diff --git a/src/aria2/installer.py b/src/aria2/installer.py index 5a1b96e..d05d631 100644 --- a/src/aria2/installer.py +++ b/src/aria2/installer.py @@ -50,6 +50,15 @@ class Aria2Installer: self.arch = detect_arch() self._executor = ThreadPoolExecutor(max_workers=4) + def __del__(self): + """确保线程池被关闭,防止资源泄漏""" + if hasattr(self, '_executor'): + self._executor.shutdown(wait=False) + + def close(self): + """显式关闭资源""" + self._executor.shutdown(wait=True) + async def get_latest_version(self) -> str: """从 GitHub API 获取最新版本号""" logger.info("正在获取 aria2 最新版本...") @@ -281,6 +290,10 @@ class Aria2Installer: @staticmethod def _extract_binary(archive_path: Path, extract_dir: Path) -> Path: with tarfile.open(archive_path, "r:gz") as tar: + # 安全检查:验证所有成员路径,防止 Zip Slip 攻击 + for member in tar.getmembers(): + if member.name.startswith('/') or '..' in member.name: + raise DownloadError(f"不安全的 tar 成员: {member.name}") tar.extractall(extract_dir) for candidate in extract_dir.rglob("aria2c"): if candidate.is_file(): diff --git a/src/aria2/rpc.py b/src/aria2/rpc.py index 323774a..564570c 100644 --- a/src/aria2/rpc.py +++ b/src/aria2/rpc.py @@ -180,7 +180,13 @@ class Aria2RpcClient: if not task.dir or not task.name: return False try: - file_path = Path(task.dir) / task.name + file_path = (Path(task.dir) / task.name).resolve() + # 安全检查:验证路径在下载目录内,防止路径遍历攻击 + from src.core.constants import DOWNLOAD_DIR + download_dir = DOWNLOAD_DIR.resolve() + if not str(file_path).startswith(str(download_dir) + "/"): + logger.error(f"路径遍历尝试被阻止: {file_path}") + return False if file_path.exists(): if file_path.is_dir(): import shutil diff --git a/src/core/config.py b/src/core/config.py index 98faad5..281bf9f 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -29,21 +29,37 @@ class BotConfig: def from_env(cls) -> "BotConfig": """从环境变量加载配置""" from dotenv import load_dotenv + from src.core.exceptions import ConfigError load_dotenv() - token = os.environ.get("TELEGRAM_BOT_TOKEN", "") + # 验证必需的 Token + token = os.environ.get("TELEGRAM_BOT_TOKEN", "").strip() + if not token: + raise ConfigError("TELEGRAM_BOT_TOKEN 环境变量未设置") + + # 安全解析 RPC 端口 + port_str = os.environ.get("ARIA2_RPC_PORT", "6800") + try: + rpc_port = int(port_str) + if not (1 <= rpc_port <= 65535): + raise ValueError("端口必须在 1-65535 范围内") + except ValueError as e: + raise ConfigError(f"无效的 ARIA2_RPC_PORT: {e}") from e # 解析允许的用户 ID 列表 allowed_users_str = os.environ.get("ALLOWED_USERS", "") allowed_users = set() if allowed_users_str: - allowed_users = { - int(uid.strip()) for uid in allowed_users_str.split(",") - if uid.strip().isdigit() - } + for uid in allowed_users_str.split(","): + uid = uid.strip() + if uid.isdigit(): + user_id = int(uid) + # 验证用户 ID 在合理范围内 + if 0 < user_id < 2**63: + allowed_users.add(user_id) aria2 = Aria2Config( - rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")), + rpc_port=rpc_port, rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""), ) return cls( diff --git a/src/telegram/handlers.py b/src/telegram/handlers.py index 9e29aa8..c5917de 100644 --- a/src/telegram/handlers.py +++ b/src/telegram/handlers.py @@ -1,6 +1,8 @@ """Telegram bot command handlers.""" from __future__ import annotations +from urllib.parse import urlparse + from telegram import Update from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters @@ -55,6 +57,28 @@ def _get_user_info(update: Update) -> str: return "未知用户" +def _validate_download_url(url: str) -> tuple[bool, str]: + """验证下载 URL 的有效性,防止恶意输入""" + # 检查 URL 长度 + if len(url) > 2048: + return False, "URL 过长(最大 2048 字符)" + + # 磁力链接直接通过 + if url.startswith("magnet:"): + return True, "" + + # 验证 HTTP/HTTPS URL + try: + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return False, f"不支持的协议: {parsed.scheme or '无'},仅支持 HTTP/HTTPS/磁力链接" + if not parsed.netloc: + return False, "无效的 URL 格式" + return True, "" + except Exception: + return False, "URL 解析失败" + + import asyncio from functools import wraps @@ -400,6 +424,13 @@ class Aria2BotAPI: return url = context.args[0] + + # 验证 URL 格式 + is_valid, error_msg = _validate_download_url(url) + if not is_valid: + await self._reply(update, context, f"❌ URL 无效: {error_msg}") + return + try: rpc = self._get_rpc_client() gid = await rpc.add_uri(url) @@ -488,8 +519,20 @@ class Aria2BotAPI: return parts = data.split(":") + if not parts: + await query.edit_message_text("❌ 无效操作") + return action = parts[0] + # 安全检查:验证回调数据格式,防止索引越界 + required_parts = { + "pause": 2, "resume": 2, "delete": 2, "detail": 2, "refresh": 2, + "confirm_del": 3, "cancel_del": 3, + } + if action in required_parts and len(parts) < required_parts[action]: + await query.edit_message_text("❌ 无效操作") + return + # 点击非详情相关按钮时,停止该消息的自动刷新 if action not in ("detail", "refresh", "pause", "resume"): key = f"{query.message.chat_id}:{query.message.message_id}" @@ -652,10 +695,12 @@ class Aria2BotAPI: await query.edit_message_text(msg, parse_mode="Markdown") def _stop_auto_refresh(self, key: str) -> None: - """停止自动刷新任务""" + """停止自动刷新任务并等待清理""" if key in self._auto_refresh_tasks: - self._auto_refresh_tasks[key].cancel() - del self._auto_refresh_tasks[key] + task = self._auto_refresh_tasks.pop(key) + task.cancel() + # 注意:这里不等待任务完成,因为是同步方法 + # 任务会在 finally 块中自行清理 async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None: """处理详情回调,启动自动刷新""" @@ -702,7 +747,8 @@ class Aria2BotAPI: try: await message.edit_text(text, parse_mode="Markdown", reply_markup=keyboard) last_text = text - except Exception: + except Exception as e: + logger.warning(f"编辑消息失败 (GID={gid}): {e}") break # 任务完成或出错时停止刷新