feat: 增加上传回调

This commit is contained in:
dnslin
2025-12-12 16:04:14 +08:00
parent be9ce53561
commit f71ce7e0ca
2 changed files with 215 additions and 30 deletions

View File

@@ -3,9 +3,12 @@ from __future__ import annotations
import asyncio
import json
import time
from pathlib import Path
from typing import Callable
from urllib.parse import quote
import httpx
from O365 import Account
from O365.utils import BaseTokenBackend
@@ -16,6 +19,11 @@ from src.utils.logger import get_logger
logger = get_logger("onedrive")
# 上传相关常量
SIMPLE_UPLOAD_LIMIT = 4 * 1024 * 1024 # 4MB超过此大小使用分块上传
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024 # 5MB必须是 320KB 的倍数
PROGRESS_UPDATE_INTERVAL = 2.0 # 进度更新间隔(秒)
class FileTokenBackend(BaseTokenBackend):
"""文件系统 Token 存储后端"""
@@ -128,16 +136,19 @@ class OneDriveClient(CloudStorageBase):
remote_path: str,
progress_callback: Callable[[UploadProgress], None] | None = None
) -> bool:
"""上传文件到 OneDrive"""
"""上传文件到 OneDrive
对于大文件(>4MB或需要进度回调时使用分块上传以支持实时进度显示。
"""
account = self._get_account()
if not account.is_authenticated:
raise RuntimeError("OneDrive 未认证")
try:
# 获取存储和驱动器
storage = account.storage()
drive = storage.get_default_drive()
root = drive.get_root_folder()
# 获取存储和驱动器(在线程池中执行,避免阻塞事件循环)
storage = await asyncio.to_thread(account.storage)
drive = await asyncio.to_thread(storage.get_default_drive)
root = await asyncio.to_thread(drive.get_root_folder)
# 确保远程目录存在
target_folder = await self._ensure_folder_path(root, remote_path)
@@ -154,13 +165,24 @@ class OneDriveClient(CloudStorageBase):
status=UploadStatus.UPLOADING
))
# 执行上传python-o365 会自动处理大文件分块)
uploaded = await asyncio.to_thread(
target_folder.upload_file,
item=str(local_path)
)
# 大文件或需要进度回调时使用分块上传
if file_size > SIMPLE_UPLOAD_LIMIT or progress_callback:
success = await self._chunked_upload_with_progress(
target_folder=target_folder,
local_path=local_path,
file_name=file_name,
file_size=file_size,
progress_callback=progress_callback
)
else:
# 小文件使用简单上传
uploaded = await asyncio.to_thread(
target_folder.upload_file,
item=str(local_path)
)
success = uploaded is not None
if uploaded:
if success:
if progress_callback:
progress_callback(UploadProgress(
file_name=file_name,
@@ -195,7 +217,7 @@ class OneDriveClient(CloudStorageBase):
Returns:
目标文件夹对象
"""
parts = [p for p in path.strip("/").split("/") if p]
parts = [p for p in path.strip("/").split("/") if p and p not in (".", "..")]
current = root_folder
for part in parts:
@@ -218,6 +240,128 @@ class OneDriveClient(CloudStorageBase):
return current
def _sync_chunked_upload(
self,
folder_id: str,
access_token: str,
local_path: Path,
file_name: str,
file_size: int,
progress_callback: Callable[[UploadProgress], None] | None,
chunk_size: int = DEFAULT_CHUNK_SIZE
) -> bool:
"""同步分块上传(在线程池中执行,避免阻塞事件循环)"""
create_session_url = (
f"https://graph.microsoft.com/v1.0/me/drive/items/{folder_id}:/{quote(file_name)}:/createUploadSession"
)
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json"
}
with httpx.Client(timeout=60.0) as client:
# 创建上传会话
response = client.post(
create_session_url,
headers=headers,
json={"item": {"@microsoft.graph.conflictBehavior": "replace"}}
)
if response.status_code != 200:
logger.error(f"创建上传会话失败: {response.status_code} - {response.text}")
return False
session_data = response.json()
upload_url = session_data.get("uploadUrl")
if not upload_url:
logger.error("上传会话响应中没有 uploadUrl")
return False
logger.info(f"创建上传会话成功,开始分块上传: {file_name}")
# 分块上传
uploaded_size = 0
last_progress_time = time.time()
with open(local_path, "rb") as f:
while uploaded_size < file_size:
chunk_data = f.read(chunk_size)
if not chunk_data:
break
chunk_len = len(chunk_data)
range_start = uploaded_size
range_end = uploaded_size + chunk_len - 1
chunk_headers = {
"Content-Length": str(chunk_len),
"Content-Range": f"bytes {range_start}-{range_end}/{file_size}"
}
chunk_response = client.put(
upload_url,
headers=chunk_headers,
content=chunk_data
)
if chunk_response.status_code not in (200, 201, 202):
logger.error(f"上传 chunk 失败: {chunk_response.status_code} - {chunk_response.text}")
return False
uploaded_size += chunk_len
# 按时间间隔更新进度
current_time = time.time()
if progress_callback and (current_time - last_progress_time >= PROGRESS_UPDATE_INTERVAL):
progress_callback(UploadProgress(
file_name=file_name,
total_size=file_size,
uploaded_size=uploaded_size,
status=UploadStatus.UPLOADING
))
last_progress_time = current_time
if chunk_response.status_code in (200, 201):
logger.info(f"文件上传完成: {file_name}")
break
return True
async def _chunked_upload_with_progress(
self,
target_folder,
local_path: Path,
file_name: str,
file_size: int,
progress_callback: Callable[[UploadProgress], None] | None,
chunk_size: int = DEFAULT_CHUNK_SIZE
) -> bool:
"""分块上传文件,支持进度回调(在线程池中执行,不阻塞事件循环)"""
account = self._get_account()
# 从 token backend 获取 access token
token_info = account.con.token_backend.get_access_token()
if not token_info:
raise RuntimeError("无法获取 access token")
access_token = token_info.get("secret")
if not access_token:
raise RuntimeError("access token 无效")
folder_id = target_folder.object_id
# 在线程池中执行同步上传,避免阻塞事件循环
return await asyncio.to_thread(
self._sync_chunked_upload,
folder_id,
access_token,
local_path,
file_name,
file_size,
progress_callback,
chunk_size
)
async def logout(self) -> bool:
"""清除认证"""
try:

View File

@@ -23,6 +23,7 @@ from src.core import (
DOWNLOAD_DIR,
)
from src.core.config import OneDriveConfig
from src.cloud.base import UploadProgress, UploadStatus
from src.aria2 import Aria2Installer, Aria2ServiceManager
from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
from src.telegram.keyboards import (
@@ -536,9 +537,8 @@ class Aria2BotAPI:
await self._reply(update, context, text, parse_mode="Markdown")
async def upload_to_cloud(self, update: Update, context: ContextTypes.DEFAULT_TYPE, gid: str) -> None:
"""上传文件到云存储"""
"""上传文件到云存储(启动后台任务,不阻塞其他命令)"""
from pathlib import Path
import shutil
logger.info(f"收到上传请求 GID={gid} - {_get_user_info(update)}")
client = self._get_onedrive_client()
@@ -572,24 +572,65 @@ class Aria2BotAPI:
msg = await self._reply(update, context, f"☁️ 正在上传: {task.name}\n⏳ 请稍候...")
success = await client.upload_file(local_path, remote_path)
# 启动后台上传任务,不阻塞其他命令
asyncio.create_task(self._do_upload_to_cloud(
client, local_path, remote_path, task.name, msg, gid, _get_user_info(update)
))
if success:
result_text = f"✅ 上传成功: {task.name}"
if self._onedrive_config and self._onedrive_config.delete_after_upload:
async def _do_upload_to_cloud(
self, client, local_path, remote_path: str, task_name: str, msg, gid: str, user_info: str
) -> None:
"""后台执行上传任务"""
import shutil
loop = asyncio.get_event_loop()
# 进度回调函数
async def update_progress(progress: UploadProgress):
"""更新上传进度消息"""
if progress.status == UploadStatus.UPLOADING and progress.total_size > 0:
percent = progress.progress
uploaded_mb = progress.uploaded_size / (1024 * 1024)
total_mb = progress.total_size / (1024 * 1024)
progress_text = (
f"☁️ 正在上传: {task_name}\n"
f"📤 {percent:.1f}% ({uploaded_mb:.1f}MB / {total_mb:.1f}MB)"
)
try:
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
result_text += "\n🗑️ 本地文件已删除"
except Exception as e:
result_text += f"\n⚠️ 删除本地文件失败: {e}"
await msg.edit_text(result_text)
logger.info(f"上传成功 GID={gid} - {_get_user_info(update)}")
else:
await msg.edit_text(f"❌ 上传失败: {task.name}")
logger.error(f"上传失败 GID={gid} - {_get_user_info(update)}")
await msg.edit_text(progress_text)
except Exception:
pass # 忽略消息更新失败(如内容未变化)
def sync_progress_callback(progress: UploadProgress):
"""同步回调,将异步更新调度到事件循环"""
if progress.status == UploadStatus.UPLOADING:
asyncio.run_coroutine_threadsafe(update_progress(progress), loop)
try:
success = await client.upload_file(local_path, remote_path, progress_callback=sync_progress_callback)
if success:
result_text = f"✅ 上传成功: {task_name}"
if self._onedrive_config and self._onedrive_config.delete_after_upload:
try:
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
result_text += "\n🗑️ 本地文件已删除"
except Exception as e:
result_text += f"\n⚠️ 删除本地文件失败: {e}"
await msg.edit_text(result_text)
logger.info(f"上传成功 GID={gid} - {user_info}")
else:
await msg.edit_text(f"❌ 上传失败: {task_name}")
logger.error(f"上传失败 GID={gid} - {user_info}")
except Exception as e:
logger.error(f"上传异常 GID={gid}: {e} - {user_info}")
try:
await msg.edit_text(f"❌ 上传失败: {task_name}\n错误: {e}")
except Exception:
pass
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""处理 Reply Keyboard 按钮点击"""