feat: 增加OneDrive认证功能

This commit is contained in:
dnslin
2025-12-12 15:09:02 +08:00
parent 1aa80a652d
commit be9ce53561
12 changed files with 1045 additions and 3 deletions

View File

@@ -20,7 +20,9 @@ from src.core import (
get_aria2_version,
generate_rpc_secret,
ARIA2_CONF,
DOWNLOAD_DIR,
)
from src.core.config import OneDriveConfig
from src.aria2 import Aria2Installer, Aria2ServiceManager
from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
from src.telegram.keyboards import (
@@ -32,6 +34,10 @@ from src.telegram.keyboards import (
build_detail_keyboard,
build_after_add_keyboard,
build_main_reply_keyboard,
build_cloud_menu_keyboard,
build_upload_choice_keyboard,
build_cloud_settings_keyboard,
build_detail_keyboard_with_upload,
)
# Reply Keyboard 按钮文本到命令的映射
@@ -83,13 +89,18 @@ import asyncio
from functools import wraps
class Aria2BotAPI:
def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None):
def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None,
onedrive_config: OneDriveConfig | None = None):
self.config = config or Aria2Config()
self.allowed_users = allowed_users or set()
self.installer = Aria2Installer(self.config)
self.service = Aria2ServiceManager()
self._rpc: Aria2RpcClient | None = None
self._auto_refresh_tasks: dict[str, asyncio.Task] = {} # chat_id:msg_id -> task
# 云存储相关
self._onedrive_config = onedrive_config
self._onedrive = None
self._pending_auth: dict[int, dict] = {} # user_id -> flow
async def _check_permission(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool:
"""检查用户权限,返回 True 表示有权限"""
@@ -113,6 +124,13 @@ class Aria2BotAPI:
self._rpc = Aria2RpcClient(port=port, secret=secret)
return self._rpc
def _get_onedrive_client(self):
"""获取或创建 OneDrive 客户端"""
if self._onedrive is None and self._onedrive_config and self._onedrive_config.enabled:
from src.cloud.onedrive import OneDriveClient
self._onedrive = OneDriveClient(self._onedrive_config)
return self._onedrive
async def _reply(self, update: Update, context: ContextTypes.DEFAULT_TYPE, text: str, **kwargs):
if update.effective_message:
return await update.effective_message.reply_text(text, **kwargs)
@@ -120,6 +138,19 @@ class Aria2BotAPI:
return await context.bot.send_message(chat_id=update.effective_chat.id, text=text, **kwargs)
return None
async def _delayed_delete_messages(self, messages: list, delay: int = 5) -> None:
"""延迟删除多条消息"""
try:
await asyncio.sleep(delay)
for msg in messages:
try:
await msg.delete()
except Exception as e:
logger.warning(f"删除消息失败: {e}")
logger.debug("已删除敏感认证消息")
except Exception as e:
logger.warning(f"延迟删除任务失败: {e}")
def _get_rpc_secret(self) -> str:
if self.config.rpc_secret:
return self.config.rpc_secret
@@ -378,6 +409,9 @@ class Aria2BotAPI:
"/list - 查看下载列表",
"/stats - 全局下载统计",
"",
"*云存储*",
"/cloud - 云存储管理菜单",
"",
"/menu - 显示快捷菜单",
"/help - 显示此帮助",
]
@@ -394,6 +428,169 @@ class Aria2BotAPI:
reply_markup=keyboard
)
# === 云存储命令 ===
async def cloud_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""云存储管理菜单"""
logger.info(f"收到 /cloud 命令 - {_get_user_info(update)}")
if not self._onedrive_config or not self._onedrive_config.enabled:
await self._reply(update, context, "❌ 云存储功能未启用,请在配置中设置 ONEDRIVE_ENABLED=true")
return
keyboard = build_cloud_menu_keyboard()
await self._reply(update, context, "☁️ *云存储管理*", parse_mode="Markdown", reply_markup=keyboard)
async def cloud_auth(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""开始 OneDrive 认证"""
logger.info(f"收到云存储认证请求 - {_get_user_info(update)}")
client = self._get_onedrive_client()
if not client:
await self._reply(update, context, "❌ OneDrive 未配置")
return
if await client.is_authenticated():
await self._reply(update, context, "✅ OneDrive 已认证")
return
url, state = await client.get_auth_url()
user_id = update.effective_user.id
auth_message = await self._reply(
update, context,
f"🔐 *OneDrive 认证*\n\n"
f"1\\. 点击下方链接登录 Microsoft 账户\n"
f"2\\. 授权后会跳转到一个空白页面\n"
f"3\\. 复制该页面的完整 URL 发送给我\n\n"
f"[点击认证]({url})",
parse_mode="Markdown"
)
self._pending_auth[user_id] = {"state": state, "message": auth_message}
async def handle_auth_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""处理用户发送的认证回调 URL"""
text = update.message.text
if not text or not text.startswith("https://login.microsoftonline.com"):
return
user_id = update.effective_user.id
if user_id not in self._pending_auth:
return
client = self._get_onedrive_client()
if not client:
return
user_message = update.message # 保存用户消息引用
pending = self._pending_auth[user_id]
flow = pending["state"]
auth_message = pending.get("message") # 认证指引消息
if await client.authenticate_with_code(text, flow=flow):
del self._pending_auth[user_id]
reply_message = await self._reply(update, context, "✅ OneDrive 认证成功!")
logger.info(f"OneDrive 认证成功 - {_get_user_info(update)}")
else:
# 认证失败时清理认证信息
del self._pending_auth[user_id]
await client.logout() # 删除可能存在的旧 token
reply_message = await self._reply(update, context, "❌ 认证失败,请重试")
logger.error(f"OneDrive 认证失败 - {_get_user_info(update)}")
# 延迟 5 秒后删除敏感消息(包括认证指引消息)
messages_to_delete = [msg for msg in [user_message, reply_message, auth_message] if msg]
if messages_to_delete:
asyncio.create_task(self._delayed_delete_messages(messages_to_delete))
async def cloud_logout(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""登出云存储"""
logger.info(f"收到云存储登出请求 - {_get_user_info(update)}")
client = self._get_onedrive_client()
if not client:
await self._reply(update, context, "❌ OneDrive 未配置")
return
if await client.logout():
await self._reply(update, context, "✅ 已登出 OneDrive")
else:
await self._reply(update, context, "❌ 登出失败")
async def cloud_status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""查看云存储状态"""
logger.info(f"收到云存储状态查询 - {_get_user_info(update)}")
client = self._get_onedrive_client()
if not client:
await self._reply(update, context, "❌ OneDrive 未配置")
return
is_auth = await client.is_authenticated()
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
remote_path = self._onedrive_config.remote_path if self._onedrive_config else "/aria2bot"
text = (
"☁️ *OneDrive 状态*\n\n"
f"🔐 认证状态: {'✅ 已认证' if is_auth else '❌ 未认证'}\n"
f"📤 自动上传: {'✅ 开启' if auto_upload else '❌ 关闭'}\n"
f"🗑️ 上传后删除: {'✅ 开启' if delete_after else '❌ 关闭'}\n"
f"📁 远程路径: `{remote_path}`"
)
await self._reply(update, context, text, parse_mode="Markdown")
async def upload_to_cloud(self, update: Update, context: ContextTypes.DEFAULT_TYPE, gid: str) -> None:
"""上传文件到云存储"""
from pathlib import Path
import shutil
logger.info(f"收到上传请求 GID={gid} - {_get_user_info(update)}")
client = self._get_onedrive_client()
if not client or not await client.is_authenticated():
await self._reply(update, context, "❌ OneDrive 未认证,请先使用 /cloud 进行认证")
return
rpc = self._get_rpc_client()
try:
task = await rpc.get_status(gid)
except RpcError as e:
await self._reply(update, context, f"❌ 获取任务信息失败: {e}")
return
if task.status != "complete":
await self._reply(update, context, "❌ 任务未完成,无法上传")
return
local_path = Path(task.dir) / task.name
if not local_path.exists():
await self._reply(update, context, "❌ 本地文件不存在")
return
# 计算远程路径(保持目录结构)
try:
download_dir = DOWNLOAD_DIR.resolve()
relative_path = local_path.resolve().relative_to(download_dir)
remote_path = f"{self._onedrive_config.remote_path}/{relative_path.parent}"
except ValueError:
remote_path = self._onedrive_config.remote_path
msg = await self._reply(update, context, f"☁️ 正在上传: {task.name}\n⏳ 请稍候...")
success = await client.upload_file(local_path, remote_path)
if success:
result_text = f"✅ 上传成功: {task.name}"
if self._onedrive_config and self._onedrive_config.delete_after_upload:
try:
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
result_text += "\n🗑️ 本地文件已删除"
except Exception as e:
result_text += f"\n⚠️ 删除本地文件失败: {e}"
await msg.edit_text(result_text)
logger.info(f"上传成功 GID={gid} - {_get_user_info(update)}")
else:
await msg.edit_text(f"❌ 上传失败: {task.name}")
logger.error(f"上传失败 GID={gid} - {_get_user_info(update)}")
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""处理 Reply Keyboard 按钮点击"""
text = update.message.text
@@ -559,6 +756,11 @@ class Aria2BotAPI:
await self._handle_stats_callback(query, rpc)
elif action == "cancel":
await query.edit_message_text("❌ 操作已取消")
# 云存储相关回调
elif action == "cloud":
await self._handle_cloud_callback(query, update, context, parts)
elif action == "upload":
await self._handle_upload_callback(query, update, context, parts)
except RpcError as e:
await query.edit_message_text(f"❌ 操作失败: {e}")
@@ -740,7 +942,13 @@ class Aria2BotAPI:
if task.error_message:
text += f"\n❌ 错误: {task.error_message}"
keyboard = build_detail_keyboard(gid, task.status)
# 检查是否显示上传按钮(任务完成且云存储已配置)
show_upload = (
task.status == "complete" and
self._onedrive_config and
self._onedrive_config.enabled
)
keyboard = build_detail_keyboard_with_upload(gid, task.status, show_upload)
# 只有内容变化时才更新
if text != last_text:
@@ -774,6 +982,83 @@ class Aria2BotAPI:
keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("🔙 返回列表", callback_data="list:menu")]])
await query.edit_message_text(text, parse_mode="Markdown", reply_markup=keyboard)
# === 云存储回调处理 ===
async def _handle_cloud_callback(self, query, update: Update, context: ContextTypes.DEFAULT_TYPE, parts: list) -> None:
"""处理云存储相关回调"""
if len(parts) < 2:
await query.edit_message_text("❌ 无效操作")
return
sub_action = parts[1]
if sub_action == "auth":
# 认证请求
await self.cloud_auth(update, context)
elif sub_action == "status":
# 状态查询
client = self._get_onedrive_client()
if not client:
await query.edit_message_text("❌ OneDrive 未配置")
return
is_auth = await client.is_authenticated()
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
remote_path = self._onedrive_config.remote_path if self._onedrive_config else "/aria2bot"
text = (
"☁️ *OneDrive 状态*\n\n"
f"🔐 认证状态: {'✅ 已认证' if is_auth else '❌ 未认证'}\n"
f"📤 自动上传: {'✅ 开启' if auto_upload else '❌ 关闭'}\n"
f"🗑️ 上传后删除: {'✅ 开启' if delete_after else '❌ 关闭'}\n"
f"📁 远程路径: `{remote_path}`"
)
keyboard = build_cloud_menu_keyboard()
await query.edit_message_text(text, parse_mode="Markdown", reply_markup=keyboard)
elif sub_action == "settings":
# 设置页面
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
keyboard = build_cloud_settings_keyboard(auto_upload, delete_after)
await query.edit_message_text("⚙️ *云存储设置*\n\n点击切换设置:", parse_mode="Markdown", reply_markup=keyboard)
elif sub_action == "logout":
# 登出
client = self._get_onedrive_client()
if client and await client.logout():
await query.edit_message_text("✅ 已登出 OneDrive")
else:
await query.edit_message_text("❌ 登出失败")
elif sub_action == "menu":
# 返回菜单
keyboard = build_cloud_menu_keyboard()
await query.edit_message_text("☁️ *云存储管理*", parse_mode="Markdown", reply_markup=keyboard)
elif sub_action == "toggle":
# 切换设置(注意:运行时修改配置,重启后会重置)
if len(parts) < 3:
return
setting = parts[2]
if self._onedrive_config:
if setting == "auto_upload":
self._onedrive_config.auto_upload = not self._onedrive_config.auto_upload
elif setting == "delete_after":
self._onedrive_config.delete_after_upload = not self._onedrive_config.delete_after_upload
auto_upload = self._onedrive_config.auto_upload if self._onedrive_config else False
delete_after = self._onedrive_config.delete_after_upload if self._onedrive_config else False
keyboard = build_cloud_settings_keyboard(auto_upload, delete_after)
await query.edit_message_text("⚙️ *云存储设置*\n\n点击切换设置:", parse_mode="Markdown", reply_markup=keyboard)
async def _handle_upload_callback(self, query, update: Update, context: ContextTypes.DEFAULT_TYPE, parts: list) -> None:
"""处理上传回调"""
if len(parts) < 3:
await query.edit_message_text("❌ 无效操作")
return
provider = parts[1] # onedrive
gid = parts[2]
if provider == "onedrive":
await query.edit_message_text("☁️ 正在准备上传...")
await self.upload_to_cloud(update, context, gid)
def build_handlers(api: Aria2BotAPI) -> list:
"""构建 Handler 列表"""
@@ -808,8 +1093,12 @@ def build_handlers(api: Aria2BotAPI) -> list:
CommandHandler("add", wrap_with_permission(api.add_download)),
CommandHandler("list", wrap_with_permission(api.list_downloads)),
CommandHandler("stats", wrap_with_permission(api.global_stats)),
# 云存储命令
CommandHandler("cloud", wrap_with_permission(api.cloud_command)),
# Reply Keyboard 按钮文本处理
MessageHandler(filters.TEXT & filters.Regex(button_pattern), wrap_with_permission(api.handle_button_text)),
# OneDrive 认证回调 URL 处理
MessageHandler(filters.TEXT & filters.Regex(r"^https://login\.microsoftonline\.com"), wrap_with_permission(api.handle_auth_callback)),
# 种子文件处理
MessageHandler(filters.Document.FileExtension("torrent"), wrap_with_permission(api.handle_torrent)),
# Callback Query 处理