mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-12 04:22:21 +08:00
feat: 增加tg文件上传
This commit is contained in:
@@ -22,7 +22,7 @@ from src.core import (
|
||||
ARIA2_CONF,
|
||||
DOWNLOAD_DIR,
|
||||
)
|
||||
from src.core.config import OneDriveConfig
|
||||
from src.core.config import OneDriveConfig, TelegramChannelConfig
|
||||
from src.cloud.base import UploadProgress, UploadStatus
|
||||
from src.aria2 import Aria2Installer, Aria2ServiceManager
|
||||
from src.aria2.rpc import Aria2RpcClient, DownloadTask, _format_size
|
||||
@@ -88,7 +88,9 @@ from functools import wraps
|
||||
|
||||
class Aria2BotAPI:
|
||||
def __init__(self, config: Aria2Config | None = None, allowed_users: set[int] | None = None,
|
||||
onedrive_config: OneDriveConfig | None = None):
|
||||
onedrive_config: OneDriveConfig | None = None,
|
||||
telegram_channel_config: TelegramChannelConfig | None = None,
|
||||
api_base_url: str = ""):
|
||||
self.config = config or Aria2Config()
|
||||
self.allowed_users = allowed_users or set()
|
||||
self.installer = Aria2Installer(self.config)
|
||||
@@ -102,6 +104,11 @@ class Aria2BotAPI:
|
||||
self._onedrive_config = onedrive_config
|
||||
self._onedrive = None
|
||||
self._pending_auth: dict[int, dict] = {} # user_id -> flow
|
||||
# Telegram 频道存储
|
||||
self._telegram_channel_config = telegram_channel_config
|
||||
self._telegram_channel = None
|
||||
self._api_base_url = api_base_url
|
||||
self._channel_uploaded_gids: set[str] = set() # 已上传到频道的 GID
|
||||
|
||||
async def _check_permission(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool:
|
||||
"""检查用户权限,返回 True 表示有权限"""
|
||||
@@ -132,6 +139,14 @@ class Aria2BotAPI:
|
||||
self._onedrive = OneDriveClient(self._onedrive_config)
|
||||
return self._onedrive
|
||||
|
||||
def _get_telegram_channel_client(self, bot):
|
||||
"""获取或创建 Telegram 频道客户端"""
|
||||
if self._telegram_channel is None and self._telegram_channel_config and self._telegram_channel_config.enabled:
|
||||
from src.cloud.telegram_channel import TelegramChannelClient
|
||||
is_local_api = bool(self._api_base_url)
|
||||
self._telegram_channel = TelegramChannelClient(self._telegram_channel_config, bot, is_local_api)
|
||||
return self._telegram_channel
|
||||
|
||||
async def _reply(self, update: Update, context: ContextTypes.DEFAULT_TYPE, text: str, **kwargs):
|
||||
if update.effective_message:
|
||||
return await update.effective_message.reply_text(text, **kwargs)
|
||||
@@ -740,6 +755,79 @@ class Aria2BotAPI:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _trigger_channel_auto_upload(self, chat_id: int, gid: str, bot) -> None:
|
||||
"""触发频道自动上传"""
|
||||
from pathlib import Path
|
||||
|
||||
logger.info(f"触发频道自动上传 GID={gid}")
|
||||
|
||||
client = self._get_telegram_channel_client(bot)
|
||||
if not client:
|
||||
logger.warning(f"频道上传跳过:频道未配置 GID={gid}")
|
||||
return
|
||||
|
||||
rpc = self._get_rpc_client()
|
||||
try:
|
||||
task = await rpc.get_status(gid)
|
||||
except RpcError as e:
|
||||
logger.error(f"频道上传失败:获取任务信息失败 GID={gid}: {e}")
|
||||
return
|
||||
|
||||
if task.status != "complete":
|
||||
return
|
||||
|
||||
local_path = Path(task.dir) / task.name
|
||||
if not local_path.exists():
|
||||
logger.error(f"频道上传失败:本地文件不存在 GID={gid}")
|
||||
return
|
||||
|
||||
# 检查文件大小
|
||||
file_size = local_path.stat().st_size
|
||||
if file_size > client.get_max_size():
|
||||
limit_mb = client.get_max_size_mb()
|
||||
await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"⚠️ 文件 {task.name} 超过 {limit_mb}MB 限制,跳过频道上传"
|
||||
)
|
||||
return
|
||||
|
||||
asyncio.create_task(self._do_channel_upload(client, local_path, task.name, chat_id, gid, bot))
|
||||
|
||||
async def _do_channel_upload(self, client, local_path, task_name: str, chat_id: int, gid: str, bot) -> None:
|
||||
"""执行频道上传"""
|
||||
import shutil
|
||||
|
||||
try:
|
||||
msg = await bot.send_message(chat_id=chat_id, text=f"📢 正在发送到频道: {task_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"频道上传失败:发送消息失败 GID={gid}: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
success, result = await client.upload_file(local_path)
|
||||
if success:
|
||||
result_text = f"✅ 已发送到频道: {task_name}"
|
||||
if self._telegram_channel_config and self._telegram_channel_config.delete_after_upload:
|
||||
try:
|
||||
if local_path.is_dir():
|
||||
shutil.rmtree(local_path)
|
||||
else:
|
||||
local_path.unlink()
|
||||
result_text += "\n🗑️ 本地文件已删除"
|
||||
except Exception as e:
|
||||
result_text += f"\n⚠️ 删除本地文件失败: {e}"
|
||||
await msg.edit_text(result_text)
|
||||
logger.info(f"频道上传成功 GID={gid}")
|
||||
else:
|
||||
await msg.edit_text(f"❌ 发送到频道失败: {task_name}\n原因: {result}")
|
||||
logger.error(f"频道上传失败 GID={gid}: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"频道上传异常 GID={gid}: {e}")
|
||||
try:
|
||||
await msg.edit_text(f"❌ 发送到频道失败: {task_name}\n错误: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def handle_button_text(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""处理 Reply Keyboard 按钮点击"""
|
||||
text = update.message.text
|
||||
@@ -1097,13 +1185,18 @@ class Aria2BotAPI:
|
||||
if task.error_message:
|
||||
text += f"\n❌ 错误: {task.error_message}"
|
||||
|
||||
# 检查是否显示上传按钮(任务完成且云存储已配置)
|
||||
show_upload = (
|
||||
# 检查是否显示上传按钮
|
||||
show_onedrive = (
|
||||
task.status == "complete" and
|
||||
self._onedrive_config and
|
||||
self._onedrive_config.enabled
|
||||
)
|
||||
keyboard = build_detail_keyboard_with_upload(gid, task.status, show_upload)
|
||||
show_channel = (
|
||||
task.status == "complete" and
|
||||
self._telegram_channel_config and
|
||||
self._telegram_channel_config.enabled
|
||||
)
|
||||
keyboard = build_detail_keyboard_with_upload(gid, task.status, show_onedrive, show_channel)
|
||||
|
||||
# 只有内容变化时才更新
|
||||
if text != last_text:
|
||||
@@ -1176,6 +1269,13 @@ class Aria2BotAPI:
|
||||
text = f"✅ *下载完成*\n📄 {safe_name}\n📦 大小: {task.size_str}\n🆔 GID: `{task.gid}`"
|
||||
try:
|
||||
await _bot_instance.send_message(chat_id=chat_id, text=text, parse_mode="Markdown")
|
||||
# 触发频道自动上传
|
||||
if (self._telegram_channel_config and
|
||||
self._telegram_channel_config.enabled and
|
||||
self._telegram_channel_config.auto_upload and
|
||||
task.gid not in self._channel_uploaded_gids):
|
||||
self._channel_uploaded_gids.add(task.gid)
|
||||
asyncio.create_task(self._trigger_channel_auto_upload(chat_id, task.gid, _bot_instance))
|
||||
except Exception as e:
|
||||
logger.warning(f"发送完成通知失败 (GID={task.gid}): {e}")
|
||||
|
||||
@@ -1276,11 +1376,52 @@ class Aria2BotAPI:
|
||||
await query.edit_message_text("❌ 无效操作")
|
||||
return
|
||||
|
||||
provider = parts[1] # onedrive
|
||||
provider = parts[1] # onedrive / telegram
|
||||
gid = parts[2]
|
||||
|
||||
if provider == "onedrive":
|
||||
await self.upload_to_cloud(update, context, gid)
|
||||
elif provider == "telegram":
|
||||
await self._upload_to_channel_manual(query, update, context, gid)
|
||||
|
||||
async def _upload_to_channel_manual(self, query, update: Update, context: ContextTypes.DEFAULT_TYPE, gid: str) -> None:
|
||||
"""手动上传到频道"""
|
||||
from pathlib import Path
|
||||
|
||||
client = self._get_telegram_channel_client(context.bot)
|
||||
if not client:
|
||||
await query.edit_message_text("❌ 频道存储未配置")
|
||||
return
|
||||
|
||||
rpc = self._get_rpc_client()
|
||||
try:
|
||||
task = await rpc.get_status(gid)
|
||||
except RpcError as e:
|
||||
await query.edit_message_text(f"❌ 获取任务信息失败: {e}")
|
||||
return
|
||||
|
||||
if task.status != "complete":
|
||||
await query.edit_message_text("❌ 任务未完成,无法上传")
|
||||
return
|
||||
|
||||
local_path = Path(task.dir) / task.name
|
||||
if not local_path.exists():
|
||||
await query.edit_message_text("❌ 本地文件不存在")
|
||||
return
|
||||
|
||||
# 检查文件大小
|
||||
file_size = local_path.stat().st_size
|
||||
if file_size > client.get_max_size():
|
||||
limit_mb = client.get_max_size_mb()
|
||||
await query.edit_message_text(f"❌ 文件超过 {limit_mb}MB 限制")
|
||||
return
|
||||
|
||||
await query.edit_message_text(f"📢 正在发送到频道: {task.name}")
|
||||
success, result = await client.upload_file(local_path)
|
||||
if success:
|
||||
await query.edit_message_text(f"✅ 已发送到频道: {task.name}")
|
||||
else:
|
||||
await query.edit_message_text(f"❌ 发送失败: {result}")
|
||||
|
||||
|
||||
def build_handlers(api: Aria2BotAPI) -> list:
|
||||
|
||||
Reference in New Issue
Block a user