mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-12 04:22:21 +08:00
feat: 增加OneDrive认证功能
This commit is contained in:
@@ -20,7 +20,9 @@ from src.core import (
|
||||
get_aria2_version,
|
||||
generate_rpc_secret,
|
||||
ARIA2_CONF,
|
||||
DOWNLOAD_DIR,
|
||||
)
|
||||
from src.core.config import OneDriveConfig
|
||||
from src.aria2 import Aria2Installer, Aria2ServiceManager
|
||||
from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
|
||||
from src.telegram.keyboards import (
|
||||
@@ -32,6 +34,10 @@ from src.telegram.keyboards import (
|
||||
build_detail_keyboard,
|
||||
build_after_add_keyboard,
|
||||
build_main_reply_keyboard,
|
||||
build_cloud_menu_keyboard,
|
||||
build_upload_choice_keyboard,
|
||||
build_cloud_settings_keyboard,
|
||||
build_detail_keyboard_with_upload,
|
||||
)
|
||||
|
||||
# Reply Keyboard 按钮文本到命令的映射
|
||||
@@ -83,13 +89,18 @@ import asyncio
|
||||
from functools import wraps
|
||||
|
||||
class Aria2BotAPI:
|
||||
def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None):
|
||||
def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None,
|
||||
onedrive_config: OneDriveConfig | None = None):
|
||||
self.config = config or Aria2Config()
|
||||
self.allowed_users = allowed_users or set()
|
||||
self.installer = Aria2Installer(self.config)
|
||||
self.service = Aria2ServiceManager()
|
||||
self._rpc: Aria2RpcClient | None = None
|
||||
self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task
|
||||
# 云存储相关
|
||||
self._onedrive_config = onedrive_config
|
||||
self._onedrive = None
|
||||
self._pending_auth: dict[int, dict] = {} # user_id -> flow
|
||||
|
||||
async def _check_permission(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool:
|
||||
"""检查用户权限,返回 True 表示有权限"""
|
||||
@@ -113,6 +124,13 @@ class Aria2BotAPI:
|
||||
self._rpc = Aria2RpcClient(port=port, secret=secret)
|
||||
return self._rpc
|
||||
|
||||
def _get_onedrive_client(self):
|
||||
"""获取或创建 OneDrive 客户端"""
|
||||
if self._onedrive is None and self._onedrive_config and self._onedrive_config.enabled:
|
||||
from src.cloud.onedrive import OneDriveClient
|
||||
self._onedrive = OneDriveClient(self._onedrive_config)
|
||||
return self._onedrive
|
||||
|
||||
async def _reply(self, update: Update, context: ContextTypes.DEFAULT_TYPE, text: str, **kwargs):
|
||||
if update.effective_message:
|
||||
return await update.effective_message.reply_text(text, **kwargs)
|
||||
@@ -120,6 +138,19 @@ class Aria2BotAPI:
|
||||
return await context.bot.send_message(chat_id=update.effective_chat.id, text=text, **kwargs)
|
||||
return None
|
||||
|
||||
async def _delayed_delete_messages(self, messages: list, delay: int = 5) -> None:
|
||||
"""延迟删除多条消息"""
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
for msg in messages:
|
||||
try:
|
||||
await msg.delete()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除消息失败: {e}")
|
||||
logger.debug("已删除敏感认证消息")
|
||||
except Exception as e:
|
||||
logger.warning(f"延迟删除任务失败: {e}")
|
||||
|
||||
def _get_rpc_secret(self) -> str:
|
||||
if self.config.rpc_secret:
|
||||
return self.config.rpc_secret
|
||||
@@ -378,6 +409,9 @@ class Aria2BotAPI:
|
||||
"/list - 查看下载列表",
|
||||
"/stats - 全局下载统计",
|
||||
"",
|
||||
"*云存储*",
|
||||
"/cloud - 云存储管理菜单",
|
||||
"",
|
||||
"/menu - 显示快捷菜单",
|
||||
"/help - 显示此帮助",
|
||||
]
|
||||
@@ -394,6 +428,169 @@ class Aria2BotAPI:
|
||||
reply_markup=keyboard
|
||||
)
|
||||
|
||||
# === 云存储命令 ===
|
||||
|
||||
async def cloud_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""云存储管理菜单"""
|
||||
logger.info(f"收到 /cloud 命令 - {_get_user_info(update)}")
|
||||
if not self._onedrive_config or not self._onedrive_config.enabled:
|
||||
await self._reply(update, context, "❌ 云存储功能未启用,请在配置中设置 ONEDRIVE_ENABLED=true")
|
||||
return
|
||||
keyboard = build_cloud_menu_keyboard()
|
||||
await self._reply(update, context, "☁️ *云存储管理*", parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
async def cloud_auth(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""开始 OneDrive 认证"""
|
||||
logger.info(f"收到云存储认证请求 - {_get_user_info(update)}")
|
||||
client = self._get_onedrive_client()
|
||||
if not client:
|
||||
await self._reply(update, context, "❌ OneDrive 未配置")
|
||||
return
|
||||
|
||||
if await client.is_authenticated():
|
||||
await self._reply(update, context, "✅ OneDrive 已认证")
|
||||
return
|
||||
|
||||
url, state = await client.get_auth_url()
|
||||
user_id = update.effective_user.id
|
||||
|
||||
auth_message = await self._reply(
|
||||
update, context,
|
||||
f"🔐 *OneDrive 认证*\n\n"
|
||||
f"1\\. 点击下方链接登录 Microsoft 账户\n"
|
||||
f"2\\. 授权后会跳转到一个空白页面\n"
|
||||
f"3\\. 复制该页面的完整 URL 发送给我\n\n"
|
||||
f"[点击认证]({url})",
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
self._pending_auth[user_id] = {"state": state, "message": auth_message}
|
||||
|
||||
async def handle_auth_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""处理用户发送的认证回调 URL"""
|
||||
text = update.message.text
|
||||
if not text or not text.startswith("https://login.microsoftonline.com"):
|
||||
return
|
||||
|
||||
user_id = update.effective_user.id
|
||||
if user_id not in self._pending_auth:
|
||||
return
|
||||
|
||||
client = self._get_onedrive_client()
|
||||
if not client:
|
||||
return
|
||||
|
||||
user_message = update.message # 保存用户消息引用
|
||||
pending = self._pending_auth[user_id]
|
||||
flow = pending["state"]
|
||||
auth_message = pending.get("message") # 认证指引消息
|
||||
|
||||
if await client.authenticate_with_code(text, flow=flow):
|
||||
del self._pending_auth[user_id]
|
||||
reply_message = await self._reply(update, context, "✅ OneDrive 认证成功!")
|
||||
logger.info(f"OneDrive 认证成功 - {_get_user_info(update)}")
|
||||
else:
|
||||
# 认证失败时清理认证信息
|
||||
del self._pending_auth[user_id]
|
||||
await client.logout() # 删除可能存在的旧 token
|
||||
reply_message = await self._reply(update, context, "❌ 认证失败,请重试")
|
||||
logger.error(f"OneDrive 认证失败 - {_get_user_info(update)}")
|
||||
|
||||
# 延迟 5 秒后删除敏感消息(包括认证指引消息)
|
||||
messages_to_delete = [msg for msg in [user_message, reply_message, auth_message] if msg]
|
||||
if messages_to_delete:
|
||||
asyncio.create_task(self._delayed_delete_messages(messages_to_delete))
|
||||
|
||||
async def cloud_logout(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""登出云存储"""
|
||||
logger.info(f"收到云存储登出请求 - {_get_user_info(update)}")
|
||||
client = self._get_onedrive_client()
|
||||
if not client:
|
||||
await self._reply(update, context, "❌ OneDrive 未配置")
|
||||
return
|
||||
|
||||
if await client.logout():
|
||||
await self._reply(update, context, "✅ 已登出 OneDrive")
|
||||
else:
|
||||
await self._reply(update, context, "❌ 登出失败")
|
||||
|
||||
async def cloud_status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""查看云存储状态"""
|
||||
logger.info(f"收到云存储状态查询 - {_get_user_info(update)}")
|
||||
client = self._get_onedrive_client()
|
||||
if not client:
|
||||
await self._reply(update, context, "❌ OneDrive 未配置")
|
||||
return
|
||||
|
||||
is_auth = await client.is_authenticated()
|
||||
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
|
||||
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
|
||||
remote_path = self._onedrive_config.remote_path if self._onedrive_config else "/aria2bot"
|
||||
|
||||
text = (
|
||||
"☁️ *OneDrive 状态*\n\n"
|
||||
f"🔐 认证状态: {'✅ 已认证' if is_auth else '❌ 未认证'}\n"
|
||||
f"📤 自动上传: {'✅ 开启' if auto_upload else '❌ 关闭'}\n"
|
||||
f"🗑️ 上传后删除: {'✅ 开启' if delete_after else '❌ 关闭'}\n"
|
||||
f"📁 远程路径: `{remote_path}`"
|
||||
)
|
||||
await self._reply(update, context, text, parse_mode="Markdown")
|
||||
|
||||
async def upload_to_cloud(self, update: Update, context: ContextTypes.DEFAULT_TYPE, gid: str) -> None:
|
||||
"""上传文件到云存储"""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
logger.info(f"收到上传请求 GID={gid} - {_get_user_info(update)}")
|
||||
client = self._get_onedrive_client()
|
||||
if not client or not await client.is_authenticated():
|
||||
await self._reply(update, context, "❌ OneDrive 未认证,请先使用 /cloud 进行认证")
|
||||
return
|
||||
|
||||
rpc = self._get_rpc_client()
|
||||
try:
|
||||
task = await rpc.get_status(gid)
|
||||
except RpcError as e:
|
||||
await self._reply(update, context, f"❌ 获取任务信息失败: {e}")
|
||||
return
|
||||
|
||||
if task.status != "complete":
|
||||
await self._reply(update, context, "❌ 任务未完成,无法上传")
|
||||
return
|
||||
|
||||
local_path = Path(task.dir) / task.name
|
||||
if not local_path.exists():
|
||||
await self._reply(update, context, "❌ 本地文件不存在")
|
||||
return
|
||||
|
||||
# 计算远程路径(保持目录结构)
|
||||
try:
|
||||
download_dir = DOWNLOAD_DIR.resolve()
|
||||
relative_path = local_path.resolve().relative_to(download_dir)
|
||||
remote_path = f"{self._onedrive_config.remote_path}/{relative_path.parent}"
|
||||
except ValueError:
|
||||
remote_path = self._onedrive_config.remote_path
|
||||
|
||||
msg = await self._reply(update, context, f"☁️ 正在上传: {task.name}\n⏳ 请稍候...")
|
||||
|
||||
success = await client.upload_file(local_path, remote_path)
|
||||
|
||||
if success:
|
||||
result_text = f"✅ 上传成功: {task.name}"
|
||||
if self._onedrive_config and self._onedrive_config.delete_after_upload:
|
||||
try:
|
||||
if local_path.is_dir():
|
||||
shutil.rmtree(local_path)
|
||||
else:
|
||||
local_path.unlink()
|
||||
result_text += "\n🗑️ 本地文件已删除"
|
||||
except Exception as e:
|
||||
result_text += f"\n⚠️ 删除本地文件失败: {e}"
|
||||
await msg.edit_text(result_text)
|
||||
logger.info(f"上传成功 GID={gid} - {_get_user_info(update)}")
|
||||
else:
|
||||
await msg.edit_text(f"❌ 上传失败: {task.name}")
|
||||
logger.error(f"上传失败 GID={gid} - {_get_user_info(update)}")
|
||||
|
||||
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""处理 Reply Keyboard 按钮点击"""
|
||||
text = update.message.text
|
||||
@@ -559,6 +756,11 @@ class Aria2BotAPI:
|
||||
await self._handle_stats_callback(query, rpc)
|
||||
elif action == "cancel":
|
||||
await query.edit_message_text("❌ 操作已取消")
|
||||
# 云存储相关回调
|
||||
elif action == "cloud":
|
||||
await self._handle_cloud_callback(query, update, context, parts)
|
||||
elif action == "upload":
|
||||
await self._handle_upload_callback(query, update, context, parts)
|
||||
|
||||
except RpcError as e:
|
||||
await query.edit_message_text(f"❌ 操作失败: {e}")
|
||||
@@ -740,7 +942,13 @@ class Aria2BotAPI:
|
||||
if task.error_message:
|
||||
text += f"\n❌ 错误: {task.error_message}"
|
||||
|
||||
keyboard = build_detail_keyboard(gid, task.status)
|
||||
# 检查是否显示上传按钮(任务完成且云存储已配置)
|
||||
show_upload = (
|
||||
task.status == "complete" and
|
||||
self._onedrive_config and
|
||||
self._onedrive_config.enabled
|
||||
)
|
||||
keyboard = build_detail_keyboard_with_upload(gid, task.status, show_upload)
|
||||
|
||||
# 只有内容变化时才更新
|
||||
if text != last_text:
|
||||
@@ -774,6 +982,83 @@ class Aria2BotAPI:
|
||||
keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("🔙 返回列表", callback_data="list:menu")]])
|
||||
await query.edit_message_text(text, parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
# === 云存储回调处理 ===
|
||||
|
||||
async def _handle_cloud_callback(self, query, update: Update, context: ContextTypes.DEFAULT_TYPE, parts: list) -> None:
|
||||
"""处理云存储相关回调"""
|
||||
if len(parts) < 2:
|
||||
await query.edit_message_text("❌ 无效操作")
|
||||
return
|
||||
|
||||
sub_action = parts[1]
|
||||
|
||||
if sub_action == "auth":
|
||||
# 认证请求
|
||||
await self.cloud_auth(update, context)
|
||||
elif sub_action == "status":
|
||||
# 状态查询
|
||||
client = self._get_onedrive_client()
|
||||
if not client:
|
||||
await query.edit_message_text("❌ OneDrive 未配置")
|
||||
return
|
||||
is_auth = await client.is_authenticated()
|
||||
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
|
||||
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
|
||||
remote_path = self._onedrive_config.remote_path if self._onedrive_config else "/aria2bot"
|
||||
text = (
|
||||
"☁️ *OneDrive 状态*\n\n"
|
||||
f"🔐 认证状态: {'✅ 已认证' if is_auth else '❌ 未认证'}\n"
|
||||
f"📤 自动上传: {'✅ 开启' if auto_upload else '❌ 关闭'}\n"
|
||||
f"🗑️ 上传后删除: {'✅ 开启' if delete_after else '❌ 关闭'}\n"
|
||||
f"📁 远程路径: `{remote_path}`"
|
||||
)
|
||||
keyboard = build_cloud_menu_keyboard()
|
||||
await query.edit_message_text(text, parse_mode="Markdown", reply_markup=keyboard)
|
||||
elif sub_action == "settings":
|
||||
# 设置页面
|
||||
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
|
||||
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
|
||||
keyboard = build_cloud_settings_keyboard(auto_upload, delete_after)
|
||||
await query.edit_message_text("⚙️ *云存储设置*\n\n点击切换设置:", parse_mode="Markdown", reply_markup=keyboard)
|
||||
elif sub_action == "logout":
|
||||
# 登出
|
||||
client = self._get_onedrive_client()
|
||||
if client and await client.logout():
|
||||
await query.edit_message_text("✅ 已登出 OneDrive")
|
||||
else:
|
||||
await query.edit_message_text("❌ 登出失败")
|
||||
elif sub_action == "menu":
|
||||
# 返回菜单
|
||||
keyboard = build_cloud_menu_keyboard()
|
||||
await query.edit_message_text("☁️ *云存储管理*", parse_mode="Markdown", reply_markup=keyboard)
|
||||
elif sub_action == "toggle":
|
||||
# 切换设置(注意:运行时修改配置,重启后会重置)
|
||||
if len(parts) < 3:
|
||||
return
|
||||
setting = parts[2]
|
||||
if self._onedrive_config:
|
||||
if setting == "auto_upload":
|
||||
self._onedrive_config.auto_upload = not self._onedrive_config.auto_upload
|
||||
elif setting == "delete_after":
|
||||
self._onedrive_config.delete_after_upload = not self._onedrive_config.delete_after_upload
|
||||
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
|
||||
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
|
||||
keyboard = build_cloud_settings_keyboard(auto_upload, delete_after)
|
||||
await query.edit_message_text("⚙️ *云存储设置*\n\n点击切换设置:", parse_mode="Markdown", reply_markup=keyboard)
|
||||
|
||||
async def _handle_upload_callback(self, query, update: Update, context: ContextTypes.DEFAULT_TYPE, parts: list) -> None:
|
||||
"""处理上传回调"""
|
||||
if len(parts) < 3:
|
||||
await query.edit_message_text("❌ 无效操作")
|
||||
return
|
||||
|
||||
provider = parts[1] # onedrive
|
||||
gid = parts[2]
|
||||
|
||||
if provider == "onedrive":
|
||||
await query.edit_message_text("☁️ 正在准备上传...")
|
||||
await self.upload_to_cloud(update, context, gid)
|
||||
|
||||
|
||||
def build_handlers(api: Aria2BotAPI) -> list:
|
||||
"""构建 Handler 列表"""
|
||||
@@ -808,8 +1093,12 @@ def build_handlers(api: Aria2BotAPI) -> list:
|
||||
CommandHandler("add", wrap_with_permission(api.add_download)),
|
||||
CommandHandler("list", wrap_with_permission(api.list_downloads)),
|
||||
CommandHandler("stats", wrap_with_permission(api.global_stats)),
|
||||
# 云存储命令
|
||||
CommandHandler("cloud", wrap_with_permission(api.cloud_command)),
|
||||
# Reply Keyboard 按钮文本处理
|
||||
MessageHandler(filters.TEXT & filters.Regex(button_pattern), wrap_with_permission(api.handle_button_text)),
|
||||
# OneDrive 认证回调 URL 处理
|
||||
MessageHandler(filters.TEXT & filters.Regex(r"^https://login\.microsoftonline\.com"), wrap_with_permission(api.handle_auth_callback)),
|
||||
# 种子文件处理
|
||||
MessageHandler(filters.Document.FileExtension("torrent"), wrap_with_permission(api.handle_torrent)),
|
||||
# Callback Query 处理
|
||||
|
||||
Reference in New Issue
Block a user