From bc3ec59c2646039f0986dab1419f96a462f81302 Mon Sep 17 00:00:00 2001 From: dnslin Date: Fri, 12 Dec 2025 10:40:04 +0800 Subject: [PATCH] =?UTF-8?q?feat(auth):=20=E6=B7=BB=E5=8A=A0=20Bot=20?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E7=99=BD=E5=90=8D=E5=8D=95=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=8F=8A=E6=9D=83=E9=99=90=E6=A0=A1=E9=AA=8C=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 4 +++ src/core/config.py | 12 ++++++++ src/telegram/app.py | 2 +- src/telegram/handlers.py | 64 ++++++++++++++++++++++++++++------------ 4 files changed, 62 insertions(+), 20 deletions(-) diff --git a/.env.example b/.env.example index 05b9edc..fde0787 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,10 @@ # Telegram Bot Token (required) TELEGRAM_BOT_TOKEN= +# 允许使用 Bot 的用户 ID 列表(逗号分隔,必须配置,否则拒绝所有用户) +# 获取用户 ID:向 @userinfobot 发送消息 +ALLOWED_USERS= + # Custom Telegram Bot API URL (optional, for self-hosted API) TELEGRAM_API_BASE_URL= # Aria2 RPC Port (default: 6800) diff --git a/src/core/config.py b/src/core/config.py index d619e22..98faad5 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -22,6 +22,7 @@ class Aria2Config: class BotConfig: token: str = "" api_base_url: str = "" + allowed_users: set[int] = field(default_factory=set) aria2: Aria2Config = field(default_factory=Aria2Config) @classmethod @@ -31,6 +32,16 @@ class BotConfig: load_dotenv() token = os.environ.get("TELEGRAM_BOT_TOKEN", "") + + # 解析允许的用户 ID 列表 + allowed_users_str = os.environ.get("ALLOWED_USERS", "") + allowed_users = set() + if allowed_users_str: + allowed_users = { + int(uid.strip()) for uid in allowed_users_str.split(",") + if uid.strip().isdigit() + } + aria2 = Aria2Config( rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")), rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""), @@ -38,5 +49,6 @@ class BotConfig: return cls( token=token, api_base_url=os.environ.get("TELEGRAM_API_BASE_URL", ""), + allowed_users=allowed_users, aria2=aria2, ) diff --git a/src/telegram/app.py b/src/telegram/app.py index b51b7d7..14de3d4 100644 --- a/src/telegram/app.py +++ b/src/telegram/app.py @@ -45,7 +45,7 @@ def create_app(config: BotConfig) -> Application: builder = builder.base_url(config.api_base_url).base_file_url(config.api_base_url + "/file") app = builder.build() - api = Aria2BotAPI(config.aria2) + api = Aria2BotAPI(config.aria2, config.allowed_users) for handler in build_handlers(api): app.add_handler(handler) diff --git a/src/telegram/handlers.py b/src/telegram/handlers.py index 623f173..9e29aa8 100644 --- a/src/telegram/handlers.py +++ b/src/telegram/handlers.py @@ -56,15 +56,31 @@ def _get_user_info(update: Update) -> str: import asyncio +from functools import wraps class Aria2BotAPI: - def __init__(self, config: Aria2Config | None = None): + def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None): self.config = config or Aria2Config() + self.allowed_users = allowed_users or set() self.installer = Aria2Installer(self.config) self.service = Aria2ServiceManager() self._rpc: Aria2RpcClient | None = None self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task + async def _check_permission(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool: + """检查用户权限,返回 True 表示有权限""" + # 未配置白名单时拒绝所有用户 + if not self.allowed_users: + logger.warning(f"未配置 ALLOWED_USERS,拒绝访问 - {_get_user_info(update)}") + await self._reply(update, context, "⚠️ Bot 未配置允许的用户,请联系管理员") + return False + user_id = update.effective_user.id if update.effective_user else None + if user_id and user_id in self.allowed_users: + return True + logger.warning(f"未授权访问 - {_get_user_info(update)}") + await self._reply(update, context, "🚫 您没有权限使用此 Bot") + return False + def _get_rpc_client(self) -> Aria2RpcClient: """获取或创建 RPC 客户端""" if self._rpc is None: @@ -715,31 +731,41 @@ class Aria2BotAPI: def build_handlers(api: Aria2BotAPI) -> list: """构建 Handler 列表""" + + def wrap_with_permission(handler_func): + """包装处理函数,添加权限检查""" + @wraps(handler_func) + async def wrapped(update: Update, context: ContextTypes.DEFAULT_TYPE): + if not await api._check_permission(update, context): + return + return await handler_func(update, context) + return wrapped + # 构建按钮文本过滤器 button_pattern = "^(" + "|".join(BUTTON_COMMANDS.keys()).replace("▶️", "▶️").replace("⏹", "⏹") + ")$" return [ # 服务管理命令 - CommandHandler("install", api.install), - CommandHandler("uninstall", api.uninstall), - CommandHandler("start", api.start_service), - CommandHandler("stop", api.stop_service), - CommandHandler("restart", api.restart_service), - CommandHandler("status", api.status), - CommandHandler("logs", api.view_logs), - CommandHandler("clear_logs", api.clear_logs), - CommandHandler("set_secret", api.set_secret), - CommandHandler("reset_secret", api.reset_secret), - CommandHandler("help", api.help_command), - CommandHandler("menu", api.menu_command), + CommandHandler("install", wrap_with_permission(api.install)), + CommandHandler("uninstall", wrap_with_permission(api.uninstall)), + CommandHandler("start", wrap_with_permission(api.start_service)), + CommandHandler("stop", wrap_with_permission(api.stop_service)), + CommandHandler("restart", wrap_with_permission(api.restart_service)), + CommandHandler("status", wrap_with_permission(api.status)), + CommandHandler("logs", wrap_with_permission(api.view_logs)), + CommandHandler("clear_logs", wrap_with_permission(api.clear_logs)), + CommandHandler("set_secret", wrap_with_permission(api.set_secret)), + CommandHandler("reset_secret", wrap_with_permission(api.reset_secret)), + CommandHandler("help", wrap_with_permission(api.help_command)), + CommandHandler("menu", wrap_with_permission(api.menu_command)), # 下载管理命令 - CommandHandler("add", api.add_download), - CommandHandler("list", api.list_downloads), - CommandHandler("stats", api.global_stats), + CommandHandler("add", wrap_with_permission(api.add_download)), + CommandHandler("list", wrap_with_permission(api.list_downloads)), + CommandHandler("stats", wrap_with_permission(api.global_stats)), # Reply Keyboard 按钮文本处理 - MessageHandler(filters.TEXT & filters.Regex(button_pattern), api.handle_button_text), + MessageHandler(filters.TEXT & filters.Regex(button_pattern), wrap_with_permission(api.handle_button_text)), # 种子文件处理 - MessageHandler(filters.Document.FileExtension("torrent"), api.handle_torrent), + MessageHandler(filters.Document.FileExtension("torrent"), wrap_with_permission(api.handle_torrent)), # Callback Query 处理 - CallbackQueryHandler(api.handle_callback), + CallbackQueryHandler(wrap_with_permission(api.handle_callback)), ]