feat(auth): 添加 Bot 用户白名单功能及权限校验机制

This commit is contained in:
dnslin
2025-12-12 10:40:04 +08:00
parent cba19b5fe0
commit bc3ec59c26
4 changed files with 62 additions and 20 deletions

View File

@@ -1,6 +1,10 @@
# Telegram Bot Token (required)
TELEGRAM_BOT_TOKEN=
# 允许使用 Bot 的用户 ID 列表(逗号分隔,必须配置,否则拒绝所有用户)
# 获取用户 ID向 @userinfobot 发送消息
ALLOWED_USERS=
# Custom Telegram Bot API URL (optional, for self-hosted API)
TELEGRAM_API_BASE_URL=
# Aria2 RPC Port (default: 6800)

View File

@@ -22,6 +22,7 @@ class Aria2Config:
class BotConfig:
token: str = ""
api_base_url: str = ""
allowed_users: set[int] = field(default_factory=set)
aria2: Aria2Config = field(default_factory=Aria2Config)
@classmethod
@@ -31,6 +32,16 @@ class BotConfig:
load_dotenv()
token = os.environ.get("TELEGRAM_BOT_TOKEN", "")
# 解析允许的用户 ID 列表
allowed_users_str = os.environ.get("ALLOWED_USERS", "")
allowed_users = set()
if allowed_users_str:
allowed_users = {
int(uid.strip()) for uid in allowed_users_str.split(",")
if uid.strip().isdigit()
}
aria2 = Aria2Config(
rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")),
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
@@ -38,5 +49,6 @@ class BotConfig:
return cls(
token=token,
api_base_url=os.environ.get("TELEGRAM_API_BASE_URL", ""),
allowed_users=allowed_users,
aria2=aria2,
)

View File

@@ -45,7 +45,7 @@ def create_app(config: BotConfig) -> Application:
builder = builder.base_url(config.api_base_url).base_file_url(config.api_base_url + "/file")
app = builder.build()
api = Aria2BotAPI(config.aria2)
api = Aria2BotAPI(config.aria2, config.allowed_users)
for handler in build_handlers(api):
app.add_handler(handler)

View File

@@ -56,15 +56,31 @@ def _get_user_info(update: Update) -> str:
import asyncio
from functools import wraps
class Aria2BotAPI:
def __init__(self, config: Aria2Config | None = None):
def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None):
self.config = config or Aria2Config()
self.allowed_users = allowed_users or set()
self.installer = Aria2Installer(self.config)
self.service = Aria2ServiceManager()
self._rpc: Aria2RpcClient | None = None
self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task
async def _check_permission(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool:
"""检查用户权限,返回 True 表示有权限"""
# 未配置白名单时拒绝所有用户
if not self.allowed_users:
logger.warning(f"未配置 ALLOWED_USERS拒绝访问 - {_get_user_info(update)}")
await self._reply(update, context, "⚠️ Bot 未配置允许的用户,请联系管理员")
return False
user_id = update.effective_user.id if update.effective_user else None
if user_id and user_id in self.allowed_users:
return True
logger.warning(f"未授权访问 - {_get_user_info(update)}")
await self._reply(update, context, "🚫 您没有权限使用此 Bot")
return False
def _get_rpc_client(self) -> Aria2RpcClient:
"""获取或创建 RPC 客户端"""
if self._rpc is None:
@@ -715,31 +731,41 @@ class Aria2BotAPI:
def build_handlers(api: Aria2BotAPI) -> list:
"""构建 Handler 列表"""
def wrap_with_permission(handler_func):
"""包装处理函数,添加权限检查"""
@wraps(handler_func)
async def wrapped(update: Update, context: ContextTypes.DEFAULT_TYPE):
if not await api._check_permission(update, context):
return
return await handler_func(update, context)
return wrapped
# 构建按钮文本过滤器
button_pattern = "^(" + "|".join(BUTTON_COMMANDS.keys()).replace("▶️", "▶️").replace("", "") + ")$"
return [
# 服务管理命令
CommandHandler("install", api.install),
CommandHandler("uninstall", api.uninstall),
CommandHandler("start", api.start_service),
CommandHandler("stop", api.stop_service),
CommandHandler("restart", api.restart_service),
CommandHandler("status", api.status),
CommandHandler("logs", api.view_logs),
CommandHandler("clear_logs", api.clear_logs),
CommandHandler("set_secret", api.set_secret),
CommandHandler("reset_secret", api.reset_secret),
CommandHandler("help", api.help_command),
CommandHandler("menu", api.menu_command),
CommandHandler("install", wrap_with_permission(api.install)),
CommandHandler("uninstall", wrap_with_permission(api.uninstall)),
CommandHandler("start", wrap_with_permission(api.start_service)),
CommandHandler("stop", wrap_with_permission(api.stop_service)),
CommandHandler("restart", wrap_with_permission(api.restart_service)),
CommandHandler("status", wrap_with_permission(api.status)),
CommandHandler("logs", wrap_with_permission(api.view_logs)),
CommandHandler("clear_logs", wrap_with_permission(api.clear_logs)),
CommandHandler("set_secret", wrap_with_permission(api.set_secret)),
CommandHandler("reset_secret", wrap_with_permission(api.reset_secret)),
CommandHandler("help", wrap_with_permission(api.help_command)),
CommandHandler("menu", wrap_with_permission(api.menu_command)),
# 下载管理命令
CommandHandler("add", api.add_download),
CommandHandler("list", api.list_downloads),
CommandHandler("stats", api.global_stats),
CommandHandler("add", wrap_with_permission(api.add_download)),
CommandHandler("list", wrap_with_permission(api.list_downloads)),
CommandHandler("stats", wrap_with_permission(api.global_stats)),
# Reply Keyboard 按钮文本处理
MessageHandler(filters.TEXT & filters.Regex(button_pattern), api.handle_button_text),
MessageHandler(filters.TEXT & filters.Regex(button_pattern), wrap_with_permission(api.handle_button_text)),
# 种子文件处理
MessageHandler(filters.Document.FileExtension("torrent"), api.handle_torrent),
MessageHandler(filters.Document.FileExtension("torrent"), wrap_with_permission(api.handle_torrent)),
# Callback Query 处理
CallbackQueryHandler(api.handle_callback),
CallbackQueryHandler(wrap_with_permission(api.handle_callback)),
]