mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-11 20:12:20 +08:00
feat: 增加上传回调
This commit is contained in:
@@ -3,9 +3,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import httpx
|
||||||
from O365 import Account
|
from O365 import Account
|
||||||
from O365.utils import BaseTokenBackend
|
from O365.utils import BaseTokenBackend
|
||||||
|
|
||||||
@@ -16,6 +19,11 @@ from src.utils.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("onedrive")
|
logger = get_logger("onedrive")
|
||||||
|
|
||||||
|
# 上传相关常量
|
||||||
|
SIMPLE_UPLOAD_LIMIT = 4 * 1024 * 1024 # 4MB,超过此大小使用分块上传
|
||||||
|
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024 # 5MB,必须是 320KB 的倍数
|
||||||
|
PROGRESS_UPDATE_INTERVAL = 2.0 # 进度更新间隔(秒)
|
||||||
|
|
||||||
|
|
||||||
class FileTokenBackend(BaseTokenBackend):
|
class FileTokenBackend(BaseTokenBackend):
|
||||||
"""文件系统 Token 存储后端"""
|
"""文件系统 Token 存储后端"""
|
||||||
@@ -128,16 +136,19 @@ class OneDriveClient(CloudStorageBase):
|
|||||||
remote_path: str,
|
remote_path: str,
|
||||||
progress_callback: Callable[[UploadProgress], None] | None = None
|
progress_callback: Callable[[UploadProgress], None] | None = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""上传文件到 OneDrive"""
|
"""上传文件到 OneDrive
|
||||||
|
|
||||||
|
对于大文件(>4MB)或需要进度回调时,使用分块上传以支持实时进度显示。
|
||||||
|
"""
|
||||||
account = self._get_account()
|
account = self._get_account()
|
||||||
if not account.is_authenticated:
|
if not account.is_authenticated:
|
||||||
raise RuntimeError("OneDrive 未认证")
|
raise RuntimeError("OneDrive 未认证")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取存储和驱动器
|
# 获取存储和驱动器(在线程池中执行,避免阻塞事件循环)
|
||||||
storage = account.storage()
|
storage = await asyncio.to_thread(account.storage)
|
||||||
drive = storage.get_default_drive()
|
drive = await asyncio.to_thread(storage.get_default_drive)
|
||||||
root = drive.get_root_folder()
|
root = await asyncio.to_thread(drive.get_root_folder)
|
||||||
|
|
||||||
# 确保远程目录存在
|
# 确保远程目录存在
|
||||||
target_folder = await self._ensure_folder_path(root, remote_path)
|
target_folder = await self._ensure_folder_path(root, remote_path)
|
||||||
@@ -154,13 +165,24 @@ class OneDriveClient(CloudStorageBase):
|
|||||||
status=UploadStatus.UPLOADING
|
status=UploadStatus.UPLOADING
|
||||||
))
|
))
|
||||||
|
|
||||||
# 执行上传(python-o365 会自动处理大文件分块)
|
# 大文件或需要进度回调时使用分块上传
|
||||||
uploaded = await asyncio.to_thread(
|
if file_size > SIMPLE_UPLOAD_LIMIT or progress_callback:
|
||||||
target_folder.upload_file,
|
success = await self._chunked_upload_with_progress(
|
||||||
item=str(local_path)
|
target_folder=target_folder,
|
||||||
)
|
local_path=local_path,
|
||||||
|
file_name=file_name,
|
||||||
|
file_size=file_size,
|
||||||
|
progress_callback=progress_callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 小文件使用简单上传
|
||||||
|
uploaded = await asyncio.to_thread(
|
||||||
|
target_folder.upload_file,
|
||||||
|
item=str(local_path)
|
||||||
|
)
|
||||||
|
success = uploaded is not None
|
||||||
|
|
||||||
if uploaded:
|
if success:
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(UploadProgress(
|
progress_callback(UploadProgress(
|
||||||
file_name=file_name,
|
file_name=file_name,
|
||||||
@@ -195,7 +217,7 @@ class OneDriveClient(CloudStorageBase):
|
|||||||
Returns:
|
Returns:
|
||||||
目标文件夹对象
|
目标文件夹对象
|
||||||
"""
|
"""
|
||||||
parts = [p for p in path.strip("/").split("/") if p]
|
parts = [p for p in path.strip("/").split("/") if p and p not in (".", "..")]
|
||||||
current = root_folder
|
current = root_folder
|
||||||
|
|
||||||
for part in parts:
|
for part in parts:
|
||||||
@@ -218,6 +240,128 @@ class OneDriveClient(CloudStorageBase):
|
|||||||
|
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
def _sync_chunked_upload(
|
||||||
|
self,
|
||||||
|
folder_id: str,
|
||||||
|
access_token: str,
|
||||||
|
local_path: Path,
|
||||||
|
file_name: str,
|
||||||
|
file_size: int,
|
||||||
|
progress_callback: Callable[[UploadProgress], None] | None,
|
||||||
|
chunk_size: int = DEFAULT_CHUNK_SIZE
|
||||||
|
) -> bool:
|
||||||
|
"""同步分块上传(在线程池中执行,避免阻塞事件循环)"""
|
||||||
|
create_session_url = (
|
||||||
|
f"https://graph.microsoft.com/v1.0/me/drive/items/{folder_id}:/{quote(file_name)}:/createUploadSession"
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
# 创建上传会话
|
||||||
|
response = client.post(
|
||||||
|
create_session_url,
|
||||||
|
headers=headers,
|
||||||
|
json={"item": {"@microsoft.graph.conflictBehavior": "replace"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"创建上传会话失败: {response.status_code} - {response.text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
session_data = response.json()
|
||||||
|
upload_url = session_data.get("uploadUrl")
|
||||||
|
if not upload_url:
|
||||||
|
logger.error("上传会话响应中没有 uploadUrl")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"创建上传会话成功,开始分块上传: {file_name}")
|
||||||
|
|
||||||
|
# 分块上传
|
||||||
|
uploaded_size = 0
|
||||||
|
last_progress_time = time.time()
|
||||||
|
|
||||||
|
with open(local_path, "rb") as f:
|
||||||
|
while uploaded_size < file_size:
|
||||||
|
chunk_data = f.read(chunk_size)
|
||||||
|
if not chunk_data:
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk_len = len(chunk_data)
|
||||||
|
range_start = uploaded_size
|
||||||
|
range_end = uploaded_size + chunk_len - 1
|
||||||
|
|
||||||
|
chunk_headers = {
|
||||||
|
"Content-Length": str(chunk_len),
|
||||||
|
"Content-Range": f"bytes {range_start}-{range_end}/{file_size}"
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_response = client.put(
|
||||||
|
upload_url,
|
||||||
|
headers=chunk_headers,
|
||||||
|
content=chunk_data
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunk_response.status_code not in (200, 201, 202):
|
||||||
|
logger.error(f"上传 chunk 失败: {chunk_response.status_code} - {chunk_response.text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
uploaded_size += chunk_len
|
||||||
|
|
||||||
|
# 按时间间隔更新进度
|
||||||
|
current_time = time.time()
|
||||||
|
if progress_callback and (current_time - last_progress_time >= PROGRESS_UPDATE_INTERVAL):
|
||||||
|
progress_callback(UploadProgress(
|
||||||
|
file_name=file_name,
|
||||||
|
total_size=file_size,
|
||||||
|
uploaded_size=uploaded_size,
|
||||||
|
status=UploadStatus.UPLOADING
|
||||||
|
))
|
||||||
|
last_progress_time = current_time
|
||||||
|
|
||||||
|
if chunk_response.status_code in (200, 201):
|
||||||
|
logger.info(f"文件上传完成: {file_name}")
|
||||||
|
break
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _chunked_upload_with_progress(
|
||||||
|
self,
|
||||||
|
target_folder,
|
||||||
|
local_path: Path,
|
||||||
|
file_name: str,
|
||||||
|
file_size: int,
|
||||||
|
progress_callback: Callable[[UploadProgress], None] | None,
|
||||||
|
chunk_size: int = DEFAULT_CHUNK_SIZE
|
||||||
|
) -> bool:
|
||||||
|
"""分块上传文件,支持进度回调(在线程池中执行,不阻塞事件循环)"""
|
||||||
|
account = self._get_account()
|
||||||
|
|
||||||
|
# 从 token backend 获取 access token
|
||||||
|
token_info = account.con.token_backend.get_access_token()
|
||||||
|
if not token_info:
|
||||||
|
raise RuntimeError("无法获取 access token")
|
||||||
|
access_token = token_info.get("secret")
|
||||||
|
if not access_token:
|
||||||
|
raise RuntimeError("access token 无效")
|
||||||
|
|
||||||
|
folder_id = target_folder.object_id
|
||||||
|
|
||||||
|
# 在线程池中执行同步上传,避免阻塞事件循环
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self._sync_chunked_upload,
|
||||||
|
folder_id,
|
||||||
|
access_token,
|
||||||
|
local_path,
|
||||||
|
file_name,
|
||||||
|
file_size,
|
||||||
|
progress_callback,
|
||||||
|
chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
async def logout(self) -> bool:
|
async def logout(self) -> bool:
|
||||||
"""清除认证"""
|
"""清除认证"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from src.core import (
|
|||||||
DOWNLOAD_DIR,
|
DOWNLOAD_DIR,
|
||||||
)
|
)
|
||||||
from src.core.config import OneDriveConfig
|
from src.core.config import OneDriveConfig
|
||||||
|
from src.cloud.base import UploadProgress, UploadStatus
|
||||||
from src.aria2 import Aria2Installer, Aria2ServiceManager
|
from src.aria2 import Aria2Installer, Aria2ServiceManager
|
||||||
from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
|
from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
|
||||||
from src.telegram.keyboards import (
|
from src.telegram.keyboards import (
|
||||||
@@ -536,9 +537,8 @@ class Aria2BotAPI:
|
|||||||
await self._reply(update, context, text, parse_mode="Markdown")
|
await self._reply(update, context, text, parse_mode="Markdown")
|
||||||
|
|
||||||
async def upload_to_cloud(self, update: Update, context: ContextTypes.DEFAULT_TYPE, gid: str) -> None:
|
async def upload_to_cloud(self, update: Update, context: ContextTypes.DEFAULT_TYPE, gid: str) -> None:
|
||||||
"""上传文件到云存储"""
|
"""上传文件到云存储(启动后台任务,不阻塞其他命令)"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
|
||||||
|
|
||||||
logger.info(f"收到上传请求 GID={gid} - {_get_user_info(update)}")
|
logger.info(f"收到上传请求 GID={gid} - {_get_user_info(update)}")
|
||||||
client = self._get_onedrive_client()
|
client = self._get_onedrive_client()
|
||||||
@@ -572,24 +572,65 @@ class Aria2BotAPI:
|
|||||||
|
|
||||||
msg = await self._reply(update, context, f"☁️ 正在上传: {task.name}\n⏳ 请稍候...")
|
msg = await self._reply(update, context, f"☁️ 正在上传: {task.name}\n⏳ 请稍候...")
|
||||||
|
|
||||||
success = await client.upload_file(local_path, remote_path)
|
# 启动后台上传任务,不阻塞其他命令
|
||||||
|
asyncio.create_task(self._do_upload_to_cloud(
|
||||||
|
client, local_path, remote_path, task.name, msg, gid, _get_user_info(update)
|
||||||
|
))
|
||||||
|
|
||||||
if success:
|
async def _do_upload_to_cloud(
|
||||||
result_text = f"✅ 上传成功: {task.name}"
|
self, client, local_path, remote_path: str, task_name: str, msg, gid: str, user_info: str
|
||||||
if self._onedrive_config and self._onedrive_config.delete_after_upload:
|
) -> None:
|
||||||
|
"""后台执行上传任务"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# 进度回调函数
|
||||||
|
async def update_progress(progress: UploadProgress):
|
||||||
|
"""更新上传进度消息"""
|
||||||
|
if progress.status == UploadStatus.UPLOADING and progress.total_size > 0:
|
||||||
|
percent = progress.progress
|
||||||
|
uploaded_mb = progress.uploaded_size / (1024 * 1024)
|
||||||
|
total_mb = progress.total_size / (1024 * 1024)
|
||||||
|
progress_text = (
|
||||||
|
f"☁️ 正在上传: {task_name}\n"
|
||||||
|
f"📤 {percent:.1f}% ({uploaded_mb:.1f}MB / {total_mb:.1f}MB)"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if local_path.is_dir():
|
await msg.edit_text(progress_text)
|
||||||
shutil.rmtree(local_path)
|
except Exception:
|
||||||
else:
|
pass # 忽略消息更新失败(如内容未变化)
|
||||||
local_path.unlink()
|
|
||||||
result_text += "\n🗑️ 本地文件已删除"
|
def sync_progress_callback(progress: UploadProgress):
|
||||||
except Exception as e:
|
"""同步回调,将异步更新调度到事件循环"""
|
||||||
result_text += f"\n⚠️ 删除本地文件失败: {e}"
|
if progress.status == UploadStatus.UPLOADING:
|
||||||
await msg.edit_text(result_text)
|
asyncio.run_coroutine_threadsafe(update_progress(progress), loop)
|
||||||
logger.info(f"上传成功 GID={gid} - {_get_user_info(update)}")
|
|
||||||
else:
|
try:
|
||||||
await msg.edit_text(f"❌ 上传失败: {task.name}")
|
success = await client.upload_file(local_path, remote_path, progress_callback=sync_progress_callback)
|
||||||
logger.error(f"上传失败 GID={gid} - {_get_user_info(update)}")
|
|
||||||
|
if success:
|
||||||
|
result_text = f"✅ 上传成功: {task_name}"
|
||||||
|
if self._onedrive_config and self._onedrive_config.delete_after_upload:
|
||||||
|
try:
|
||||||
|
if local_path.is_dir():
|
||||||
|
shutil.rmtree(local_path)
|
||||||
|
else:
|
||||||
|
local_path.unlink()
|
||||||
|
result_text += "\n🗑️ 本地文件已删除"
|
||||||
|
except Exception as e:
|
||||||
|
result_text += f"\n⚠️ 删除本地文件失败: {e}"
|
||||||
|
await msg.edit_text(result_text)
|
||||||
|
logger.info(f"上传成功 GID={gid} - {user_info}")
|
||||||
|
else:
|
||||||
|
await msg.edit_text(f"❌ 上传失败: {task_name}")
|
||||||
|
logger.error(f"上传失败 GID={gid} - {user_info}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"上传异常 GID={gid}: {e} - {user_info}")
|
||||||
|
try:
|
||||||
|
await msg.edit_text(f"❌ 上传失败: {task_name}\n错误: {e}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""处理 Reply Keyboard 按钮点击"""
|
"""处理 Reply Keyboard 按钮点击"""
|
||||||
|
|||||||
Reference in New Issue
Block a user