mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-11 04:02:20 +08:00
feat(auth): 添加 Bot 用户白名单功能及权限校验机制
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
# Telegram Bot Token (required)
|
||||
TELEGRAM_BOT_TOKEN=
|
||||
|
||||
# 允许使用 Bot 的用户 ID 列表(逗号分隔,必须配置,否则拒绝所有用户)
|
||||
# 获取用户 ID:向 @userinfobot 发送消息
|
||||
ALLOWED_USERS=
|
||||
|
||||
# Custom Telegram Bot API URL (optional, for self-hosted API)
|
||||
TELEGRAM_API_BASE_URL=
|
||||
# Aria2 RPC Port (default: 6800)
|
||||
|
||||
@@ -22,6 +22,7 @@ class Aria2Config:
|
||||
class BotConfig:
|
||||
token: str = ""
|
||||
api_base_url: str = ""
|
||||
allowed_users: set[int] = field(default_factory=set)
|
||||
aria2: Aria2Config = field(default_factory=Aria2Config)
|
||||
|
||||
@classmethod
|
||||
@@ -31,6 +32,16 @@ class BotConfig:
|
||||
load_dotenv()
|
||||
|
||||
token = os.environ.get("TELEGRAM_BOT_TOKEN", "")
|
||||
|
||||
# 解析允许的用户 ID 列表
|
||||
allowed_users_str = os.environ.get("ALLOWED_USERS", "")
|
||||
allowed_users = set()
|
||||
if allowed_users_str:
|
||||
allowed_users = {
|
||||
int(uid.strip()) for uid in allowed_users_str.split(",")
|
||||
if uid.strip().isdigit()
|
||||
}
|
||||
|
||||
aria2 = Aria2Config(
|
||||
rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")),
|
||||
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
|
||||
@@ -38,5 +49,6 @@ class BotConfig:
|
||||
return cls(
|
||||
token=token,
|
||||
api_base_url=os.environ.get("TELEGRAM_API_BASE_URL", ""),
|
||||
allowed_users=allowed_users,
|
||||
aria2=aria2,
|
||||
)
|
||||
|
||||
@@ -45,7 +45,7 @@ def create_app(config: BotConfig) -> Application:
|
||||
builder = builder.base_url(config.api_base_url).base_file_url(config.api_base_url + "/file")
|
||||
app = builder.build()
|
||||
|
||||
api = Aria2BotAPI(config.aria2)
|
||||
api = Aria2BotAPI(config.aria2, config.allowed_users)
|
||||
for handler in build_handlers(api):
|
||||
app.add_handler(handler)
|
||||
|
||||
|
||||
@@ -56,15 +56,31 @@ def _get_user_info(update: Update) -> str:
|
||||
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
class Aria2BotAPI:
|
||||
def __init__(self, config: Aria2Config | None = None):
|
||||
def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None):
|
||||
self.config = config or Aria2Config()
|
||||
self.allowed_users = allowed_users or set()
|
||||
self.installer = Aria2Installer(self.config)
|
||||
self.service = Aria2ServiceManager()
|
||||
self._rpc: Aria2RpcClient | None = None
|
||||
self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task
|
||||
|
||||
async def _check_permission(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool:
|
||||
"""检查用户权限,返回 True 表示有权限"""
|
||||
# 未配置白名单时拒绝所有用户
|
||||
if not self.allowed_users:
|
||||
logger.warning(f"未配置 ALLOWED_USERS,拒绝访问 - {_get_user_info(update)}")
|
||||
await self._reply(update, context, "⚠️ Bot 未配置允许的用户,请联系管理员")
|
||||
return False
|
||||
user_id = update.effective_user.id if update.effective_user else None
|
||||
if user_id and user_id in self.allowed_users:
|
||||
return True
|
||||
logger.warning(f"未授权访问 - {_get_user_info(update)}")
|
||||
await self._reply(update, context, "🚫 您没有权限使用此 Bot")
|
||||
return False
|
||||
|
||||
def _get_rpc_client(self) -> Aria2RpcClient:
|
||||
"""获取或创建 RPC 客户端"""
|
||||
if self._rpc is None:
|
||||
@@ -715,31 +731,41 @@ class Aria2BotAPI:
|
||||
|
||||
def build_handlers(api: Aria2BotAPI) -> list:
|
||||
"""构建 Handler 列表"""
|
||||
|
||||
def wrap_with_permission(handler_func):
|
||||
"""包装处理函数,添加权限检查"""
|
||||
@wraps(handler_func)
|
||||
async def wrapped(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if not await api._check_permission(update, context):
|
||||
return
|
||||
return await handler_func(update, context)
|
||||
return wrapped
|
||||
|
||||
# 构建按钮文本过滤器
|
||||
button_pattern = "^(" + "|".join(BUTTON_COMMANDS.keys()).replace("▶️", "▶️").replace("⏹", "⏹") + ")$"
|
||||
|
||||
return [
|
||||
# 服务管理命令
|
||||
CommandHandler("install", api.install),
|
||||
CommandHandler("uninstall", api.uninstall),
|
||||
CommandHandler("start", api.start_service),
|
||||
CommandHandler("stop", api.stop_service),
|
||||
CommandHandler("restart", api.restart_service),
|
||||
CommandHandler("status", api.status),
|
||||
CommandHandler("logs", api.view_logs),
|
||||
CommandHandler("clear_logs", api.clear_logs),
|
||||
CommandHandler("set_secret", api.set_secret),
|
||||
CommandHandler("reset_secret", api.reset_secret),
|
||||
CommandHandler("help", api.help_command),
|
||||
CommandHandler("menu", api.menu_command),
|
||||
CommandHandler("install", wrap_with_permission(api.install)),
|
||||
CommandHandler("uninstall", wrap_with_permission(api.uninstall)),
|
||||
CommandHandler("start", wrap_with_permission(api.start_service)),
|
||||
CommandHandler("stop", wrap_with_permission(api.stop_service)),
|
||||
CommandHandler("restart", wrap_with_permission(api.restart_service)),
|
||||
CommandHandler("status", wrap_with_permission(api.status)),
|
||||
CommandHandler("logs", wrap_with_permission(api.view_logs)),
|
||||
CommandHandler("clear_logs", wrap_with_permission(api.clear_logs)),
|
||||
CommandHandler("set_secret", wrap_with_permission(api.set_secret)),
|
||||
CommandHandler("reset_secret", wrap_with_permission(api.reset_secret)),
|
||||
CommandHandler("help", wrap_with_permission(api.help_command)),
|
||||
CommandHandler("menu", wrap_with_permission(api.menu_command)),
|
||||
# 下载管理命令
|
||||
CommandHandler("add", api.add_download),
|
||||
CommandHandler("list", api.list_downloads),
|
||||
CommandHandler("stats", api.global_stats),
|
||||
CommandHandler("add", wrap_with_permission(api.add_download)),
|
||||
CommandHandler("list", wrap_with_permission(api.list_downloads)),
|
||||
CommandHandler("stats", wrap_with_permission(api.global_stats)),
|
||||
# Reply Keyboard 按钮文本处理
|
||||
MessageHandler(filters.TEXT & filters.Regex(button_pattern), api.handle_button_text),
|
||||
MessageHandler(filters.TEXT & filters.Regex(button_pattern), wrap_with_permission(api.handle_button_text)),
|
||||
# 种子文件处理
|
||||
MessageHandler(filters.Document.FileExtension("torrent"), api.handle_torrent),
|
||||
MessageHandler(filters.Document.FileExtension("torrent"), wrap_with_permission(api.handle_torrent)),
|
||||
# Callback Query 处理
|
||||
CallbackQueryHandler(api.handle_callback),
|
||||
CallbackQueryHandler(wrap_with_permission(api.handle_callback)),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user