mirror of
https://github.com/dnslin/aria2bot.git
synced 2026-01-11 20:12:20 +08:00
feat: init project
This commit is contained in:
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Aria2bot - Telegram bot for managing aria2 downloads."""
|
||||
5
src/aria2/__init__.py
Normal file
5
src/aria2/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Aria2 operations module - installer and service management."""
|
||||
from src.aria2.installer import Aria2Installer
|
||||
from src.aria2.service import Aria2ServiceManager
|
||||
|
||||
__all__ = ["Aria2Installer", "Aria2ServiceManager"]
|
||||
246
src/aria2/installer.py
Normal file
246
src/aria2/installer.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Aria2 installer - download, install, and configure aria2."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from urllib import error, request
|
||||
|
||||
from src.core import (
|
||||
ARIA2_BIN,
|
||||
ARIA2_CONFIG_DIR,
|
||||
ARIA2_CONF,
|
||||
ARIA2_LOG,
|
||||
ARIA2_SESSION,
|
||||
Aria2Config,
|
||||
Aria2Error,
|
||||
ConfigError,
|
||||
DownloadError,
|
||||
detect_arch,
|
||||
detect_os,
|
||||
generate_rpc_secret,
|
||||
is_aria2_installed,
|
||||
)
|
||||
|
||||
|
||||
class Aria2Installer:
|
||||
GITHUB_API = "https://api.github.com/repos/P3TERX/Aria2-Pro-Core/releases/latest"
|
||||
GITHUB_MIRROR = "https://gh-api.p3terx.com/repos/P3TERX/Aria2-Pro-Core/releases/latest"
|
||||
CONFIG_URLS = [
|
||||
"https://p3terx.github.io/aria2.conf",
|
||||
"https://cdn.jsdelivr.net/gh/P3TERX/aria2.conf@master",
|
||||
]
|
||||
CONFIG_FILES = ["aria2.conf", "script.conf", "dht.dat", "dht6.dat"]
|
||||
|
||||
def __init__(self, config: Aria2Config | None = None):
|
||||
self.config = config or Aria2Config()
|
||||
self.os_type = detect_os()
|
||||
self.arch = detect_arch()
|
||||
self._executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
async def get_latest_version(self) -> str:
|
||||
"""从 GitHub API 获取最新版本号"""
|
||||
loop = asyncio.get_running_loop()
|
||||
last_error: Exception | None = None
|
||||
|
||||
for url in (self.GITHUB_API, self.GITHUB_MIRROR):
|
||||
try:
|
||||
data = await loop.run_in_executor(
|
||||
self._executor, functools.partial(self._fetch_url, url)
|
||||
)
|
||||
payload = json.loads(data.decode("utf-8"))
|
||||
tag_name = payload.get("tag_name")
|
||||
if not tag_name:
|
||||
raise DownloadError("tag_name missing in GitHub API response")
|
||||
return tag_name
|
||||
except Exception as exc: # noqa: PERF203
|
||||
last_error = exc
|
||||
continue
|
||||
|
||||
raise DownloadError(f"Failed to fetch latest version: {last_error}") from last_error
|
||||
|
||||
async def download_binary(self, version: str | None = None) -> Path:
|
||||
"""下载并解压 aria2 静态二进制到 ~/.local/bin/"""
|
||||
resolved_version = version or await self.get_latest_version()
|
||||
version_name = resolved_version.lstrip("v")
|
||||
archive_name = f"aria2-{version_name}-static-linux-{self.arch}.tar.gz"
|
||||
download_url = (
|
||||
f"https://github.com/P3TERX/Aria2-Pro-Core/releases/download/"
|
||||
f"{resolved_version}/{archive_name}"
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_dir_path = Path(tmpdir)
|
||||
archive_path = tmp_dir_path / archive_name
|
||||
extract_dir = tmp_dir_path / "extract"
|
||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
data = await loop.run_in_executor(
|
||||
self._executor, functools.partial(self._fetch_url, download_url)
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
self._executor, functools.partial(self._write_file, archive_path, data)
|
||||
)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
raise DownloadError(f"Failed to download aria2 binary: {exc}") from exc
|
||||
|
||||
try:
|
||||
binary_path = await loop.run_in_executor(
|
||||
self._executor, functools.partial(self._extract_binary, archive_path, extract_dir)
|
||||
)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
raise DownloadError(f"Failed to extract aria2 binary: {exc}") from exc
|
||||
|
||||
try:
|
||||
ARIA2_BIN.parent.mkdir(parents=True, exist_ok=True)
|
||||
if ARIA2_BIN.exists():
|
||||
ARIA2_BIN.unlink()
|
||||
shutil.move(str(binary_path), ARIA2_BIN)
|
||||
ARIA2_BIN.chmod(0o755)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
raise DownloadError(f"Failed to install aria2 binary: {exc}") from exc
|
||||
|
||||
return ARIA2_BIN
|
||||
|
||||
async def download_config(self) -> None:
|
||||
"""下载配置模板文件"""
|
||||
ARIA2_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for filename in self.CONFIG_FILES:
|
||||
last_error: Exception | None = None
|
||||
for base in self.CONFIG_URLS:
|
||||
url = f"{base.rstrip('/')}/{filename}"
|
||||
try:
|
||||
data = await loop.run_in_executor(
|
||||
self._executor, functools.partial(self._fetch_url, url)
|
||||
)
|
||||
target = ARIA2_CONFIG_DIR / filename
|
||||
await loop.run_in_executor(
|
||||
self._executor, functools.partial(self._write_file, target, data)
|
||||
)
|
||||
last_error = None
|
||||
break
|
||||
except Exception as exc: # noqa: PERF203
|
||||
last_error = exc
|
||||
continue
|
||||
if last_error is not None:
|
||||
raise DownloadError(f"Failed to download {filename}: {last_error}") from last_error
|
||||
|
||||
def render_config(self) -> None:
|
||||
"""渲染配置文件,注入用户参数"""
|
||||
if not ARIA2_CONF.exists():
|
||||
raise ConfigError("Config template not found. Run download_config first.")
|
||||
|
||||
try:
|
||||
content = ARIA2_CONF.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise ConfigError(f"Failed to read config: {exc}") from exc
|
||||
|
||||
rpc_secret = self.config.rpc_secret or generate_rpc_secret()
|
||||
self.config.rpc_secret = rpc_secret
|
||||
|
||||
replacements = {
|
||||
"dir=": str(self.config.download_dir),
|
||||
"rpc-listen-port=": str(self.config.rpc_port),
|
||||
"rpc-secret=": rpc_secret,
|
||||
"max-concurrent-downloads=": str(self.config.max_concurrent_downloads),
|
||||
"max-connection-per-server=": str(self.config.max_connection_per_server),
|
||||
}
|
||||
|
||||
new_lines: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.lstrip()
|
||||
replaced = False
|
||||
for key, value in replacements.items():
|
||||
if stripped.startswith(key):
|
||||
prefix = line[: len(line) - len(stripped)]
|
||||
new_lines.append(f"{prefix}{key}{value}")
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
new_lines.append(line)
|
||||
|
||||
try:
|
||||
ARIA2_CONF.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
|
||||
ARIA2_SESSION.touch(exist_ok=True)
|
||||
self.config.download_dir.mkdir(parents=True, exist_ok=True)
|
||||
ARIA2_LOG.touch(exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigError(f"Failed to render config: {exc}") from exc
|
||||
|
||||
async def install(self, version: str | None = None) -> dict:
|
||||
"""完整安装流程"""
|
||||
resolved_version = version or await self.get_latest_version()
|
||||
await self.download_binary(resolved_version)
|
||||
await self.download_config()
|
||||
self.render_config()
|
||||
|
||||
return {
|
||||
"version": resolved_version,
|
||||
"binary": str(ARIA2_BIN),
|
||||
"config_dir": str(ARIA2_CONFIG_DIR),
|
||||
"config": str(ARIA2_CONF),
|
||||
"session": str(ARIA2_SESSION),
|
||||
"installed": is_aria2_installed(),
|
||||
}
|
||||
|
||||
def uninstall(self) -> None:
|
||||
"""卸载 aria2"""
|
||||
errors: list[Exception] = []
|
||||
|
||||
try:
|
||||
if ARIA2_BIN.exists():
|
||||
ARIA2_BIN.unlink()
|
||||
except Exception as exc: # noqa: PERF203
|
||||
errors.append(exc)
|
||||
|
||||
try:
|
||||
if ARIA2_CONFIG_DIR.exists():
|
||||
shutil.rmtree(ARIA2_CONFIG_DIR)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
errors.append(exc)
|
||||
|
||||
try:
|
||||
service_path = Path.home() / ".config" / "systemd" / "user" / "aria2.service"
|
||||
if service_path.exists():
|
||||
service_path.unlink()
|
||||
except Exception as exc: # noqa: PERF203
|
||||
errors.append(exc)
|
||||
|
||||
if errors:
|
||||
messages = "; ".join(str(err) for err in errors)
|
||||
raise Aria2Error(f"Failed to uninstall aria2: {messages}")
|
||||
|
||||
def _fetch_url(self, url: str) -> bytes:
|
||||
"""阻塞式 URL 获取,放在线程池中运行"""
|
||||
req = request.Request(url, headers={"User-Agent": "aria2-installer"})
|
||||
try:
|
||||
with request.urlopen(req, timeout=30) as resp:
|
||||
if getattr(resp, "status", 200) >= 400:
|
||||
raise DownloadError(f"HTTP {resp.status} for {url}")
|
||||
return resp.read()
|
||||
except (error.HTTPError, error.URLError) as exc:
|
||||
raise DownloadError(f"Network error for {url}: {exc}") from exc
|
||||
|
||||
@staticmethod
|
||||
def _write_file(path: Path, data: bytes) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
|
||||
@staticmethod
|
||||
def _extract_binary(archive_path: Path, extract_dir: Path) -> Path:
|
||||
with tarfile.open(archive_path, "r:gz") as tar:
|
||||
tar.extractall(extract_dir)
|
||||
for candidate in extract_dir.rglob("aria2c"):
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
raise DownloadError("aria2c binary not found in archive")
|
||||
170
src/aria2/service.py
Normal file
170
src/aria2/service.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Aria2 service manager - systemd service lifecycle management."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from src.core import (
|
||||
ARIA2_BIN,
|
||||
ARIA2_CONF,
|
||||
ARIA2_LOG,
|
||||
ARIA2_SERVICE,
|
||||
SYSTEMD_USER_DIR,
|
||||
ServiceError,
|
||||
NotInstalledError,
|
||||
is_aria2_installed,
|
||||
)
|
||||
|
||||
|
||||
SYSTEMD_SERVICE_TEMPLATE = """[Unit]
|
||||
Description=Aria2 Download Manager
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={aria2_bin} --conf-path={aria2_conf}
|
||||
ExecReload=/bin/kill -HUP $MAINPID
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
"""
|
||||
|
||||
|
||||
class Aria2ServiceManager:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _run_systemctl(self, *args: str) -> subprocess.CompletedProcess[str]:
|
||||
try:
|
||||
return subprocess.run(
|
||||
["systemctl", "--user", *args],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise ServiceError("systemctl command not found") from exc
|
||||
except subprocess.CalledProcessError as exc:
|
||||
output = exc.stderr.strip() or exc.stdout.strip() or str(exc)
|
||||
raise ServiceError(output) from exc
|
||||
|
||||
def _ensure_service_file(self) -> None:
|
||||
try:
|
||||
SYSTEMD_USER_DIR.mkdir(parents=True, exist_ok=True)
|
||||
content = SYSTEMD_SERVICE_TEMPLATE.format(
|
||||
aria2_bin=str(ARIA2_BIN),
|
||||
aria2_conf=str(ARIA2_CONF),
|
||||
)
|
||||
ARIA2_SERVICE.write_text(content, encoding="utf-8")
|
||||
self._run_systemctl("daemon-reload")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to write service file: {exc}") from exc
|
||||
|
||||
def start(self) -> None:
|
||||
if not is_aria2_installed():
|
||||
raise NotInstalledError("aria2 is not installed")
|
||||
self._ensure_service_file()
|
||||
self._run_systemctl("start", "aria2")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._run_systemctl("stop", "aria2")
|
||||
|
||||
def restart(self) -> None:
|
||||
self._run_systemctl("restart", "aria2")
|
||||
|
||||
def enable(self) -> None:
|
||||
self._run_systemctl("enable", "aria2")
|
||||
|
||||
def disable(self) -> None:
|
||||
self._run_systemctl("disable", "aria2")
|
||||
|
||||
def status(self) -> dict:
|
||||
installed = is_aria2_installed()
|
||||
pid = self.get_pid() if installed else None
|
||||
|
||||
try:
|
||||
active_proc = subprocess.run(
|
||||
["systemctl", "--user", "is-active", "aria2"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
enabled_proc = subprocess.run(
|
||||
["systemctl", "--user", "is-enabled", "aria2"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise ServiceError("systemctl command not found") from exc
|
||||
|
||||
running = active_proc.returncode == 0
|
||||
enabled = enabled_proc.returncode == 0
|
||||
|
||||
return {
|
||||
"installed": installed,
|
||||
"running": running,
|
||||
"pid": pid,
|
||||
"enabled": enabled,
|
||||
}
|
||||
|
||||
def get_pid(self) -> int | None:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["pgrep", "-u", str(os.getuid()), "-f", "aria2c"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
result = None
|
||||
|
||||
if result and result.returncode == 0:
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if line.isdigit():
|
||||
return int(line)
|
||||
|
||||
try:
|
||||
ps_result = subprocess.run(
|
||||
["ps", "-C", "aria2c", "-o", "pid="],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
for line in ps_result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if line.isdigit():
|
||||
return int(line)
|
||||
return None
|
||||
|
||||
def view_log(self, lines: int = 50) -> str:
|
||||
if lines <= 0 or not ARIA2_LOG.exists():
|
||||
return ""
|
||||
try:
|
||||
content = ARIA2_LOG.read_text(encoding="utf-8", errors="ignore")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to read log: {exc}") from exc
|
||||
|
||||
log_lines = content.splitlines(keepends=True)
|
||||
return "".join(log_lines[-lines:])
|
||||
|
||||
def clear_log(self) -> None:
|
||||
try:
|
||||
ARIA2_LOG.parent.mkdir(parents=True, exist_ok=True)
|
||||
ARIA2_LOG.write_text("", encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to clear log: {exc}") from exc
|
||||
|
||||
def remove_service(self) -> None:
|
||||
self.stop()
|
||||
try:
|
||||
ARIA2_SERVICE.unlink(missing_ok=True)
|
||||
except OSError as exc:
|
||||
raise ServiceError(f"Failed to remove service file: {exc}") from exc
|
||||
self._run_systemctl("daemon-reload")
|
||||
55
src/core/__init__.py
Normal file
55
src/core/__init__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Core module for aria2bot - constants, config, exceptions, and system utilities."""
|
||||
from src.core.constants import (
|
||||
HOME,
|
||||
ARIA2_BIN,
|
||||
ARIA2_CONFIG_DIR,
|
||||
ARIA2_CONF,
|
||||
ARIA2_SESSION,
|
||||
ARIA2_LOG,
|
||||
DOWNLOAD_DIR,
|
||||
SYSTEMD_USER_DIR,
|
||||
ARIA2_SERVICE,
|
||||
)
|
||||
from src.core.exceptions import (
|
||||
Aria2Error,
|
||||
UnsupportedOSError,
|
||||
UnsupportedArchError,
|
||||
DownloadError,
|
||||
ConfigError,
|
||||
ServiceError,
|
||||
NotInstalledError,
|
||||
)
|
||||
from src.core.config import Aria2Config, BotConfig
|
||||
from src.core.system import (
|
||||
detect_os,
|
||||
detect_arch,
|
||||
generate_rpc_secret,
|
||||
is_aria2_installed,
|
||||
get_aria2_version,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HOME",
|
||||
"ARIA2_BIN",
|
||||
"ARIA2_CONFIG_DIR",
|
||||
"ARIA2_CONF",
|
||||
"ARIA2_SESSION",
|
||||
"ARIA2_LOG",
|
||||
"DOWNLOAD_DIR",
|
||||
"SYSTEMD_USER_DIR",
|
||||
"ARIA2_SERVICE",
|
||||
"Aria2Error",
|
||||
"UnsupportedOSError",
|
||||
"UnsupportedArchError",
|
||||
"DownloadError",
|
||||
"ConfigError",
|
||||
"ServiceError",
|
||||
"NotInstalledError",
|
||||
"Aria2Config",
|
||||
"BotConfig",
|
||||
"detect_os",
|
||||
"detect_arch",
|
||||
"generate_rpc_secret",
|
||||
"is_aria2_installed",
|
||||
"get_aria2_version",
|
||||
]
|
||||
42
src/core/config.py
Normal file
42
src/core/config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Configuration dataclass for aria2bot."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.constants import DOWNLOAD_DIR
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aria2Config:
|
||||
rpc_port: int = 6800
|
||||
rpc_secret: str = ""
|
||||
download_dir: Path = DOWNLOAD_DIR
|
||||
max_concurrent_downloads: int = 5
|
||||
max_connection_per_server: int = 16
|
||||
bt_tracker_update: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotConfig:
|
||||
token: str = ""
|
||||
api_base_url: str = ""
|
||||
aria2: Aria2Config = field(default_factory=Aria2Config)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "BotConfig":
|
||||
"""从环境变量加载配置"""
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
token = os.environ.get("TELEGRAM_BOT_TOKEN", "")
|
||||
aria2 = Aria2Config(
|
||||
rpc_port=int(os.environ.get("ARIA2_RPC_PORT", "6800")),
|
||||
rpc_secret=os.environ.get("ARIA2_RPC_SECRET", ""),
|
||||
)
|
||||
return cls(
|
||||
token=token,
|
||||
api_base_url=os.environ.get("TELEGRAM_API_BASE_URL", ""),
|
||||
aria2=aria2,
|
||||
)
|
||||
12
src/core/constants.py
Normal file
12
src/core/constants.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Path constants for aria2bot."""
|
||||
from pathlib import Path
|
||||
|
||||
HOME = Path.home()
|
||||
ARIA2_BIN = HOME / ".local" / "bin" / "aria2c"
|
||||
ARIA2_CONFIG_DIR = HOME / ".config" / "aria2"
|
||||
ARIA2_CONF = ARIA2_CONFIG_DIR / "aria2.conf"
|
||||
ARIA2_SESSION = ARIA2_CONFIG_DIR / "aria2.session"
|
||||
ARIA2_LOG = ARIA2_CONFIG_DIR / "aria2.log"
|
||||
DOWNLOAD_DIR = HOME / "downloads"
|
||||
SYSTEMD_USER_DIR = HOME / ".config" / "systemd" / "user"
|
||||
ARIA2_SERVICE = SYSTEMD_USER_DIR / "aria2.service"
|
||||
29
src/core/exceptions.py
Normal file
29
src/core/exceptions.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Exception classes for aria2bot."""
|
||||
|
||||
|
||||
class Aria2Error(Exception):
|
||||
"""Base exception"""
|
||||
|
||||
|
||||
class UnsupportedOSError(Aria2Error):
|
||||
"""不支持的操作系统"""
|
||||
|
||||
|
||||
class UnsupportedArchError(Aria2Error):
|
||||
"""不支持的 CPU 架构"""
|
||||
|
||||
|
||||
class DownloadError(Aria2Error):
|
||||
"""下载失败"""
|
||||
|
||||
|
||||
class ConfigError(Aria2Error):
|
||||
"""配置错误"""
|
||||
|
||||
|
||||
class ServiceError(Aria2Error):
|
||||
"""服务操作失败"""
|
||||
|
||||
|
||||
class NotInstalledError(Aria2Error):
|
||||
"""aria2 未安装"""
|
||||
91
src/core/system.py
Normal file
91
src/core/system.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""System detection utilities for aria2bot."""
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import secrets
|
||||
import shutil
|
||||
import string
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.constants import ARIA2_BIN
|
||||
from src.core.exceptions import UnsupportedOSError, UnsupportedArchError
|
||||
|
||||
|
||||
def detect_os() -> str:
|
||||
"""检测操作系统,返回 'centos', 'debian', 'ubuntu' 或抛出 UnsupportedOSError"""
|
||||
os_release_path = Path("/etc/os-release")
|
||||
if os_release_path.exists():
|
||||
info: dict[str, str] = {}
|
||||
for line in os_release_path.read_text(encoding="utf-8", errors="ignore").splitlines():
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
info[key.strip()] = value.strip().strip('"').lower()
|
||||
os_id = info.get("ID")
|
||||
if os_id in {"ubuntu", "debian"}:
|
||||
return os_id
|
||||
if os_id in {"centos", "rhel", "rocky", "almalinux"}:
|
||||
return "centos"
|
||||
|
||||
redhat_release = Path("/etc/redhat-release")
|
||||
if redhat_release.exists():
|
||||
content = redhat_release.read_text(encoding="utf-8", errors="ignore").lower()
|
||||
if any(name in content for name in ("centos", "red hat", "rocky", "alma")):
|
||||
return "centos"
|
||||
|
||||
raise UnsupportedOSError("Unsupported operating system")
|
||||
|
||||
|
||||
def detect_arch() -> str:
|
||||
"""检测 CPU 架构,返回 'amd64', 'arm64', 'armhf', 'i386' 或抛出 UnsupportedArchError"""
|
||||
machine = platform.machine().lower()
|
||||
if machine in {"x86_64", "amd64"}:
|
||||
return "amd64"
|
||||
if machine in {"aarch64", "arm64", "armv8"}:
|
||||
return "arm64"
|
||||
if machine.startswith("armv7") or machine.startswith("armv6"):
|
||||
return "armhf"
|
||||
if machine in {"i386", "i686", "x86"}:
|
||||
return "i386"
|
||||
raise UnsupportedArchError(f"Unsupported CPU architecture: {machine}")
|
||||
|
||||
|
||||
def generate_rpc_secret() -> str:
|
||||
"""生成 20 位随机 RPC 密钥"""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return "".join(secrets.choice(alphabet) for _ in range(20))
|
||||
|
||||
|
||||
def is_aria2_installed() -> bool:
|
||||
"""检查 aria2c 是否已安装"""
|
||||
if ARIA2_BIN.exists():
|
||||
return True
|
||||
return shutil.which("aria2c") is not None
|
||||
|
||||
|
||||
def get_aria2_version() -> str | None:
|
||||
"""获取已安装的 aria2 版本"""
|
||||
candidates = [ARIA2_BIN] if ARIA2_BIN.exists() else []
|
||||
path_cmd = shutil.which("aria2c")
|
||||
if path_cmd:
|
||||
candidates.append(Path(path_cmd))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
for cmd in candidates:
|
||||
result = subprocess.run(
|
||||
[str(cmd), "-v"], capture_output=True, text=True, check=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
continue
|
||||
for line in result.stdout.splitlines():
|
||||
lowered = line.lower()
|
||||
if "aria2 version" in lowered:
|
||||
parts = line.split()
|
||||
return parts[-1] if parts else line.strip()
|
||||
if result.stdout.strip():
|
||||
return result.stdout.splitlines()[0].strip()
|
||||
|
||||
return None
|
||||
5
src/telegram/__init__.py
Normal file
5
src/telegram/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Telegram bot module - command handlers and application."""
|
||||
from src.telegram.handlers import Aria2BotAPI, build_handlers
|
||||
from src.telegram.app import create_app, run
|
||||
|
||||
__all__ = ["Aria2BotAPI", "build_handlers", "create_app", "run"]
|
||||
38
src/telegram/app.py
Normal file
38
src/telegram/app.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Telegram application builder and runner."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from telegram.ext import Application
|
||||
|
||||
from src.core import BotConfig
|
||||
from src.telegram.handlers import Aria2BotAPI, build_handlers
|
||||
from src.utils import setup_logger
|
||||
|
||||
|
||||
def create_app(config: BotConfig) -> Application:
|
||||
"""创建 Telegram Application"""
|
||||
builder = Application.builder().token(config.token)
|
||||
if config.api_base_url:
|
||||
builder = builder.base_url(config.api_base_url).base_file_url(config.api_base_url + "/file")
|
||||
app = builder.build()
|
||||
|
||||
api = Aria2BotAPI(config.aria2)
|
||||
for handler in build_handlers(api):
|
||||
app.add_handler(handler)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run() -> None:
|
||||
"""加载配置并启动 bot"""
|
||||
logger = setup_logger()
|
||||
config = BotConfig.from_env()
|
||||
|
||||
if not config.token:
|
||||
logger.error("Please set TELEGRAM_BOT_TOKEN in .env or environment")
|
||||
sys.exit(1)
|
||||
|
||||
app = create_app(config)
|
||||
logger.info("Bot starting...")
|
||||
app.run_polling()
|
||||
212
src/telegram/handlers.py
Normal file
212
src/telegram/handlers.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Telegram bot command handlers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes, CommandHandler
|
||||
|
||||
from src.core import (
|
||||
Aria2Config,
|
||||
Aria2Error,
|
||||
NotInstalledError,
|
||||
ServiceError,
|
||||
DownloadError,
|
||||
ConfigError,
|
||||
is_aria2_installed,
|
||||
get_aria2_version,
|
||||
ARIA2_CONF,
|
||||
)
|
||||
from src.aria2 import Aria2Installer, Aria2ServiceManager
|
||||
|
||||
|
||||
class Aria2BotAPI:
|
||||
def __init__(self, config: Aria2Config | None = None):
|
||||
self.config = config or Aria2Config()
|
||||
self.installer = Aria2Installer(self.config)
|
||||
self.service = Aria2ServiceManager()
|
||||
|
||||
async def _reply(self, update: Update, context: ContextTypes.DEFAULT_TYPE, text: str, **kwargs):
|
||||
if update.effective_message:
|
||||
return await update.effective_message.reply_text(text, **kwargs)
|
||||
if update.effective_chat:
|
||||
return await context.bot.send_message(chat_id=update.effective_chat.id, text=text, **kwargs)
|
||||
return None
|
||||
|
||||
def _get_rpc_secret(self) -> str:
|
||||
if self.config.rpc_secret:
|
||||
return self.config.rpc_secret
|
||||
if ARIA2_CONF.exists():
|
||||
try:
|
||||
for line in ARIA2_CONF.read_text(encoding="utf-8", errors="ignore").splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("rpc-secret="):
|
||||
secret = stripped.split("=", 1)[1].strip()
|
||||
if secret:
|
||||
self.config.rpc_secret = secret
|
||||
return secret
|
||||
except OSError:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
def _get_rpc_port(self) -> int | None:
|
||||
if ARIA2_CONF.exists():
|
||||
try:
|
||||
for line in ARIA2_CONF.read_text(encoding="utf-8", errors="ignore").splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("rpc-listen-port="):
|
||||
port_str = stripped.split("=", 1)[1].strip()
|
||||
if port_str.isdigit():
|
||||
return int(port_str)
|
||||
except OSError:
|
||||
return None
|
||||
return self.config.rpc_port
|
||||
|
||||
async def install(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await self._reply(update, context, "正在安装 aria2,处理中,请稍候...")
|
||||
try:
|
||||
result = await self.installer.install()
|
||||
version = get_aria2_version() or result.get("version") or "未知"
|
||||
rpc_secret = self._get_rpc_secret() or "未设置"
|
||||
rpc_port = self._get_rpc_port() or self.config.rpc_port
|
||||
await self._reply(
|
||||
update,
|
||||
context,
|
||||
"\n".join(
|
||||
[
|
||||
"安装完成 ✅",
|
||||
f"版本:{version}",
|
||||
f"二进制:{result.get('binary')}",
|
||||
f"配置目录:{result.get('config_dir')}",
|
||||
f"配置文件:{result.get('config')}",
|
||||
f"RPC 端口:{rpc_port}",
|
||||
f"RPC 密钥:{rpc_secret}",
|
||||
]
|
||||
),
|
||||
)
|
||||
except (DownloadError, ConfigError, Aria2Error) as exc:
|
||||
await self._reply(update, context, f"安装失败:{exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"安装失败,发生未知错误:{exc}")
|
||||
|
||||
async def uninstall(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await self._reply(update, context, "正在卸载 aria2,处理中,请稍候...")
|
||||
try:
|
||||
try:
|
||||
self.service.stop()
|
||||
except ServiceError:
|
||||
pass
|
||||
self.installer.uninstall()
|
||||
await self._reply(update, context, "卸载完成 ✅")
|
||||
except Aria2Error as exc:
|
||||
await self._reply(update, context, f"卸载失败:{exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"卸载失败,发生未知错误:{exc}")
|
||||
|
||||
async def start_service(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
if not is_aria2_installed():
|
||||
await self._reply(update, context, "aria2 未安装,请先运行 /install")
|
||||
return
|
||||
self.service.start()
|
||||
await self._reply(update, context, "aria2 服务已启动 ✅")
|
||||
except NotInstalledError:
|
||||
await self._reply(update, context, "aria2 未安装,请先运行 /install")
|
||||
except ServiceError as exc:
|
||||
await self._reply(update, context, f"启动失败:{exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"启动失败,发生未知错误:{exc}")
|
||||
|
||||
async def stop_service(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
self.service.stop()
|
||||
await self._reply(update, context, "aria2 服务已停止 ✅")
|
||||
except ServiceError as exc:
|
||||
await self._reply(update, context, f"停止失败:{exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"停止失败,发生未知错误:{exc}")
|
||||
|
||||
async def restart_service(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
self.service.restart()
|
||||
await self._reply(update, context, "aria2 服务已重启 ✅")
|
||||
except ServiceError as exc:
|
||||
await self._reply(update, context, f"重启失败:{exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"重启失败,发生未知错误:{exc}")
|
||||
|
||||
async def status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
info = self.service.status()
|
||||
version = get_aria2_version() or "未知"
|
||||
rpc_secret = self._get_rpc_secret() or "未设置"
|
||||
rpc_port = self._get_rpc_port() or self.config.rpc_port or "未知"
|
||||
except ServiceError as exc:
|
||||
await self._reply(update, context, f"获取状态失败:{exc}")
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"获取状态失败,发生未知错误:{exc}")
|
||||
return
|
||||
|
||||
text = (
|
||||
"*Aria2 状态*\n"
|
||||
f"- 安装状态:{'已安装 ✅' if info.get('installed') or is_aria2_installed() else '未安装 ❌'}\n"
|
||||
f"- 运行状态:{'运行中 ✅' if info.get('running') else '未运行 ❌'}\n"
|
||||
f"- PID:`{info.get('pid') or 'N/A'}`\n"
|
||||
f"- 版本:`{version}`\n"
|
||||
f"- RPC 端口:`{rpc_port}`\n"
|
||||
f"- RPC 密钥:`{rpc_secret}`"
|
||||
)
|
||||
await self._reply(update, context, text, parse_mode="Markdown")
|
||||
|
||||
async def view_logs(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
logs = self.service.view_log(lines=30)
|
||||
except ServiceError as exc:
|
||||
await self._reply(update, context, f"读取日志失败:{exc}")
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"读取日志失败,发生未知错误:{exc}")
|
||||
return
|
||||
|
||||
if not logs.strip():
|
||||
await self._reply(update, context, "暂无日志内容。")
|
||||
return
|
||||
|
||||
await self._reply(update, context, f"最近 30 行日志:\n{logs}")
|
||||
|
||||
async def clear_logs(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
self.service.clear_log()
|
||||
await self._reply(update, context, "日志已清空 ✅")
|
||||
except ServiceError as exc:
|
||||
await self._reply(update, context, f"清空日志失败:{exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._reply(update, context, f"清空日志失败,发生未知错误:{exc}")
|
||||
|
||||
async def help_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
commands = [
|
||||
"/install - 安装 aria2",
|
||||
"/uninstall - 卸载 aria2",
|
||||
"/start - 启动 aria2 服务",
|
||||
"/stop - 停止 aria2 服务",
|
||||
"/restart - 重启 aria2 服务",
|
||||
"/status - 查看 aria2 状态",
|
||||
"/logs - 查看最近日志",
|
||||
"/clear_logs - 清空日志",
|
||||
"/help - 显示此帮助",
|
||||
]
|
||||
await self._reply(update, context, "可用命令:\n" + "\n".join(commands))
|
||||
|
||||
|
||||
def build_handlers(api: Aria2BotAPI) -> list[CommandHandler]:
|
||||
"""构建 CommandHandler 列表"""
|
||||
return [
|
||||
CommandHandler("install", api.install),
|
||||
CommandHandler("uninstall", api.uninstall),
|
||||
CommandHandler("start", api.start_service),
|
||||
CommandHandler("stop", api.stop_service),
|
||||
CommandHandler("restart", api.restart_service),
|
||||
CommandHandler("status", api.status),
|
||||
CommandHandler("logs", api.view_logs),
|
||||
CommandHandler("clear_logs", api.clear_logs),
|
||||
CommandHandler("help", api.help_command),
|
||||
]
|
||||
4
src/utils/__init__.py
Normal file
4
src/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Utility module - logging and other helpers."""
|
||||
from src.utils.logger import setup_logger, get_logger
|
||||
|
||||
__all__ = ["setup_logger", "get_logger"]
|
||||
28
src/utils/logger.py
Normal file
28
src/utils/logger.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Logging module for aria2bot"""
|
||||
import logging
|
||||
import sys
|
||||
|
||||
_initialized = False
|
||||
|
||||
|
||||
def setup_logger(name: str = "aria2bot", level: int = logging.INFO) -> logging.Logger:
|
||||
"""Initialize and configure the root logger."""
|
||||
global _initialized
|
||||
logger = logging.getLogger(name)
|
||||
if not _initialized:
|
||||
logger.setLevel(level)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
"%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
_initialized = True
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a child logger for a specific module."""
|
||||
return logging.getLogger(f"aria2bot.{name}")
|
||||
Reference in New Issue
Block a user