diff --git a/src/cloud/onedrive.py b/src/cloud/onedrive.py index 98b39d9..3fffed8 100644 --- a/src/cloud/onedrive.py +++ b/src/cloud/onedrive.py @@ -3,9 +3,12 @@ from __future__ import annotations import asyncio import json +import time from pathlib import Path from typing import Callable +from urllib.parse import quote +import httpx from O365 import Account from O365.utils import BaseTokenBackend @@ -16,6 +19,11 @@ from src.utils.logger import get_logger logger = get_logger("onedrive") +# 上传相关常量 +SIMPLE_UPLOAD_LIMIT = 4 * 1024 * 1024 # 4MB,超过此大小使用分块上传 +DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024 # 5MB,必须是 320KB 的倍数 +PROGRESS_UPDATE_INTERVAL = 2.0 # 进度更新间隔(秒) + class FileTokenBackend(BaseTokenBackend): """文件系统 Token 存储后端""" @@ -128,16 +136,19 @@ class OneDriveClient(CloudStorageBase): remote_path: str, progress_callback: Callable[[UploadProgress], None] | None = None ) -> bool: - """上传文件到 OneDrive""" + """上传文件到 OneDrive + + 对于大文件(>4MB)或需要进度回调时,使用分块上传以支持实时进度显示。 + """ account = self._get_account() if not account.is_authenticated: raise RuntimeError("OneDrive 未认证") try: - # 获取存储和驱动器 - storage = account.storage() - drive = storage.get_default_drive() - root = drive.get_root_folder() + # 获取存储和驱动器(在线程池中执行,避免阻塞事件循环) + storage = await asyncio.to_thread(account.storage) + drive = await asyncio.to_thread(storage.get_default_drive) + root = await asyncio.to_thread(drive.get_root_folder) # 确保远程目录存在 target_folder = await self._ensure_folder_path(root, remote_path) @@ -154,13 +165,24 @@ class OneDriveClient(CloudStorageBase): status=UploadStatus.UPLOADING )) - # 执行上传(python-o365 会自动处理大文件分块) - uploaded = await asyncio.to_thread( - target_folder.upload_file, - item=str(local_path) - ) + # 大文件或需要进度回调时使用分块上传 + if file_size > SIMPLE_UPLOAD_LIMIT or progress_callback: + success = await self._chunked_upload_with_progress( + target_folder=target_folder, + local_path=local_path, + file_name=file_name, + file_size=file_size, + progress_callback=progress_callback + ) + else: + # 小文件使用简单上传 + uploaded = await asyncio.to_thread( + target_folder.upload_file, + item=str(local_path) + ) + success = uploaded is not None - if uploaded: + if success: if progress_callback: progress_callback(UploadProgress( file_name=file_name, @@ -195,7 +217,7 @@ class OneDriveClient(CloudStorageBase): Returns: 目标文件夹对象 """ - parts = [p for p in path.strip("/").split("/") if p] + parts = [p for p in path.strip("/").split("/") if p and p not in (".", "..")] current = root_folder for part in parts: @@ -218,6 +240,128 @@ class OneDriveClient(CloudStorageBase): return current + def _sync_chunked_upload( + self, + folder_id: str, + access_token: str, + local_path: Path, + file_name: str, + file_size: int, + progress_callback: Callable[[UploadProgress], None] | None, + chunk_size: int = DEFAULT_CHUNK_SIZE + ) -> bool: + """同步分块上传(在线程池中执行,避免阻塞事件循环)""" + create_session_url = ( + f"https://graph.microsoft.com/v1.0/me/drive/items/{folder_id}:/{quote(file_name)}:/createUploadSession" + ) + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json" + } + + with httpx.Client(timeout=60.0) as client: + # 创建上传会话 + response = client.post( + create_session_url, + headers=headers, + json={"item": {"@microsoft.graph.conflictBehavior": "replace"}} + ) + + if response.status_code != 200: + logger.error(f"创建上传会话失败: {response.status_code} - {response.text}") + return False + + session_data = response.json() + upload_url = session_data.get("uploadUrl") + if not upload_url: + logger.error("上传会话响应中没有 uploadUrl") + return False + + logger.info(f"创建上传会话成功,开始分块上传: {file_name}") + + # 分块上传 + uploaded_size = 0 + last_progress_time = time.time() + + with open(local_path, "rb") as f: + while uploaded_size < file_size: + chunk_data = f.read(chunk_size) + if not chunk_data: + break + + chunk_len = len(chunk_data) + range_start = uploaded_size + range_end = uploaded_size + chunk_len - 1 + + chunk_headers = { + "Content-Length": str(chunk_len), + "Content-Range": f"bytes {range_start}-{range_end}/{file_size}" + } + + chunk_response = client.put( + upload_url, + headers=chunk_headers, + content=chunk_data + ) + + if chunk_response.status_code not in (200, 201, 202): + logger.error(f"上传 chunk 失败: {chunk_response.status_code} - {chunk_response.text}") + return False + + uploaded_size += chunk_len + + # 按时间间隔更新进度 + current_time = time.time() + if progress_callback and (current_time - last_progress_time >= PROGRESS_UPDATE_INTERVAL): + progress_callback(UploadProgress( + file_name=file_name, + total_size=file_size, + uploaded_size=uploaded_size, + status=UploadStatus.UPLOADING + )) + last_progress_time = current_time + + if chunk_response.status_code in (200, 201): + logger.info(f"文件上传完成: {file_name}") + break + + return True + + async def _chunked_upload_with_progress( + self, + target_folder, + local_path: Path, + file_name: str, + file_size: int, + progress_callback: Callable[[UploadProgress], None] | None, + chunk_size: int = DEFAULT_CHUNK_SIZE + ) -> bool: + """分块上传文件,支持进度回调(在线程池中执行,不阻塞事件循环)""" + account = self._get_account() + + # 从 token backend 获取 access token + token_info = account.con.token_backend.get_access_token() + if not token_info: + raise RuntimeError("无法获取 access token") + access_token = token_info.get("secret") + if not access_token: + raise RuntimeError("access token 无效") + + folder_id = target_folder.object_id + + # 在线程池中执行同步上传,避免阻塞事件循环 + return await asyncio.to_thread( + self._sync_chunked_upload, + folder_id, + access_token, + local_path, + file_name, + file_size, + progress_callback, + chunk_size + ) + async def logout(self) -> bool: """清除认证""" try: diff --git a/src/telegram/handlers.py b/src/telegram/handlers.py index 4dad7cd..e8b7753 100644 --- a/src/telegram/handlers.py +++ b/src/telegram/handlers.py @@ -23,6 +23,7 @@ from src.core import ( DOWNLOAD_DIR, ) from src.core.config import OneDriveConfig +from src.cloud.base import UploadProgress, UploadStatus from src.aria2 import Aria2Installer, Aria2ServiceManager from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size from src.telegram.keyboards import ( @@ -536,9 +537,8 @@ class Aria2BotAPI: await self._reply(update, context, text, parse_mode="Markdown") async def upload_to_cloud(self, update: Update, context: ContextTypes.DEFAULT_TYPE, gid: str) -> None: - """上传文件到云存储""" + """上传文件到云存储(启动后台任务,不阻塞其他命令)""" from pathlib import Path - import shutil logger.info(f"收到上传请求 GID={gid} - {_get_user_info(update)}") client = self._get_onedrive_client() @@ -572,24 +572,65 @@ class Aria2BotAPI: msg = await self._reply(update, context, f"☁️ 正在上传: {task.name}\n⏳ 请稍候...") - success = await client.upload_file(local_path, remote_path) + # 启动后台上传任务,不阻塞其他命令 + asyncio.create_task(self._do_upload_to_cloud( + client, local_path, remote_path, task.name, msg, gid, _get_user_info(update) + )) - if success: - result_text = f"✅ 上传成功: {task.name}" - if self._onedrive_config and self._onedrive_config.delete_after_upload: + async def _do_upload_to_cloud( + self, client, local_path, remote_path: str, task_name: str, msg, gid: str, user_info: str + ) -> None: + """后台执行上传任务""" + import shutil + + loop = asyncio.get_event_loop() + + # 进度回调函数 + async def update_progress(progress: UploadProgress): + """更新上传进度消息""" + if progress.status == UploadStatus.UPLOADING and progress.total_size > 0: + percent = progress.progress + uploaded_mb = progress.uploaded_size / (1024 * 1024) + total_mb = progress.total_size / (1024 * 1024) + progress_text = ( + f"☁️ 正在上传: {task_name}\n" + f"📤 {percent:.1f}% ({uploaded_mb:.1f}MB / {total_mb:.1f}MB)" + ) try: - if local_path.is_dir(): - shutil.rmtree(local_path) - else: - local_path.unlink() - result_text += "\n🗑️ 本地文件已删除" - except Exception as e: - result_text += f"\n⚠️ 删除本地文件失败: {e}" - await msg.edit_text(result_text) - logger.info(f"上传成功 GID={gid} - {_get_user_info(update)}") - else: - await msg.edit_text(f"❌ 上传失败: {task.name}") - logger.error(f"上传失败 GID={gid} - {_get_user_info(update)}") + await msg.edit_text(progress_text) + except Exception: + pass # 忽略消息更新失败(如内容未变化) + + def sync_progress_callback(progress: UploadProgress): + """同步回调,将异步更新调度到事件循环""" + if progress.status == UploadStatus.UPLOADING: + asyncio.run_coroutine_threadsafe(update_progress(progress), loop) + + try: + success = await client.upload_file(local_path, remote_path, progress_callback=sync_progress_callback) + + if success: + result_text = f"✅ 上传成功: {task_name}" + if self._onedrive_config and self._onedrive_config.delete_after_upload: + try: + if local_path.is_dir(): + shutil.rmtree(local_path) + else: + local_path.unlink() + result_text += "\n🗑️ 本地文件已删除" + except Exception as e: + result_text += f"\n⚠️ 删除本地文件失败: {e}" + await msg.edit_text(result_text) + logger.info(f"上传成功 GID={gid} - {user_info}") + else: + await msg.edit_text(f"❌ 上传失败: {task_name}") + logger.error(f"上传失败 GID={gid} - {user_info}") + except Exception as e: + logger.error(f"上传异常 GID={gid}: {e} - {user_info}") + try: + await msg.edit_text(f"❌ 上传失败: {task_name}\n错误: {e}") + except Exception: + pass async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """处理 Reply Keyboard 按钮点击"""