feat: 增加子进程管理aria2

This commit is contained in:
dnslin
2025-12-12 17:12:05 +08:00
parent bf4fdf1377
commit debce7c476
5 changed files with 357 additions and 51 deletions

View File

@@ -1,8 +1,14 @@
"""Aria2 service manager - systemd service lifecycle management."""
"""Aria2 service manager - 支持 systemd 和子进程两种模式."""
from __future__ import annotations
import atexit
import os
import signal
import subprocess
import sys
import threading
import time
from abc import ABC, abstractmethod
from src.utils.logger import get_logger
@@ -16,6 +22,7 @@ from src.core import (
NotInstalledError,
ConfigError,
is_aria2_installed,
detect_service_mode,
)
@@ -37,10 +44,93 @@ WantedBy=default.target
logger = get_logger("service")
class Aria2ServiceManager:
def __init__(self) -> None:
class ServiceManagerBase(ABC):
"""服务管理器抽象基类"""
@abstractmethod
def start(self) -> None:
"""启动 aria2 服务"""
pass
@abstractmethod
def stop(self) -> None:
"""停止 aria2 服务"""
pass
@abstractmethod
def restart(self) -> None:
"""重启 aria2 服务"""
pass
@abstractmethod
def status(self) -> dict:
"""获取服务状态,返回 {installed, running, pid, enabled}"""
pass
@abstractmethod
def get_pid(self) -> int | None:
"""获取 aria2 进程 PID"""
pass
def enable(self) -> None:
"""启用开机自启(子进程模式下静默忽略)"""
pass
def disable(self) -> None:
"""禁用开机自启(子进程模式下静默忽略)"""
pass
def view_log(self, lines: int = 50) -> str:
"""查看日志"""
if lines <= 0 or not ARIA2_LOG.exists():
return ""
try:
content = ARIA2_LOG.read_text(encoding="utf-8", errors="ignore")
except OSError as exc:
raise ServiceError(f"读取日志失败: {exc}") from exc
log_lines = content.splitlines(keepends=True)
return "".join(log_lines[-lines:])
def clear_log(self) -> None:
"""清空日志"""
try:
ARIA2_LOG.parent.mkdir(parents=True, exist_ok=True)
ARIA2_LOG.write_text("", encoding="utf-8")
except OSError as exc:
raise ServiceError(f"清空日志失败: {exc}") from exc
def update_rpc_secret(self, new_secret: str) -> None:
"""更新 aria2.conf 中的 rpc-secret 配置"""
if not ARIA2_CONF.exists():
raise ConfigError("aria2.conf 不存在,请先安装 aria2")
try:
content = ARIA2_CONF.read_text(encoding="utf-8")
lines = content.splitlines()
new_lines = []
found = False
for line in lines:
stripped = line.lstrip()
if stripped.startswith("rpc-secret="):
prefix = line[: len(line) - len(stripped)]
new_lines.append(f"{prefix}rpc-secret={new_secret}")
found = True
else:
new_lines.append(line)
if not found:
new_lines.append(f"rpc-secret={new_secret}")
ARIA2_CONF.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
logger.info("RPC 密钥已更新")
except OSError as exc:
raise ConfigError(f"更新配置文件失败: {exc}") from exc
def remove_service(self) -> None:
"""移除服务(子类可覆盖)"""
self.stop()
class SystemdServiceManager(ServiceManagerBase):
"""基于 systemctl --user 的服务管理器"""
def _run_systemctl(self, *args: str) -> subprocess.CompletedProcess[str]:
try:
return subprocess.run(
@@ -68,23 +158,23 @@ class Aria2ServiceManager:
ARIA2_SERVICE.write_text(content, encoding="utf-8")
self._run_systemctl("daemon-reload")
except OSError as exc:
raise ServiceError(f"Failed to write service file: {exc}") from exc
raise ServiceError(f"写入服务文件失败: {exc}") from exc
def start(self) -> None:
logger.info("正在启动 aria2 服务...")
logger.info("正在启动 aria2 服务 (systemd)...")
if not is_aria2_installed():
raise NotInstalledError("aria2 is not installed")
raise NotInstalledError("aria2 未安装")
self._ensure_service_file()
self._run_systemctl("start", "aria2")
logger.info("aria2 服务已启动")
def stop(self) -> None:
logger.info("正在停止 aria2 服务...")
logger.info("正在停止 aria2 服务 (systemd)...")
self._run_systemctl("stop", "aria2")
logger.info("aria2 服务已停止")
def restart(self) -> None:
logger.info("正在重启 aria2 服务...")
logger.info("正在重启 aria2 服务 (systemd)...")
self._run_systemctl("restart", "aria2")
logger.info("aria2 服务已重启")
@@ -95,7 +185,7 @@ class Aria2ServiceManager:
self._run_systemctl("disable", "aria2")
def status(self) -> dict:
logger.info("正在获取 aria2 服务状态...")
logger.info("正在获取 aria2 服务状态 (systemd)...")
installed = is_aria2_installed()
pid = self.get_pid() if installed else None
@@ -165,52 +255,203 @@ class Aria2ServiceManager:
return int(line)
return None
def view_log(self, lines: int = 50) -> str:
if lines <= 0 or not ARIA2_LOG.exists():
return ""
try:
content = ARIA2_LOG.read_text(encoding="utf-8", errors="ignore")
except OSError as exc:
raise ServiceError(f"Failed to read log: {exc}") from exc
log_lines = content.splitlines(keepends=True)
return "".join(log_lines[-lines:])
def clear_log(self) -> None:
try:
ARIA2_LOG.parent.mkdir(parents=True, exist_ok=True)
ARIA2_LOG.write_text("", encoding="utf-8")
except OSError as exc:
raise ServiceError(f"Failed to clear log: {exc}") from exc
def remove_service(self) -> None:
self.stop()
try:
ARIA2_SERVICE.unlink(missing_ok=True)
except OSError as exc:
raise ServiceError(f"Failed to remove service file: {exc}") from exc
raise ServiceError(f"删除服务文件失败: {exc}") from exc
self._run_systemctl("daemon-reload")
def update_rpc_secret(self, new_secret: str) -> None:
"""更新 aria2.conf 中的 rpc-secret 配置"""
class SubprocessServiceManager(ServiceManagerBase):
"""基于子进程的服务管理器(用于 Docker 等无 systemd 环境)"""
_instance: "SubprocessServiceManager | None" = None
_lock = threading.Lock()
def __new__(cls) -> "SubprocessServiceManager":
"""单例模式,确保只有一个子进程管理器"""
with cls._lock:
if cls._instance is None:
instance = super().__new__(cls)
instance._process: subprocess.Popen | None = None
instance._registered_cleanup = False
cls._instance = instance
return cls._instance
def _register_cleanup(self) -> None:
"""注册退出清理函数"""
if not self._registered_cleanup:
atexit.register(self._cleanup)
# 保存原始信号处理器
self._original_sigterm = signal.getsignal(signal.SIGTERM)
self._original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
self._registered_cleanup = True
def _signal_handler(self, signum: int, frame) -> None:
"""信号处理器"""
self._cleanup()
# 调用原始处理器或默认退出
if signum == signal.SIGTERM and callable(self._original_sigterm):
self._original_sigterm(signum, frame)
elif signum == signal.SIGINT and callable(self._original_sigint):
self._original_sigint(signum, frame)
else:
sys.exit(0)
def _cleanup(self) -> None:
"""清理子进程"""
if self._process and self._process.poll() is None:
logger.info("正在停止 aria2 子进程...")
self._process.terminate()
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
self._process.kill()
self._process.wait()
logger.info("aria2 子进程已停止")
def start(self) -> None:
"""启动 aria2 子进程"""
logger.info("正在启动 aria2 子进程...")
if not is_aria2_installed():
raise NotInstalledError("aria2 未安装")
if self._process and self._process.poll() is None:
logger.info("aria2 已在运行")
return
if not ARIA2_CONF.exists():
raise ConfigError("aria2.conf 不存在,请先安装 aria2")
raise ConfigError("aria2.conf 不存在")
self._register_cleanup()
# 确保日志文件目录存在
ARIA2_LOG.parent.mkdir(parents=True, exist_ok=True)
# 启动子进程,日志输出到文件
log_file = open(ARIA2_LOG, "a", encoding="utf-8")
self._process = subprocess.Popen(
[str(ARIA2_BIN), f"--conf-path={ARIA2_CONF}"],
stdout=log_file,
stderr=subprocess.STDOUT,
start_new_session=True, # 创建新会话,避免信号传播
)
# 等待短暂时间检查是否启动成功
time.sleep(0.5)
if self._process.poll() is not None:
raise ServiceError(f"aria2 启动失败,退出码: {self._process.returncode}")
logger.info(f"aria2 子进程已启动PID={self._process.pid}")
def stop(self) -> None:
"""停止 aria2 子进程"""
logger.info("正在停止 aria2 子进程...")
if self._process and self._process.poll() is None:
self._process.terminate()
try:
self._process.wait(timeout=10)
except subprocess.TimeoutExpired:
self._process.kill()
self._process.wait()
self._process = None
logger.info("aria2 子进程已停止")
return
# 尝试通过 PID 查找并停止
pid = self.get_pid()
if pid:
try:
os.kill(pid, signal.SIGTERM)
logger.info(f"已发送 SIGTERM 到 PID={pid}")
except OSError:
pass
def restart(self) -> None:
"""重启 aria2 子进程"""
logger.info("正在重启 aria2 子进程...")
self.stop()
time.sleep(1)
self.start()
logger.info("aria2 子进程已重启")
def status(self) -> dict:
"""获取服务状态"""
logger.info("正在获取 aria2 子进程状态...")
installed = is_aria2_installed()
pid = self.get_pid() if installed else None
running = pid is not None
logger.info(f"aria2 状态: 已安装={installed}, 运行中={running}, PID={pid}")
return {
"installed": installed,
"running": running,
"pid": pid,
"enabled": False, # 子进程模式不支持开机自启
}
def get_pid(self) -> int | None:
"""获取 aria2 进程 PID"""
# 优先检查管理的子进程
if self._process and self._process.poll() is None:
return self._process.pid
# 回退到 pgrep 查找
try:
content = ARIA2_CONF.read_text(encoding="utf-8")
lines = content.splitlines()
new_lines = []
found = False
for line in lines:
stripped = line.lstrip()
if stripped.startswith("rpc-secret="):
prefix = line[: len(line) - len(stripped)]
new_lines.append(f"{prefix}rpc-secret={new_secret}")
found = True
else:
new_lines.append(line)
if not found:
new_lines.append(f"rpc-secret={new_secret}")
ARIA2_CONF.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
logger.info("RPC 密钥已更新")
except OSError as exc:
raise ConfigError(f"更新配置文件失败: {exc}") from exc
result = subprocess.run(
["pgrep", "-u", str(os.getuid()), "-f", "aria2c"],
capture_output=True,
text=True,
check=False,
timeout=5,
)
if result.returncode == 0:
for line in result.stdout.splitlines():
line = line.strip()
if line.isdigit():
return int(line)
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
return None
# 缓存服务管理器实例
_service_manager: ServiceManagerBase | None = None
_service_mode: str | None = None
def Aria2ServiceManager() -> ServiceManagerBase:
"""工厂函数:根据环境自动选择服务管理器
保持与现有代码的兼容性,调用方式不变:
service = Aria2ServiceManager()
service.start()
"""
global _service_manager, _service_mode
if _service_manager is not None:
return _service_manager
mode = detect_service_mode()
_service_mode = mode
logger.info(f"检测到服务管理模式: {mode}")
if mode == "systemd":
_service_manager = SystemdServiceManager()
else:
_service_manager = SubprocessServiceManager()
return _service_manager
def get_service_mode() -> str:
"""获取当前服务管理模式"""
global _service_mode
if _service_mode is None:
_service_mode = detect_service_mode()
return _service_mode

View File

@@ -26,6 +26,7 @@ from src.core.config import Aria2Config, BotConfig
from src.core.system import (
detect_os,
detect_arch,
detect_service_mode,
generate_rpc_secret,
is_aria2_installed,
get_aria2_version,
@@ -55,6 +56,7 @@ __all__ = [
"BotConfig",
"detect_os",
"detect_arch",
"detect_service_mode",
"generate_rpc_secret",
"is_aria2_installed",
"get_aria2_version",

View File

@@ -57,6 +57,45 @@ def generate_rpc_secret() -> str:
return "".join(secrets.choice(alphabet) for _ in range(20))
def detect_service_mode() -> str:
"""检测应使用的服务管理模式
返回: 'systemd''subprocess'
可通过环境变量 ARIA2_SERVICE_MODE 强制指定模式(用于测试)
"""
import os
# 允许通过环境变量强制指定模式
forced_mode = os.environ.get("ARIA2_SERVICE_MODE", "").lower()
if forced_mode in ("systemd", "subprocess"):
return forced_mode
# 1. 检查 systemctl 命令是否存在
if shutil.which("systemctl") is None:
return "subprocess"
# 2. 检查 systemd 是否真正运行Docker 容器中可能有命令但无服务)
try:
result = subprocess.run(
["systemctl", "--user", "is-system-running"],
capture_output=True,
text=True,
timeout=5,
)
# running/degraded/starting/initializing 都表示 systemd 可用
status = result.stdout.strip()
if status in ("running", "degraded", "starting", "initializing"):
return "systemd"
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
pass
# 3. 备选检查:检查 /run/systemd/system 目录
if Path("/run/systemd/system").exists():
return "systemd"
return "subprocess"
def is_aria2_installed() -> bool:
"""检查 aria2c 是否已安装"""
if ARIA2_BIN.exists():

View File

@@ -6,7 +6,8 @@ import sys
from telegram import Bot, BotCommand
from telegram.ext import Application
from src.core import BotConfig
from src.core import BotConfig, is_aria2_installed
from src.aria2.service import Aria2ServiceManager, get_service_mode
from src.telegram.handlers import Aria2BotAPI, build_handlers
from src.utils import setup_logger
@@ -57,6 +58,27 @@ def create_app(config: BotConfig) -> Application:
return app
def _auto_start_aria2() -> None:
"""子进程模式下自动启动 aria2如果已安装"""
logger = setup_logger()
mode = get_service_mode()
if mode != "subprocess":
logger.info(f"服务管理模式: {mode},跳过自动启动")
return
if not is_aria2_installed():
logger.info("aria2 未安装,跳过自动启动")
return
try:
service = Aria2ServiceManager()
service.start()
logger.info("aria2 子进程已自动启动")
except Exception as e:
logger.warning(f"自动启动 aria2 失败: {e}")
def run() -> None:
"""加载配置并启动 bot"""
import asyncio
@@ -68,6 +90,9 @@ def run() -> None:
logger.error("Please set TELEGRAM_BOT_TOKEN in .env or environment")
sys.exit(1)
# 子进程模式下自动启动 aria2
_auto_start_aria2()
app = create_app(config)
logger.info("Bot starting...")

View File

@@ -1280,7 +1280,6 @@ class Aria2BotAPI:
gid = parts[2]
if provider == "onedrive":
await query.edit_message_text("☁️ 正在准备上传...")
await self.upload_to_cloud(update, context, gid)