Compare commits

...

10 Commits

3 changed files with 180 additions and 33 deletions

View File

@@ -1,2 +1,5 @@
TELEGRAM_TOKEN="" TELEGRAM_TOKEN=""
EXCHANGE_API_KEY="" EXCHANGE_API_KEY=""
UPDATE_OWNER_ID=""
AUTO_UPDATE_REMOTE="https://git.llc/zimk/SubMind.git"
AUTO_UPDATE_BRANCH="main"

View File

@@ -40,7 +40,7 @@ cd SubMind
python -m venv .venv python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -U pip pip install -U pip
pip install python-telegram-bot pandas matplotlib python-dateutil dateparser python-dotenv requests pip install -r requirements.txt
``` ```
### 3) 配置环境变量 ### 3) 配置环境变量
@@ -56,11 +56,17 @@ cp .env.example .env
```env ```env
TELEGRAM_TOKEN="<YOUR_TELEGRAM_BOT_TOKEN>" TELEGRAM_TOKEN="<YOUR_TELEGRAM_BOT_TOKEN>"
EXCHANGE_API_KEY="<YOUR_EXCHANGE_API_KEY>" EXCHANGE_API_KEY="<YOUR_EXCHANGE_API_KEY>"
UPDATE_OWNER_ID="<YOUR_TELEGRAM_USER_ID>"
AUTO_UPDATE_REMOTE="https://git.llc/zimk/SubMind.git"
AUTO_UPDATE_BRANCH="main"
``` ```
说明: 说明:
- `TELEGRAM_TOKEN` 必填。 - `TELEGRAM_TOKEN` 必填。
- `EXCHANGE_API_KEY` 可选(不填时不做在线汇率转换)。 - `EXCHANGE_API_KEY` 可选(不填时不做在线汇率转换)。
- `UPDATE_OWNER_ID` 可选(建议配置为你的 Telegram 用户 ID仅该用户可执行 `/update`)。
- `AUTO_UPDATE_REMOTE` 可选(默认 `https://git.llc/zimk/SubMind.git`)。
- `AUTO_UPDATE_BRANCH` 可选(默认 `main`)。
### 4) 运行 ### 4) 运行
@@ -80,6 +86,7 @@ python SubMind.py
- `/import` 导入 CSV - `/import` 导入 CSV
- `/export` 导出 CSV - `/export` 导出 CSV
- `/set_currency <CODE>` 设置主货币(例如 `USD``CNY` - `/set_currency <CODE>` 设置主货币(例如 `USD``CNY`
- `/update` 拉取最新代码、安装依赖并自动重启(仅 `UPDATE_OWNER_ID` 指定用户可用)
- `/help` 帮助 - `/help` 帮助
- `/cancel` 取消当前流程 - `/cancel` 取消当前流程

View File

@@ -1,6 +1,8 @@
import sqlite3 import sqlite3
import asyncio import asyncio
import os import os
import sys
import subprocess
import html import html
import requests import requests
import datetime import datetime
@@ -20,7 +22,10 @@ from telegram.ext import (
CallbackContext, CallbackQueryHandler, ConversationHandler CallbackContext, CallbackQueryHandler, ConversationHandler
) )
from telegram.error import TelegramError from telegram.error import TelegramError
from telegram.helpers import escape_html def escape_html(text, version=None):
if text is None:
return ''
return html.escape(str(text))
# --- 加载 .env 和设置 --- # --- 加载 .env 和设置 ---
load_dotenv() load_dotenv()
@@ -39,6 +44,11 @@ EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY')
PROJECT_NAME = "SubMind" PROJECT_NAME = "SubMind"
DB_FILE = 'submind.db' DB_FILE = 'submind.db'
# 自动更新配置
UPDATE_OWNER_ID = os.getenv('UPDATE_OWNER_ID') # 仅允许此用户执行 /update
AUTO_UPDATE_REMOTE = os.getenv('AUTO_UPDATE_REMOTE', 'https://git.llc/zimk/SubMind.git').strip()
AUTO_UPDATE_BRANCH = os.getenv('AUTO_UPDATE_BRANCH', 'main').strip() or 'main'
# --- 对话处理器状态 --- # --- 对话处理器状态 ---
(ADD_NAME, ADD_COST, ADD_CURRENCY, ADD_CATEGORY, ADD_NEXT_DUE, (ADD_NAME, ADD_COST, ADD_CURRENCY, ADD_CATEGORY, ADD_NEXT_DUE,
ADD_FREQ_UNIT, ADD_FREQ_VALUE, ADD_RENEWAL_TYPE, ADD_NOTES) = range(9) ADD_FREQ_UNIT, ADD_FREQ_VALUE, ADD_RENEWAL_TYPE, ADD_NOTES) = range(9)
@@ -373,7 +383,7 @@ async def check_and_send_reminders(context: CallbackContext):
]) ])
if sub['reminder_on_due_date'] and due_date == today: if sub['reminder_on_due_date'] and due_date == today:
message = f"🔔 *订阅到期提醒*\n\n您的订阅 `{safe_sub_name}` 今天到期。" message = f"🔔 <b>订阅到期提醒</b>\n\n您的订阅 <code>{safe_sub_name}</code> 今天到期。"
if renewal_type == 'manual': if renewal_type == 'manual':
message += " 请记得手动续费。" message += " 请记得手动续费。"
else: else:
@@ -385,7 +395,7 @@ async def check_and_send_reminders(context: CallbackContext):
if reminder_date == today: if reminder_date == today:
days_left = (due_date - today).days days_left = (due_date - today).days
days_text = f"<b>{days_left}天后</b>" if days_left > 0 else "<b>今天</b>" days_text = f"<b>{days_left}天后</b>" if days_left > 0 else "<b>今天</b>"
message = f"🔔 *订阅即将到期提醒*\n\n您的手动续费订阅 `{safe_sub_name}` 将在 {days_text} 到期。" message = f"🔔 <b>订阅即将到期提醒</b>\n\n您的手动续费订阅 <code>{safe_sub_name}</code> 将在 {days_text} 到期。"
if message: if message:
await context.bot.send_message( await context.bot.send_message(
@@ -415,19 +425,19 @@ async def start(update: Update, context: CallbackContext):
async def help_command(update: Update, context: CallbackContext): async def help_command(update: Update, context: CallbackContext):
help_text = fr""" help_text = f"""
*{escape_html(PROJECT_NAME)} 命令列表* <b>{escape_html(PROJECT_NAME)} 命令列表</b>
*🌟 核心功能* <b>🌟 核心功能</b>
/add\_sub \- 引导您添加一个新的订阅 /add_sub - 引导您添加一个新的订阅
/list\_subs \- 列出您的所有订阅 /list_subs - 列出您的所有订阅
/list\_categories \- 按分类浏览您的订阅 /list_categories - 按分类浏览您的订阅
*📊 数据管理* <b>📊 数据管理</b>
/stats \- 查看按类别分类的订阅统计 /stats - 查看按类别分类的订阅统计
/import \- 通过上传 CSV 文件批量导入订阅 /import - 通过上传 CSV 文件批量导入订阅
/export \- 将您的所有订阅导出为 CSV 文件 /export - 将您的所有订阅导出为 CSV 文件
*⚙️ 个性化设置* <b>⚙️ 个性化设置</b>
/set\_currency \`<code>\` \- 设置您的主要货币 /set_currency &lt;代码&gt; - 设置您的主要货币
/cancel \- 在任何流程中取消当前操作 /cancel - 在任何流程中取消当前操作
""" """
await update.message.reply_text(help_text, parse_mode='HTML') await update.message.reply_text(help_text, parse_mode='HTML')
@@ -513,7 +523,7 @@ async def stats(update: Update, context: CallbackContext):
try: try:
theme_colors = ['#3B82F6', '#10B981', '#F59E0B', '#EF4444', '#8B5CF6', '#EC4899', '#14B8A6', '#F97316', '#6366F1', '#84CC16'] theme_colors = ['#3B82F6', '#10B981', '#F59E0B', '#EF4444', '#8B5CF6', '#EC4899', '#14B8A6', '#F97316', '#6366F1', '#84CC16']
if len(category_costs) > len(theme_colors): if len(category_costs) > len(theme_colors):
import matplotlib.pyplot as plt # 移除导致遮蔽的局部 import直接使用全局的 matplotlib 和 plt
extra_colors = [matplotlib.colors.to_hex(c) for c in plt.get_cmap('tab20').colors] extra_colors = [matplotlib.colors.to_hex(c) for c in plt.get_cmap('tab20').colors]
theme_colors.extend(extra_colors) theme_colors.extend(extra_colors)
@@ -593,7 +603,7 @@ async def stats(update: Update, context: CallbackContext):
weight='bold' weight='bold'
) )
fig.suptitle('📊 您的订阅支出洞察', fontproperties=font_prop, fontsize=24, color='#0F172A', y=1.02, weight='bold') fig.suptitle('您的订阅统计报告', fontproperties=font_prop, fontsize=24, color='#0F172A', y=1.02, weight='bold')
fig.tight_layout(rect=[0, 0, 1, 0.95]) fig.tight_layout(rect=[0, 0, 1, 0.95])
with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp: with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp:
@@ -824,7 +834,7 @@ async def add_category_received(update: Update, context: CallbackContext):
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name)) cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
conn.commit() conn.commit()
await update.message.reply_text("第五步:请输入 *下一次付款日期*(例如 2025\\-10\\-01 或 10月1日", await update.message.reply_text("第五步:请输入 <b>下一次付款日期</b>(例如 2025-10-01 或 10月1日",
parse_mode='HTML') parse_mode='HTML')
return ADD_NEXT_DUE return ADD_NEXT_DUE
@@ -838,7 +848,7 @@ async def add_next_due_received(update: Update, context: CallbackContext):
parsed_date = parse_date(update.message.text) parsed_date = parse_date(update.message.text)
if not parsed_date: if not parsed_date:
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。") await update.message.reply_text("无法识别的日期格式,请使用类似 '2025-10-01''10月1日' 的格式。")
return ADD_NEXT_DUE return ADD_NEXT_DUE
sub_data['next_due'] = parsed_date sub_data['next_due'] = parsed_date
keyboard = [ keyboard = [
@@ -865,7 +875,7 @@ async def add_freq_unit_received(update: Update, context: CallbackContext):
await query.edit_message_text("错误:无效的周期单位,请重试。") await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END return ConversationHandler.END
sub_data['unit'] = unit sub_data['unit'] = unit
await query.edit_message_text("第七步:请输入周期的<b>数量</b>例如每3个月输入 3", parse_mode='Markdown') await query.edit_message_text("第七步:请输入周期的<b>数量</b>例如每3个月输入 3", parse_mode='HTML')
return ADD_FREQ_VALUE return ADD_FREQ_VALUE
@@ -1012,19 +1022,19 @@ async def show_subscription_view(update: Update, context: CallbackContext, sub_i
sub['reminders_enabled'], sub['notes']) sub['reminders_enabled'], sub['notes'])
freq_text = format_frequency(sub['frequency_unit'], sub['frequency_value']) freq_text = format_frequency(sub['frequency_unit'], sub['frequency_value'])
main_currency = get_user_main_currency(user_id) main_currency = get_user_main_currency(user_id)
converted_cost = convert_currency(cost, currency, main_currency) converted_cost = await asyncio.to_thread(convert_currency, cost, currency, main_currency)
safe_name, safe_category, safe_freq = escape_html(name), escape_html(category), escape_html(freq_text) safe_name, safe_category, safe_freq = escape_html(name), escape_html(category), escape_html(freq_text)
cost_str, converted_cost_str = escape_html(f"{cost:.2f}"), escape_html(f"{converted_cost:.2f}") cost_str, converted_cost_str = escape_html(f"{cost:.2f}"), escape_html(f"{converted_cost:.2f}")
renewal_text = "手动续费" if renewal_type == 'manual' else "自动续费" renewal_text = "手动续费" if renewal_type == 'manual' else "自动续费"
reminder_status = "开启" if reminders_enabled else "关闭" reminder_status = "开启" if reminders_enabled else "关闭"
text = (f"*订阅详情: {safe_name}*\n\n" text = (f"<b>订阅详情: {safe_name}</b>\n\n"
f"\\- *费用*: `{cost_str} {currency.upper()}` \\(\\~`{converted_cost_str} {main_currency.upper()}`\\)\n" f"- <b>费用</b>: <code>{cost_str} {currency.upper()}</code> (~<code>{converted_cost_str} {main_currency.upper()}</code>)\n"
f"\\- *类别*: `{safe_category}`\n" f"- <b>类别</b>: <code>{safe_category}</code>\n"
f"\\- *下次付款*: `{next_due}` \\(周期: {safe_freq}\\)\n" f"- <b>下次付款</b>: <code>{next_due}</code> (周期: {safe_freq})\n"
f"\\- *续费方式*: `{renewal_text}`\n" f"- <b>续费方式</b>: <code>{renewal_text}</code>\n"
f"\\- *提醒状态*: `{reminder_status}`") f"- <b>提醒状态</b>: <code>{reminder_status}</code>")
if notes: if notes:
text += f"\n\\- *备注*: {escape_html(notes)}" text += f"\n- <b>备注</b>: {escape_html(notes)}"
keyboard_buttons = [ keyboard_buttons = [
[InlineKeyboardButton("✏️ 编辑", callback_data=f'edit_{sub_id}'), [InlineKeyboardButton("✏️ 编辑", callback_data=f'edit_{sub_id}'),
InlineKeyboardButton("🗑️ 删除", callback_data=f'delete_{sub_id}')], InlineKeyboardButton("🗑️ 删除", callback_data=f'delete_{sub_id}')],
@@ -1161,7 +1171,7 @@ async def button_callback_handler(update: Update, context: CallbackContext):
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True) await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
else: else:
await query.answer("续费失败:此订阅可能已被删除或无权限。", show_alert=True) await query.answer("续费失败:此订阅可能已被删除或无权限。", show_alert=True)
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅不存在或无权限)*", await query.edit_message_text(text=query.message.text + "\n\n<b>(错误:此订阅不存在或无权限)</b>",
parse_mode='HTML', reply_markup=None) parse_mode='HTML', reply_markup=None)
elif action == 'delete': elif action == 'delete':
@@ -1390,7 +1400,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
parsed = parse_date(str(new_value)) parsed = parse_date(str(new_value))
if not parsed: if not parsed:
if message_to_reply: if message_to_reply:
await message_to_reply.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。") await message_to_reply.reply_text("无法识别的日期格式,请使用类似 '2025-10-01''10月1日' 的格式。")
validation_failed = True validation_failed = True
else: else:
new_value = parsed new_value = parsed
@@ -1469,7 +1479,7 @@ async def _display_reminder_settings(query: CallbackQuery, context: CallbackCont
safe_name = escape_html(sub['name']) safe_name = escape_html(sub['name'])
current_status = f"<b>🔔 提醒设置: {safe_name}</b>\n\n" current_status = f"<b>🔔 提醒设置: {safe_name}</b>\n\n"
if sub['renewal_type'] == 'manual': if sub['renewal_type'] == 'manual':
current_status += f"当前提前提醒: *{sub['reminder_days']}*\n" current_status += f"当前提前提醒: <b>{sub['reminder_days']}</b>\n"
keyboard.append([InlineKeyboardButton("⚙️ 更改提前天数", callback_data='remindaction_ask_days')]) keyboard.append([InlineKeyboardButton("⚙️ 更改提前天数", callback_data='remindaction_ask_days')])
keyboard.append([InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')]) keyboard.append([InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')])
await query.edit_message_text(current_status, reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='HTML') await query.edit_message_text(current_status, reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='HTML')
@@ -1603,6 +1613,131 @@ async def cancel(update: Update, context: CallbackContext):
return ConversationHandler.END return ConversationHandler.END
def _can_run_update(user_id: int) -> bool:
"""仅允许指定 owner 执行自动更新。未配置 owner 时默认拒绝。"""
if not UPDATE_OWNER_ID:
return False
try:
return int(UPDATE_OWNER_ID) == int(user_id)
except (ValueError, TypeError):
return False
def _resolve_update_target(repo_dir: str):
"""
解析更新目标 remote/branch。
优先级:
1) 环境变量 AUTO_UPDATE_REMOTE + AUTO_UPDATE_BRANCH
2) 当前分支上游 @{u}
3) 远程优先 gitllc其次 origin分支用 AUTO_UPDATE_BRANCH 或 main
"""
branch = (AUTO_UPDATE_BRANCH or 'main').strip() or 'main'
# 1) 明确指定 remote
if AUTO_UPDATE_REMOTE:
return AUTO_UPDATE_REMOTE, branch
# 2) 尝试读取上游分支(如 gitllc/main
upstream_proc = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"],
cwd=repo_dir, capture_output=True, text=True
)
if upstream_proc.returncode == 0:
upstream = upstream_proc.stdout.strip()
if '/' in upstream:
remote, up_branch = upstream.split('/', 1)
if remote and up_branch:
return remote, up_branch
# 3) 回退:从远程列表推断
remotes_proc = subprocess.run(["git", "remote"], cwd=repo_dir, capture_output=True, text=True)
if remotes_proc.returncode != 0:
return None, None
remotes = [r.strip() for r in remotes_proc.stdout.splitlines() if r.strip()]
if not remotes:
return None, None
if 'gitllc' in remotes:
return 'gitllc', branch
if 'origin' in remotes:
return 'origin', branch
return remotes[0], branch
def _run_cmd(cmd, cwd):
return subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
async def update_bot(update: Update, context: CallbackContext):
user_id = update.effective_user.id
if not _can_run_update(user_id):
await update.message.reply_text("无权限执行 /update。")
return
await update.message.reply_text("开始检查更新,请稍候…")
repo_dir = os.path.dirname(os.path.abspath(__file__))
try:
remote_name, branch_name = _resolve_update_target(repo_dir)
if not remote_name or not branch_name:
await update.message.reply_text("更新失败:无法解析 git 远程仓库,请检查仓库 remote 配置。")
return
fetch_cmd = ["git", "fetch", remote_name, branch_name]
fetch_proc = await asyncio.to_thread(_run_cmd, fetch_cmd, repo_dir)
if fetch_proc.returncode != 0:
err = (fetch_proc.stderr or fetch_proc.stdout or "未知错误").strip()
await update.message.reply_text(f"更新失败fetch\n<code>{escape_html(err)}</code>", parse_mode='HTML')
return
local_rev = await asyncio.to_thread(_run_cmd, ["git", "rev-parse", "HEAD"], repo_dir)
fetched_rev = await asyncio.to_thread(_run_cmd, ["git", "rev-parse", "FETCH_HEAD"], repo_dir)
if local_rev.returncode != 0 or fetched_rev.returncode != 0:
await update.message.reply_text("更新失败:无法读取当前版本。")
return
local_hash = local_rev.stdout.strip()
fetched_hash = fetched_rev.stdout.strip()
if local_hash == fetched_hash:
await update.message.reply_text("当前已是最新版本,无需更新。")
return
reset_proc = await asyncio.to_thread(
_run_cmd,
["git", "reset", "--hard", "FETCH_HEAD"],
repo_dir
)
if reset_proc.returncode != 0:
err = (reset_proc.stderr or reset_proc.stdout or "未知错误").strip()
await update.message.reply_text(f"更新失败reset\n<code>{escape_html(err)}</code>", parse_mode='HTML')
return
pip_proc = await asyncio.to_thread(
_run_cmd,
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
repo_dir
)
if pip_proc.returncode != 0:
err = (pip_proc.stderr or pip_proc.stdout or "未知错误").strip()
await update.message.reply_text(f"依赖安装失败:\n<code>{escape_html(err[-1800:])}</code>", parse_mode='HTML')
return
await update.message.reply_text(
f"更新完成({escape_html(remote_name)} {escape_html(branch_name)}),正在重启机器人…",
parse_mode='HTML'
)
os.execv(sys.executable, [sys.executable] + sys.argv)
except Exception as e:
logger.error(f"/update failed: {e}")
await update.message.reply_text(f"更新异常:<code>{escape_html(str(e))}</code>", parse_mode='HTML')
# --- Main --- # --- Main ---
def main(): def main():
if not TELEGRAM_TOKEN: if not TELEGRAM_TOKEN:
@@ -1631,6 +1766,7 @@ def main():
BotCommand("import", "📥 导入订阅"), BotCommand("import", "📥 导入订阅"),
BotCommand("export", "📤 导出订阅"), BotCommand("export", "📤 导出订阅"),
BotCommand("set_currency", "💲 设置主货币"), BotCommand("set_currency", "💲 设置主货币"),
BotCommand("update", "🛠️ 拉取最新代码并重启"),
BotCommand("help", " 获取帮助"), BotCommand("help", " 获取帮助"),
BotCommand("cancel", "❌ 取消当前操作") BotCommand("cancel", "❌ 取消当前操作")
] ]
@@ -1725,6 +1861,7 @@ def main():
application.add_handler(CommandHandler('set_currency', set_currency)) application.add_handler(CommandHandler('set_currency', set_currency))
application.add_handler(CommandHandler('stats', stats)) application.add_handler(CommandHandler('stats', stats))
application.add_handler(CommandHandler('export', export_command)) application.add_handler(CommandHandler('export', export_command))
application.add_handler(CommandHandler('update', update_bot))
application.add_handler(CommandHandler('cancel', cancel)) application.add_handler(CommandHandler('cancel', cancel))
application.add_handler(add_conv) application.add_handler(add_conv)