mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-11 20:12:20 +08:00
fix(security): 增加多处安全检查防止路径遍历和恶意输入,完善资源关闭逻辑
This commit is contained in:
@@ -50,6 +50,15 @@ class Aria2Installer:
|
|||||||
self.arch = detect_arch()
|
self.arch = detect_arch()
|
||||||
self._executor = ThreadPoolExecutor(max_workers=4)
|
self._executor = ThreadPoolExecutor(max_workers=4)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""确保线程池被关闭,防止资源泄漏"""
|
||||||
|
if hasattr(self, '_executor'):
|
||||||
|
self._executor.shutdown(wait=False)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""显式关闭资源"""
|
||||||
|
self._executor.shutdown(wait=True)
|
||||||
|
|
||||||
async def get_latest_version(self) -> str:
|
async def get_latest_version(self) -> str:
|
||||||
"""从 GitHub API 获取最新版本号"""
|
"""从 GitHub API 获取最新版本号"""
|
||||||
logger.info("正在获取 aria2 最新版本...")
|
logger.info("正在获取 aria2 最新版本...")
|
||||||
@@ -281,6 +290,10 @@ class Aria2Installer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_binary(archive_path: Path, extract_dir: Path) -> Path:
|
def _extract_binary(archive_path: Path, extract_dir: Path) -> Path:
|
||||||
with tarfile.open(archive_path, "r:gz") as tar:
|
with tarfile.open(archive_path, "r:gz") as tar:
|
||||||
|
# 安全检查:验证所有成员路径,防止 Zip Slip 攻击
|
||||||
|
for member in tar.getmembers():
|
||||||
|
if member.name.startswith('/') or '..' in member.name:
|
||||||
|
raise DownloadError(f"不安全的 tar 成员: {member.name}")
|
||||||
tar.extractall(extract_dir)
|
tar.extractall(extract_dir)
|
||||||
for candidate in extract_dir.rglob("aria2c"):
|
for candidate in extract_dir.rglob("aria2c"):
|
||||||
if candidate.is_file():
|
if candidate.is_file():
|
||||||
|
|||||||
@@ -180,7 +180,13 @@ class Aria2RpcClient:
|
|||||||
if not task.dir or not task.name:
|
if not task.dir or not task.name:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
file_path = Path(task.dir) / task.name
|
file_path = (Path(task.dir) / task.name).resolve()
|
||||||
|
# 安全检查:验证路径在下载目录内,防止路径遍历攻击
|
||||||
|
from src.core.constants import DOWNLOAD_DIR
|
||||||
|
download_dir = DOWNLOAD_DIR.resolve()
|
||||||
|
if not str(file_path).startswith(str(download_dir) + "/"):
|
||||||
|
logger.error(f"路径遍历尝试被阻止: {file_path}")
|
||||||
|
return False
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
if file_path.is_dir():
|
if file_path.is_dir():
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
@@ -29,21 +29,37 @@ class BotConfig:
|
|||||||
def from_env(cls) -> "BotConfig":
|
def from_env(cls) -> "BotConfig":
|
||||||
"""从环境变量加载配置"""
|
"""从环境变量加载配置"""
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from src.core.exceptions import ConfigError
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
token = os.environ.get("TELEGRAM_BOT_TOKEN", "")
|
# 验证必需的 Token
|
||||||
|
token = os.environ.get("TELEGRAM_BOT_TOKEN", "").strip()
|
||||||
|
if not token:
|
||||||
|
raise ConfigError("TELEGRAM_BOT_TOKEN 环境变量未设置")
|
||||||
|
|
||||||
|
# 安全解析 RPC 端口
|
||||||
|
port_str = os.environ.get("ARIA2_RPC_PORT", "6800")
|
||||||
|
try:
|
||||||
|
rpc_port = int(port_str)
|
||||||
|
if not (1 <= rpc_port <= 65535):
|
||||||
|
raise ValueError("端口必须在 1-65535 范围内")
|
||||||
|
except ValueError as e:
|
||||||
|
raise ConfigError(f"无效的 ARIA2_RPC_PORT: {e}") from e
|
||||||
|
|
||||||
# 解析允许的用户 ID 列表
|
# 解析允许的用户 ID 列表
|
||||||
allowed_users_str = os.environ.get("ALLOWED_USERS", "")
|
allowed_users_str = os.environ.get("ALLOWED_USERS", "")
|
||||||
allowed_users = set()
|
allowed_users = set()
|
||||||
if allowed_users_str:
|
if allowed_users_str:
|
||||||
allowed_users = {
|
for uid in allowed_users_str.split(","):
|
||||||
int(uid.strip()) for uid in allowed_users_str.split(",")
|
uid = uid.strip()
|
||||||
if uid.strip().isdigit()
|
if uid.isdigit():
|
||||||
}
|
user_id = int(uid)
|
||||||
|
# 验证用户 ID 在合理范围内
|
||||||
|
if 0 < user_id < 2**63:
|
||||||
|
allowed_users.add(user_id)
|
||||||
|
|
||||||
aria2 = Aria2Config(
|
aria2 = Aria2Config(
|
||||||
rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")),
|
rpc_port=rpc_port,
|
||||||
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
|
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Telegram bot command handlers."""
|
"""Telegram bot command handlers."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
from telegram.ext import ContextTypes, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
||||||
|
|
||||||
@@ -55,6 +57,28 @@ def _get_user_info(update: Update) -> str:
|
|||||||
return "未知用户"
|
return "未知用户"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_download_url(url: str) -> tuple[bool, str]:
|
||||||
|
"""验证下载 URL 的有效性,防止恶意输入"""
|
||||||
|
# 检查 URL 长度
|
||||||
|
if len(url) > 2048:
|
||||||
|
return False, "URL 过长(最大 2048 字符)"
|
||||||
|
|
||||||
|
# 磁力链接直接通过
|
||||||
|
if url.startswith("magnet:"):
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
# 验证 HTTP/HTTPS URL
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return False, f"不支持的协议: {parsed.scheme or '无'},仅支持 HTTP/HTTPS/磁力链接"
|
||||||
|
if not parsed.netloc:
|
||||||
|
return False, "无效的 URL 格式"
|
||||||
|
return True, ""
|
||||||
|
except Exception:
|
||||||
|
return False, "URL 解析失败"
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
@@ -400,6 +424,13 @@ class Aria2BotAPI:
|
|||||||
return
|
return
|
||||||
|
|
||||||
url = context.args[0]
|
url = context.args[0]
|
||||||
|
|
||||||
|
# 验证 URL 格式
|
||||||
|
is_valid, error_msg = _validate_download_url(url)
|
||||||
|
if not is_valid:
|
||||||
|
await self._reply(update, context, f"❌ URL 无效: {error_msg}")
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rpc = self._get_rpc_client()
|
rpc = self._get_rpc_client()
|
||||||
gid = await rpc.add_uri(url)
|
gid = await rpc.add_uri(url)
|
||||||
@@ -488,8 +519,20 @@ class Aria2BotAPI:
|
|||||||
return
|
return
|
||||||
|
|
||||||
parts = data.split(":")
|
parts = data.split(":")
|
||||||
|
if not parts:
|
||||||
|
await query.edit_message_text("❌ 无效操作")
|
||||||
|
return
|
||||||
action = parts[0]
|
action = parts[0]
|
||||||
|
|
||||||
|
# 安全检查:验证回调数据格式,防止索引越界
|
||||||
|
required_parts = {
|
||||||
|
"pause": 2, "resume": 2, "delete": 2, "detail": 2, "refresh": 2,
|
||||||
|
"confirm_del": 3, "cancel_del": 3,
|
||||||
|
}
|
||||||
|
if action in required_parts and len(parts) < required_parts[action]:
|
||||||
|
await query.edit_message_text("❌ 无效操作")
|
||||||
|
return
|
||||||
|
|
||||||
# 点击非详情相关按钮时,停止该消息的自动刷新
|
# 点击非详情相关按钮时,停止该消息的自动刷新
|
||||||
if action not in ("detail", "refresh", "pause", "resume"):
|
if action not in ("detail", "refresh", "pause", "resume"):
|
||||||
key = f"{query.message.chat_id}:{query.message.message_id}"
|
key = f"{query.message.chat_id}:{query.message.message_id}"
|
||||||
@@ -652,10 +695,12 @@ class Aria2BotAPI:
|
|||||||
await query.edit_message_text(msg, parse_mode="Markdown")
|
await query.edit_message_text(msg, parse_mode="Markdown")
|
||||||
|
|
||||||
def _stop_auto_refresh(self, key: str) -> None:
|
def _stop_auto_refresh(self, key: str) -> None:
|
||||||
"""停止自动刷新任务"""
|
"""停止自动刷新任务并等待清理"""
|
||||||
if key in self._auto_refresh_tasks:
|
if key in self._auto_refresh_tasks:
|
||||||
self._auto_refresh_tasks[key].cancel()
|
task = self._auto_refresh_tasks.pop(key)
|
||||||
del self._auto_refresh_tasks[key]
|
task.cancel()
|
||||||
|
# 注意:这里不等待任务完成,因为是同步方法
|
||||||
|
# 任务会在 finally 块中自行清理
|
||||||
|
|
||||||
async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
|
async def _handle_detail_callback(self, query, rpc: Aria2RpcClient, gid: str) -> None:
|
||||||
"""处理详情回调,启动自动刷新"""
|
"""处理详情回调,启动自动刷新"""
|
||||||
@@ -702,7 +747,8 @@ class Aria2BotAPI:
|
|||||||
try:
|
try:
|
||||||
await message.edit_text(text, parse_mode="Markdown", reply_markup=keyboard)
|
await message.edit_text(text, parse_mode="Markdown", reply_markup=keyboard)
|
||||||
last_text = text
|
last_text = text
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败 (GID={gid}): {e}")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 任务完成或出错时停止刷新
|
# 任务完成或出错时停止刷新
|
||||||
|
|||||||
Reference in New Issue
Block a user