mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-11 04:02:20 +08:00
feat: 增加子进程管理aria2
This commit is contained in:
@@ -1,8 +1,14 @@
|
||||
"""Aria2 service manager - systemd service lifecycle management."""
|
||||
"""Aria2 service manager - 支持 systemd 和子进程两种模式."""
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
@@ -16,6 +22,7 @@ from src.core import (
|
||||
NotInstalledError,
|
||||
ConfigError,
|
||||
is_aria2_installed,
|
||||
detect_service_mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -37,10 +44,93 @@ WantedBy=default.target
|
||||
logger = get_logger("service")
|
||||
|
||||
|
||||
class Aria2ServiceManager:
|
||||
def __init__(self) -> None:
|
||||
class ServiceManagerBase(ABC):
|
||||
"""服务管理器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""启动 aria2 服务"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""停止 aria2 服务"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def restart(self) -> None:
|
||||
"""重启 aria2 服务"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def status(self) -> dict:
|
||||
"""获取服务状态,返回 {installed, running, pid, enabled}"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pid(self) -> int | None:
|
||||
"""获取 aria2 进程 PID"""
|
||||
pass
|
||||
|
||||
def enable(self) -> None:
|
||||
"""启用开机自启(子进程模式下静默忽略)"""
|
||||
pass
|
||||
|
||||
def disable(self) -> None:
|
||||
"""禁用开机自启(子进程模式下静默忽略)"""
|
||||
pass
|
||||
|
||||
def view_log(self, lines: int = 50) -> str:
|
||||
"""查看日志"""
|
||||
if lines <= 0 or not ARIA2_LOG.exists():
|
||||
return ""
|
||||
try:
|
||||
content = ARIA2_LOG.read_text(encoding="utf-8", errors="ignore")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"读取日志失败: {exc}") from exc
|
||||
log_lines = content.splitlines(keepends=True)
|
||||
return "".join(log_lines[-lines:])
|
||||
|
||||
def clear_log(self) -> None:
|
||||
"""清空日志"""
|
||||
try:
|
||||
ARIA2_LOG.parent.mkdir(parents=True, exist_ok=True)
|
||||
ARIA2_LOG.write_text("", encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"清空日志失败: {exc}") from exc
|
||||
|
||||
def update_rpc_secret(self, new_secret: str) -> None:
|
||||
"""更新 aria2.conf 中的 rpc-secret 配置"""
|
||||
if not ARIA2_CONF.exists():
|
||||
raise ConfigError("aria2.conf 不存在,请先安装 aria2")
|
||||
try:
|
||||
content = ARIA2_CONF.read_text(encoding="utf-8")
|
||||
lines = content.splitlines()
|
||||
new_lines = []
|
||||
found = False
|
||||
for line in lines:
|
||||
stripped = line.lstrip()
|
||||
if stripped.startswith("rpc-secret="):
|
||||
prefix = line[: len(line) - len(stripped)]
|
||||
new_lines.append(f"{prefix}rpc-secret={new_secret}")
|
||||
found = True
|
||||
else:
|
||||
new_lines.append(line)
|
||||
if not found:
|
||||
new_lines.append(f"rpc-secret={new_secret}")
|
||||
ARIA2_CONF.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
|
||||
logger.info("RPC 密钥已更新")
|
||||
except OSError as exc:
|
||||
raise ConfigError(f"更新配置文件失败: {exc}") from exc
|
||||
|
||||
def remove_service(self) -> None:
|
||||
"""移除服务(子类可覆盖)"""
|
||||
self.stop()
|
||||
|
||||
|
||||
class SystemdServiceManager(ServiceManagerBase):
|
||||
"""基于 systemctl --user 的服务管理器"""
|
||||
|
||||
def _run_systemctl(self, *args: str) -> subprocess.CompletedProcess[str]:
|
||||
try:
|
||||
return subprocess.run(
|
||||
@@ -68,23 +158,23 @@ class Aria2ServiceManager:
|
||||
ARIA2_SERVICE.write_text(content, encoding="utf-8")
|
||||
self._run_systemctl("daemon-reload")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to write service file: {exc}") from exc
|
||||
raise ServiceError(f"写入服务文件失败: {exc}") from exc
|
||||
|
||||
def start(self) -> None:
|
||||
logger.info("正在启动 aria2 服务...")
|
||||
logger.info("正在启动 aria2 服务 (systemd)...")
|
||||
if not is_aria2_installed():
|
||||
raise NotInstalledError("aria2 is not installed")
|
||||
raise NotInstalledError("aria2 未安装")
|
||||
self._ensure_service_file()
|
||||
self._run_systemctl("start", "aria2")
|
||||
logger.info("aria2 服务已启动")
|
||||
|
||||
def stop(self) -> None:
|
||||
logger.info("正在停止 aria2 服务...")
|
||||
logger.info("正在停止 aria2 服务 (systemd)...")
|
||||
self._run_systemctl("stop", "aria2")
|
||||
logger.info("aria2 服务已停止")
|
||||
|
||||
def restart(self) -> None:
|
||||
logger.info("正在重启 aria2 服务...")
|
||||
logger.info("正在重启 aria2 服务 (systemd)...")
|
||||
self._run_systemctl("restart", "aria2")
|
||||
logger.info("aria2 服务已重启")
|
||||
|
||||
@@ -95,7 +185,7 @@ class Aria2ServiceManager:
|
||||
self._run_systemctl("disable", "aria2")
|
||||
|
||||
def status(self) -> dict:
|
||||
logger.info("正在获取 aria2 服务状态...")
|
||||
logger.info("正在获取 aria2 服务状态 (systemd)...")
|
||||
installed = is_aria2_installed()
|
||||
pid = self.get_pid() if installed else None
|
||||
|
||||
@@ -165,52 +255,203 @@ class Aria2ServiceManager:
|
||||
return int(line)
|
||||
return None
|
||||
|
||||
def view_log(self, lines: int = 50) -> str:
|
||||
if lines <= 0 or not ARIA2_LOG.exists():
|
||||
return ""
|
||||
try:
|
||||
content = ARIA2_LOG.read_text(encoding="utf-8", errors="ignore")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to read log: {exc}") from exc
|
||||
|
||||
log_lines = content.splitlines(keepends=True)
|
||||
return "".join(log_lines[-lines:])
|
||||
|
||||
def clear_log(self) -> None:
|
||||
try:
|
||||
ARIA2_LOG.parent.mkdir(parents=True, exist_ok=True)
|
||||
ARIA2_LOG.write_text("", encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to clear log: {exc}") from exc
|
||||
|
||||
def remove_service(self) -> None:
|
||||
self.stop()
|
||||
try:
|
||||
ARIA2_SERVICE.unlink(missing_ok=True)
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to remove service file: {exc}") from exc
|
||||
raise ServiceError(f"删除服务文件失败: {exc}") from exc
|
||||
self._run_systemctl("daemon-reload")
|
||||
|
||||
def update_rpc_secret(self, new_secret: str) -> None:
|
||||
"""更新 aria2.conf 中的 rpc-secret 配置"""
|
||||
|
||||
class SubprocessServiceManager(ServiceManagerBase):
|
||||
"""基于子进程的服务管理器(用于 Docker 等无 systemd 环境)"""
|
||||
|
||||
_instance: "SubprocessServiceManager | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "SubprocessServiceManager":
|
||||
"""单例模式,确保只有一个子进程管理器"""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._process: subprocess.Popen | None = None
|
||||
instance._registered_cleanup = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def _register_cleanup(self) -> None:
|
||||
"""注册退出清理函数"""
|
||||
if not self._registered_cleanup:
|
||||
atexit.register(self._cleanup)
|
||||
# 保存原始信号处理器
|
||||
self._original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
self._original_sigint = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
self._registered_cleanup = True
|
||||
|
||||
def _signal_handler(self, signum: int, frame) -> None:
|
||||
"""信号处理器"""
|
||||
self._cleanup()
|
||||
# 调用原始处理器或默认退出
|
||||
if signum == signal.SIGTERM and callable(self._original_sigterm):
|
||||
self._original_sigterm(signum, frame)
|
||||
elif signum == signal.SIGINT and callable(self._original_sigint):
|
||||
self._original_sigint(signum, frame)
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""清理子进程"""
|
||||
if self._process and self._process.poll() is None:
|
||||
logger.info("正在停止 aria2 子进程...")
|
||||
self._process.terminate()
|
||||
try:
|
||||
self._process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._process.kill()
|
||||
self._process.wait()
|
||||
logger.info("aria2 子进程已停止")
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动 aria2 子进程"""
|
||||
logger.info("正在启动 aria2 子进程...")
|
||||
if not is_aria2_installed():
|
||||
raise NotInstalledError("aria2 未安装")
|
||||
|
||||
if self._process and self._process.poll() is None:
|
||||
logger.info("aria2 已在运行")
|
||||
return
|
||||
|
||||
if not ARIA2_CONF.exists():
|
||||
raise ConfigError("aria2.conf 不存在,请先安装 aria2")
|
||||
raise ConfigError("aria2.conf 不存在")
|
||||
|
||||
self._register_cleanup()
|
||||
|
||||
# 确保日志文件目录存在
|
||||
ARIA2_LOG.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 启动子进程,日志输出到文件
|
||||
log_file = open(ARIA2_LOG, "a", encoding="utf-8")
|
||||
self._process = subprocess.Popen(
|
||||
[str(ARIA2_BIN), f"--conf-path={ARIA2_CONF}"],
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # 创建新会话,避免信号传播
|
||||
)
|
||||
|
||||
# 等待短暂时间检查是否启动成功
|
||||
time.sleep(0.5)
|
||||
if self._process.poll() is not None:
|
||||
raise ServiceError(f"aria2 启动失败,退出码: {self._process.returncode}")
|
||||
|
||||
logger.info(f"aria2 子进程已启动,PID={self._process.pid}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""停止 aria2 子进程"""
|
||||
logger.info("正在停止 aria2 子进程...")
|
||||
if self._process and self._process.poll() is None:
|
||||
self._process.terminate()
|
||||
try:
|
||||
self._process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._process.kill()
|
||||
self._process.wait()
|
||||
self._process = None
|
||||
logger.info("aria2 子进程已停止")
|
||||
return
|
||||
|
||||
# 尝试通过 PID 查找并停止
|
||||
pid = self.get_pid()
|
||||
if pid:
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
logger.info(f"已发送 SIGTERM 到 PID={pid}")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def restart(self) -> None:
|
||||
"""重启 aria2 子进程"""
|
||||
logger.info("正在重启 aria2 子进程...")
|
||||
self.stop()
|
||||
time.sleep(1)
|
||||
self.start()
|
||||
logger.info("aria2 子进程已重启")
|
||||
|
||||
def status(self) -> dict:
|
||||
"""获取服务状态"""
|
||||
logger.info("正在获取 aria2 子进程状态...")
|
||||
installed = is_aria2_installed()
|
||||
pid = self.get_pid() if installed else None
|
||||
running = pid is not None
|
||||
|
||||
logger.info(f"aria2 状态: 已安装={installed}, 运行中={running}, PID={pid}")
|
||||
return {
|
||||
"installed": installed,
|
||||
"running": running,
|
||||
"pid": pid,
|
||||
"enabled": False, # 子进程模式不支持开机自启
|
||||
}
|
||||
|
||||
def get_pid(self) -> int | None:
|
||||
"""获取 aria2 进程 PID"""
|
||||
# 优先检查管理的子进程
|
||||
if self._process and self._process.poll() is None:
|
||||
return self._process.pid
|
||||
|
||||
# 回退到 pgrep 查找
|
||||
try:
|
||||
content = ARIA2_CONF.read_text(encoding="utf-8")
|
||||
lines = content.splitlines()
|
||||
new_lines = []
|
||||
found = False
|
||||
for line in lines:
|
||||
stripped = line.lstrip()
|
||||
if stripped.startswith("rpc-secret="):
|
||||
prefix = line[: len(line) - len(stripped)]
|
||||
new_lines.append(f"{prefix}rpc-secret={new_secret}")
|
||||
found = True
|
||||
else:
|
||||
new_lines.append(line)
|
||||
if not found:
|
||||
new_lines.append(f"rpc-secret={new_secret}")
|
||||
ARIA2_CONF.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
|
||||
logger.info("RPC 密钥已更新")
|
||||
except OSError as exc:
|
||||
raise ConfigError(f"更新配置文件失败: {exc}") from exc
|
||||
result = subprocess.run(
|
||||
["pgrep", "-u", str(os.getuid()), "-f", "aria2c"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if line.isdigit():
|
||||
return int(line)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# 缓存服务管理器实例
|
||||
_service_manager: ServiceManagerBase | None = None
|
||||
_service_mode: str | None = None
|
||||
|
||||
|
||||
def Aria2ServiceManager() -> ServiceManagerBase:
|
||||
"""工厂函数:根据环境自动选择服务管理器
|
||||
|
||||
保持与现有代码的兼容性,调用方式不变:
|
||||
service = Aria2ServiceManager()
|
||||
service.start()
|
||||
"""
|
||||
global _service_manager, _service_mode
|
||||
|
||||
if _service_manager is not None:
|
||||
return _service_manager
|
||||
|
||||
mode = detect_service_mode()
|
||||
_service_mode = mode
|
||||
logger.info(f"检测到服务管理模式: {mode}")
|
||||
|
||||
if mode == "systemd":
|
||||
_service_manager = SystemdServiceManager()
|
||||
else:
|
||||
_service_manager = SubprocessServiceManager()
|
||||
|
||||
return _service_manager
|
||||
|
||||
|
||||
def get_service_mode() -> str:
|
||||
"""获取当前服务管理模式"""
|
||||
global _service_mode
|
||||
if _service_mode is None:
|
||||
_service_mode = detect_service_mode()
|
||||
return _service_mode
|
||||
|
||||
@@ -26,6 +26,7 @@ from src.core.config import Aria2Config, BotConfig
|
||||
from src.core.system import (
|
||||
detect_os,
|
||||
detect_arch,
|
||||
detect_service_mode,
|
||||
generate_rpc_secret,
|
||||
is_aria2_installed,
|
||||
get_aria2_version,
|
||||
@@ -55,6 +56,7 @@ __all__ = [
|
||||
"BotConfig",
|
||||
"detect_os",
|
||||
"detect_arch",
|
||||
"detect_service_mode",
|
||||
"generate_rpc_secret",
|
||||
"is_aria2_installed",
|
||||
"get_aria2_version",
|
||||
|
||||
@@ -57,6 +57,45 @@ def generate_rpc_secret() -> str:
|
||||
return "".join(secrets.choice(alphabet) for _ in range(20))
|
||||
|
||||
|
||||
def detect_service_mode() -> str:
|
||||
"""检测应使用的服务管理模式
|
||||
|
||||
返回: 'systemd' 或 'subprocess'
|
||||
|
||||
可通过环境变量 ARIA2_SERVICE_MODE 强制指定模式(用于测试)
|
||||
"""
|
||||
import os
|
||||
# 允许通过环境变量强制指定模式
|
||||
forced_mode = os.environ.get("ARIA2_SERVICE_MODE", "").lower()
|
||||
if forced_mode in ("systemd", "subprocess"):
|
||||
return forced_mode
|
||||
|
||||
# 1. 检查 systemctl 命令是否存在
|
||||
if shutil.which("systemctl") is None:
|
||||
return "subprocess"
|
||||
|
||||
# 2. 检查 systemd 是否真正运行(Docker 容器中可能有命令但无服务)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-system-running"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
# running/degraded/starting/initializing 都表示 systemd 可用
|
||||
status = result.stdout.strip()
|
||||
if status in ("running", "degraded", "starting", "initializing"):
|
||||
return "systemd"
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
# 3. 备选检查:检查 /run/systemd/system 目录
|
||||
if Path("/run/systemd/system").exists():
|
||||
return "systemd"
|
||||
|
||||
return "subprocess"
|
||||
|
||||
|
||||
def is_aria2_installed() -> bool:
|
||||
"""检查 aria2c 是否已安装"""
|
||||
if ARIA2_BIN.exists():
|
||||
|
||||
@@ -6,7 +6,8 @@ import sys
|
||||
from telegram import Bot, BotCommand
|
||||
from telegram.ext import Application
|
||||
|
||||
from src.core import BotConfig
|
||||
from src.core import BotConfig, is_aria2_installed
|
||||
from src.aria2.service import Aria2ServiceManager, get_service_mode
|
||||
from src.telegram.handlers import Aria2BotAPI, build_handlers
|
||||
from src.utils import setup_logger
|
||||
|
||||
@@ -57,6 +58,27 @@ def create_app(config: BotConfig) -> Application:
|
||||
return app
|
||||
|
||||
|
||||
def _auto_start_aria2() -> None:
|
||||
"""子进程模式下自动启动 aria2(如果已安装)"""
|
||||
logger = setup_logger()
|
||||
mode = get_service_mode()
|
||||
|
||||
if mode != "subprocess":
|
||||
logger.info(f"服务管理模式: {mode},跳过自动启动")
|
||||
return
|
||||
|
||||
if not is_aria2_installed():
|
||||
logger.info("aria2 未安装,跳过自动启动")
|
||||
return
|
||||
|
||||
try:
|
||||
service = Aria2ServiceManager()
|
||||
service.start()
|
||||
logger.info("aria2 子进程已自动启动")
|
||||
except Exception as e:
|
||||
logger.warning(f"自动启动 aria2 失败: {e}")
|
||||
|
||||
|
||||
def run() -> None:
|
||||
"""加载配置并启动 bot"""
|
||||
import asyncio
|
||||
@@ -68,6 +90,9 @@ def run() -> None:
|
||||
logger.error("Please set TELEGRAM_BOT_TOKEN in .env or environment")
|
||||
sys.exit(1)
|
||||
|
||||
# 子进程模式下自动启动 aria2
|
||||
_auto_start_aria2()
|
||||
|
||||
app = create_app(config)
|
||||
logger.info("Bot starting...")
|
||||
|
||||
|
||||
@@ -1280,7 +1280,6 @@ class Aria2BotAPI:
|
||||
gid = parts[2]
|
||||
|
||||
if provider == "onedrive":
|
||||
await query.edit_message_text("☁️ 正在准备上传...")
|
||||
await self.upload_to_cloud(update, context, gid)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user