Compare commits

..

18 Commits

Author SHA1 Message Date
Xiaolan Bot
f064f751f0 fix: harden callback payload parsing in add/edit flows 2026-02-22 11:53:22 +08:00
Xiaolan Bot
5eebf4bf66 fix: restore clear helper and unify notes/skip expiry handling 2026-02-22 11:44:13 +08:00
Xiaolan Bot
210af75e2c fix: guard add-sub conversation against expired state 2026-02-22 11:39:25 +08:00
Xiaolan Bot
d212d73c2a refactor: centralize enum validation for unit and renewal type 2026-02-22 11:33:41 +08:00
Xiaolan Bot
052966e07c fix: validate name/notes constraints in edit flow 2026-02-22 11:30:22 +08:00
Xiaolan Bot
095e88cad3 refactor: add input length guards for add/edit/import flows 2026-02-22 11:07:42 +08:00
Xiaolan Bot
276bb5fc83 fix: restore get_subs_list_keyboard helper 2026-02-22 02:54:56 +08:00
Xiaolan Bot
decb9c12c1 hardening: remove broad session clears and validate add flow inputs 2026-02-22 02:43:26 +08:00
Xiaolan Bot
ec06c5fac3 chore: tighten conversation entry callback patterns 2026-02-22 02:22:26 +08:00
Xiaolan Bot
ced65fc4da chore: tighten conversation fallback callback patterns 2026-02-22 02:17:00 +08:00
Xiaolan Bot
36b136289c hardening: close low-risk gaps and improve import validation 2026-02-22 02:11:26 +08:00
Xiaolan Bot
98a863f567 docs: add comprehensive README for open-source usage 2026-02-22 02:07:50 +08:00
Xiaolan Bot
15f9ceb841 refactor: use tempfile for import/export/stats artifacts 2026-02-22 01:48:44 +08:00
Xiaolan Bot
8601e78e17 hardening: validate ownership on entry points and failed updates 2026-02-22 01:41:46 +08:00
Xiaolan Bot
530d81b565 refactor: harden field mapping and sqlite boolean toggles 2026-02-22 01:33:02 +08:00
Xiaolan Bot
8354e38e89 fix: tighten callback pattern for id-based category routing 2026-02-22 01:27:11 +08:00
Xiaolan Bot
97bcee7258 fix: make category callbacks id-based and tighten ownership checks 2026-02-22 01:26:24 +08:00
Xiaolan Bot
db8257fdde fix: harden subscription callbacks and settings updates 2026-02-22 01:17:20 +08:00
2 changed files with 519 additions and 98 deletions

141
README.md
View File

@@ -1 +1,140 @@
power by gemini # SubMind
一个基于 Telegram 的订阅管理机器人,帮助你记录订阅、设置提醒、查看支出统计并导入/导出数据。
## 功能特性
- 添加订阅(名称、费用、货币、分类、到期日、周期、续费方式、备注)
- 📋 列出订阅并查看详情
- 🗂️ 按分类浏览订阅
- ✏️ 编辑订阅信息
- 🗑️ 删除订阅
- 🔔 提醒设置(到期日提醒、提前 N 天提醒、手动续费一键确认)
- 📊 统计图(按分类汇总月均支出)
- 📥 CSV 导入
- 📤 CSV 导出
- 💱 多货币换算(支持缓存汇率)
## 技术栈
- Python 3.10+
- [python-telegram-bot](https://github.com/python-telegram-bot/python-telegram-bot)
- SQLite
- pandas + matplotlib
- dateparser
## 快速开始
### 1) 克隆项目
```bash
git clone https://github.com/zkysimon/SubMind.git
cd SubMind
```
### 2) 安装依赖
建议使用虚拟环境:
```bash
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -U pip
pip install python-telegram-bot pandas matplotlib python-dateutil dateparser python-dotenv requests
```
### 3) 配置环境变量
复制并填写 `.env`
```bash
cp .env.example .env
```
`.env` 示例:
```env
TELEGRAM_TOKEN="<YOUR_TELEGRAM_BOT_TOKEN>"
EXCHANGE_API_KEY="<YOUR_EXCHANGE_API_KEY>"
```
说明:
- `TELEGRAM_TOKEN` 必填。
- `EXCHANGE_API_KEY` 可选(不填时不做在线汇率转换)。
### 4) 运行
```bash
python SubMind.py
```
首次启动会自动初始化数据库(默认 `submind.db`)。
## Bot 命令
- `/start` 开始使用
- `/add_sub` 添加订阅
- `/list_subs` 查看所有订阅
- `/list_categories` 按分类查看
- `/stats` 查看统计图
- `/import` 导入 CSV
- `/export` 导出 CSV
- `/set_currency <CODE>` 设置主货币(例如 `USD``CNY`
- `/help` 帮助
- `/cancel` 取消当前流程
## CSV 导入格式
导入文件需包含以下列:
- `name`
- `cost`
- `currency`
- `category`
- `next_due`
- `frequency_unit``day` / `week` / `month` / `year`
- `frequency_value`(正整数)
- `renewal_type``auto` / `manual`
- `notes`(可选)
示例:
```csv
name,cost,currency,category,next_due,frequency_unit,frequency_value,renewal_type,notes
Netflix,15.99,USD,影音,2026-03-01,month,1,auto,
VPS,60,CNY,开发,2026-03-12,month,1,manual,生产环境
```
## 数据与迁移
- 使用 SQLite文件`submind.db`)。
- 启动时自动执行兼容性检查和必要字段补齐。
- 建议升级前先备份数据库:
```bash
cp submind.db submind.db.bak
```
## 常见问题
### 1) 统计图中文乱码怎么办?
程序会尝试下载中文字体到 `fonts/` 目录;若网络不可达,可手动放置字体文件。
### 2) 为什么汇率没变化?
请检查 `EXCHANGE_API_KEY` 是否配置正确。未配置时会使用原货币金额。
### 3) 旧消息按钮点不动?
更新后建议重新执行 `/list_categories``/list_subs` 刷新最新按钮。
## 安全说明
- 项目已对订阅编辑/提醒等关键路径做用户归属校验(`user_id`)。
- 数据库操作使用参数化查询,并对可编辑字段做白名单约束。
## License
可按你的开源策略补充(例如 MIT
---
如果这个项目对你有帮助,欢迎 Star ⭐

View File

@@ -4,6 +4,7 @@ import requests
import datetime import datetime
import dateparser import dateparser
import logging import logging
import tempfile
import pandas as pd import pandas as pd
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -241,6 +242,32 @@ def format_frequency(unit, value) -> str:
return f"{value} {unit_map.get(unit, unit)}" return f"{value} {unit_map.get(unit, unit)}"
CATEGORY_CB_PREFIX = "list_subs_in_category_id_"
EDITABLE_SUB_FIELDS = {
'name': 'name',
'cost': 'cost',
'currency': 'currency',
'category': 'category',
'next_due': 'next_due',
'renewal_type': 'renewal_type',
'notes': 'notes'
}
MAX_NAME_LEN = 128
MAX_CATEGORY_LEN = 64
MAX_NOTES_LEN = 1000
VALID_FREQ_UNITS = {'day', 'week', 'month', 'year'}
VALID_RENEWAL_TYPES = {'auto', 'manual'}
def _build_category_callback_data(category_id: int) -> str:
return f"{CATEGORY_CB_PREFIX}{category_id}"
def _parse_category_id_from_callback(data: str) -> int | None:
payload = data.replace(CATEGORY_CB_PREFIX, '', 1)
return int(payload) if payload.isdigit() else None
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup: async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
@@ -262,6 +289,21 @@ async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> I
return InlineKeyboardMarkup(keyboard) return InlineKeyboardMarkup(keyboard)
def _clear_action_state(context: CallbackContext, keys: list[str]):
for key in keys:
context.user_data.pop(key, None)
def _get_new_sub_data_or_end(update: Update, context: CallbackContext):
sub_data = context.user_data.get('new_sub_data')
if sub_data is None:
message_obj = update.message or (update.callback_query.message if update.callback_query else None)
if message_obj:
# 统一提示,避免 KeyError 导致会话崩溃
return None, message_obj
return sub_data, None
# --- 自动任务 --- # --- 自动任务 ---
def update_past_due_dates(): def update_past_due_dates():
today = datetime.date.today() today = datetime.date.today()
@@ -352,7 +394,7 @@ async def start(update: Update, context: CallbackContext):
async def help_command(update: Update, context: CallbackContext): async def help_command(update: Update, context: CallbackContext):
help_text = f""" help_text = fr"""
*{escape_markdown(PROJECT_NAME, version=2)} 命令列表* *{escape_markdown(PROJECT_NAME, version=2)} 命令列表*
*🌟 核心功能* *🌟 核心功能*
/add\_sub \- 引导您添加一个新的订阅 /add\_sub \- 引导您添加一个新的订阅
@@ -413,7 +455,9 @@ async def stats(update: Update, context: CallbackContext):
plt.style.use('seaborn-v0_8-darkgrid') plt.style.use('seaborn-v0_8-darkgrid')
fig, ax = plt.subplots(figsize=(12, 12)) fig, ax = plt.subplots(figsize=(12, 12))
image_path = None
try:
autopct_function = make_autopct(category_costs.values, main_currency) autopct_function = make_autopct(category_costs.values, main_currency)
wedges, texts, autotexts = ax.pie(category_costs.values, wedges, texts, autotexts = ax.pie(category_costs.values,
@@ -435,17 +479,18 @@ async def stats(update: Update, context: CallbackContext):
autotext.set_color('white') autotext.set_color('white')
ax.axis('equal') ax.axis('equal')
fig.tight_layout() fig.tight_layout()
image_path = f'stats_{user_id}.png'
plt.savefig(image_path)
plt.close(fig)
try: with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp:
image_path = tmp.name
plt.savefig(image_path)
with open(image_path, 'rb') as photo: with open(image_path, 'rb') as photo:
await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。") await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。")
finally: finally:
if os.path.exists(image_path): plt.close(fig)
if image_path and os.path.exists(image_path):
os.remove(image_path) os.remove(image_path)
@@ -460,10 +505,12 @@ async def export_command(update: Update, context: CallbackContext):
await update.message.reply_text("您还没有任何订阅数据,无法导出。") await update.message.reply_text("您还没有任何订阅数据,无法导出。")
return return
export_path = f'export_{user_id}.csv' with tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) as tmp:
df.to_csv(export_path, index=False, encoding='utf-8-sig') export_path = tmp.name
try: try:
df.to_csv(export_path, index=False, encoding='utf-8-sig')
with open(export_path, 'rb') as file: with open(export_path, 'rb') as file:
await update.message.reply_document(document=file, filename='subscriptions.csv', await update.message.reply_document(document=file, filename='subscriptions.csv',
caption="您的订阅数据已导出为 CSV 文件。") caption="您的订阅数据已导出为 CSV 文件。")
@@ -485,7 +532,8 @@ async def import_upload_received(update: Update, context: CallbackContext):
return IMPORT_UPLOAD return IMPORT_UPLOAD
file = await update.message.document.get_file() file = await update.message.document.get_file()
file_path = f'import_{user_id}.csv' with tempfile.NamedTemporaryFile(prefix=f'import_{user_id}_', suffix='.csv', delete=False) as tmp:
file_path = tmp.name
try: try:
await file.download_to_drive(file_path) await file.download_to_drive(file_path)
df = pd.read_csv(file_path, encoding='utf-8-sig') df = pd.read_csv(file_path, encoding='utf-8-sig')
@@ -519,14 +567,26 @@ async def import_upload_received(update: Update, context: CallbackContext):
renewal_type = str(row['renewal_type']).lower() renewal_type = str(row['renewal_type']).lower()
if renewal_type not in valid_renewal_types: if renewal_type not in valid_renewal_types:
raise ValueError(f"无效续费类型: {renewal_type}") raise ValueError(f"无效续费类型: {renewal_type}")
notes = str(row['notes']) if pd.notna(row['notes']) else None notes = str(row['notes']).strip() if pd.notna(row['notes']) else None
if notes and len(notes) > MAX_NOTES_LEN:
raise ValueError(f"备注过长(>{MAX_NOTES_LEN}")
name = str(row['name']).strip()
category = str(row['category']).strip()
if not name:
raise ValueError("名称不能为空")
if not category:
raise ValueError("类别不能为空")
if len(name) > MAX_NAME_LEN:
raise ValueError(f"名称过长(>{MAX_NAME_LEN}")
if len(category) > MAX_CATEGORY_LEN:
raise ValueError(f"类别过长(>{MAX_CATEGORY_LEN}")
records.append(( records.append((
user_id, row['name'], cost, currency, row['category'], user_id, name, cost, currency, category,
next_due, frequency_unit, frequency_value, renewal_type, notes next_due, frequency_unit, frequency_value, renewal_type, notes
)) ))
except Exception as e: except Exception as e:
logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}") logger.error(f"Invalid row in CSV import, error: {e}")
await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误{e}") await update.message.reply_text(f"导入失败,存在无效行:{e}")
return ConversationHandler.END return ConversationHandler.END
with get_db_connection() as conn: with get_db_connection() as conn:
@@ -558,17 +618,36 @@ async def add_sub_start(update: Update, context: CallbackContext):
async def add_name_received(update: Update, context: CallbackContext): async def add_name_received(update: Update, context: CallbackContext):
context.user_data['new_sub_data']['name'] = update.message.text sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
name = update.message.text.strip()
if not name:
await update.message.reply_text("订阅名称不能为空。")
return ADD_NAME
if len(name) > MAX_NAME_LEN:
await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
return ADD_NAME
sub_data['name'] = name
await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2') await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2')
return ADD_COST return ADD_COST
async def add_cost_received(update: Update, context: CallbackContext): async def add_cost_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
try: try:
cost = float(update.message.text) cost = float(update.message.text)
if cost < 0: if cost < 0:
raise ValueError("费用不能为负数") raise ValueError("费用不能为负数")
context.user_data['new_sub_data']['cost'] = cost sub_data['cost'] = cost
except (ValueError, TypeError): except (ValueError, TypeError):
await update.message.reply_text("费用必须是有效的非负数字。") await update.message.reply_text("费用必须是有效的非负数字。")
return ADD_COST return ADD_COST
@@ -577,21 +656,36 @@ async def add_cost_received(update: Update, context: CallbackContext):
async def add_currency_received(update: Update, context: CallbackContext): async def add_currency_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
currency = update.message.text.upper() currency = update.message.text.upper()
if not (len(currency) == 3 and currency.isalpha()): if not (len(currency) == 3 and currency.isalpha()):
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY") await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY")
return ADD_CURRENCY return ADD_CURRENCY
context.user_data['new_sub_data']['currency'] = currency sub_data['currency'] = currency
await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2') await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2')
return ADD_CATEGORY return ADD_CATEGORY
async def add_category_received(update: Update, context: CallbackContext): async def add_category_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
user_id, category_name = update.effective_user.id, update.message.text.strip() user_id, category_name = update.effective_user.id, update.message.text.strip()
if not category_name: if not category_name:
await update.message.reply_text("类别不能为空。") await update.message.reply_text("类别不能为空。")
return ADD_CATEGORY return ADD_CATEGORY
context.user_data['new_sub_data']['category'] = category_name if len(category_name) > MAX_CATEGORY_LEN:
await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
return ADD_CATEGORY
sub_data['category'] = category_name
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name)) cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
@@ -602,11 +696,17 @@ async def add_category_received(update: Update, context: CallbackContext):
async def add_next_due_received(update: Update, context: CallbackContext): async def add_next_due_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
parsed_date = parse_date(update.message.text) parsed_date = parse_date(update.message.text)
if not parsed_date: if not parsed_date:
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。") await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。")
return ADD_NEXT_DUE return ADD_NEXT_DUE
context.user_data['new_sub_data']['next_due'] = parsed_date sub_data['next_due'] = parsed_date
keyboard = [ keyboard = [
[InlineKeyboardButton("", callback_data='freq_unit_day'), [InlineKeyboardButton("", callback_data='freq_unit_day'),
InlineKeyboardButton("", callback_data='freq_unit_week')], InlineKeyboardButton("", callback_data='freq_unit_week')],
@@ -619,19 +719,34 @@ async def add_next_due_received(update: Update, context: CallbackContext):
async def add_freq_unit_received(update: Update, context: CallbackContext): async def add_freq_unit_received(update: Update, context: CallbackContext):
sub_data, _ = _get_new_sub_data_or_end(update, context)
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
context.user_data['new_sub_data']['unit'] = query.data.split('_')[2] if sub_data is None:
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
unit = query.data.partition('freq_unit_')[2]
if unit not in VALID_FREQ_UNITS:
await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END
sub_data['unit'] = unit
await query.edit_message_text("第七步:请输入周期的*数量*例如每3个月输入 3", parse_mode='Markdown') await query.edit_message_text("第七步:请输入周期的*数量*例如每3个月输入 3", parse_mode='Markdown')
return ADD_FREQ_VALUE return ADD_FREQ_VALUE
async def add_freq_value_received(update: Update, context: CallbackContext): async def add_freq_value_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
try: try:
value = int(update.message.text) value = int(update.message.text)
if value <= 0: if value <= 0:
raise ValueError raise ValueError
context.user_data['new_sub_data']['value'] = value sub_data['value'] = value
except (ValueError, TypeError): except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的正整数。") await update.message.reply_text("请输入一个有效的正整数。")
return ADD_FREQ_VALUE return ADD_FREQ_VALUE
@@ -645,36 +760,55 @@ async def add_freq_value_received(update: Update, context: CallbackContext):
async def add_renewal_type_received(update: Update, context: CallbackContext): async def add_renewal_type_received(update: Update, context: CallbackContext):
sub_data, _ = _get_new_sub_data_or_end(update, context)
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
context.user_data['new_sub_data']['renewal_type'] = query.data.split('_')[1] if sub_data is None:
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
renewal_type = query.data.partition('renewal_')[2]
if renewal_type not in VALID_RENEWAL_TYPES:
await query.edit_message_text("错误:无效的续费类型,请重试。")
return ConversationHandler.END
sub_data['renewal_type'] = renewal_type
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip") await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip")
return ADD_NOTES return ADD_NOTES
async def add_notes_received(update: Update, context: CallbackContext): async def add_notes_received(update: Update, context: CallbackContext):
sub_data = context.user_data.get('new_sub_data') sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if not sub_data: if sub_data is None:
await update.message.reply_text("发生错误,请重试。") _clear_action_state(context, ['new_sub_data'])
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END return ConversationHandler.END
sub_data['notes'] = update.message.text
note = update.message.text.strip()
if len(note) > MAX_NOTES_LEN:
await update.message.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。")
return ADD_NOTES
sub_data['notes'] = note if note else None
save_subscription(update.effective_user.id, sub_data) save_subscription(update.effective_user.id, sub_data)
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!", await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
parse_mode='MarkdownV2') parse_mode='MarkdownV2')
context.user_data.clear() _clear_action_state(context, ['new_sub_data'])
return ConversationHandler.END return ConversationHandler.END
async def skip_notes(update: Update, context: CallbackContext): async def skip_notes(update: Update, context: CallbackContext):
sub_data = context.user_data.get('new_sub_data') sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if not sub_data: if sub_data is None:
await update.message.reply_text("发生错误,请重试。") _clear_action_state(context, ['new_sub_data'])
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END return ConversationHandler.END
sub_data['notes'] = None sub_data['notes'] = None
save_subscription(update.effective_user.id, sub_data) save_subscription(update.effective_user.id, sub_data)
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!", await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
parse_mode='MarkdownV2') parse_mode='MarkdownV2')
context.user_data.clear() _clear_action_state(context, ['new_sub_data'])
return ConversationHandler.END return ConversationHandler.END
@@ -706,7 +840,7 @@ async def list_categories(update: Update, context: CallbackContext):
user_id = update.effective_user.id user_id = update.effective_user.id
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT name FROM categories WHERE user_id = ? ORDER BY name", (user_id,)) cursor.execute("SELECT id, name FROM categories WHERE user_id = ? ORDER BY name", (user_id,))
categories = cursor.fetchall() categories = cursor.fetchall()
if not categories: if not categories:
if update.callback_query: if update.callback_query:
@@ -715,7 +849,11 @@ async def list_categories(update: Update, context: CallbackContext):
await update.message.reply_text("您还没有任何分类。") await update.message.reply_text("您还没有任何分类。")
return return
buttons = [InlineKeyboardButton(cat[0], callback_data=f"list_subs_in_category_{cat[0]}") for cat in categories] buttons = []
for cat in categories:
cat_id, cat_name = cat[0], cat[1]
buttons.append(InlineKeyboardButton(cat_name, callback_data=_build_category_callback_data(cat_id)))
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)] keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")]) keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
if update.callback_query: if update.callback_query:
@@ -765,8 +903,12 @@ async def show_subscription_view(update: Update, context: CallbackContext, sub_i
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')]) keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
if 'list_subs_in_category' in context.user_data: if 'list_subs_in_category' in context.user_data:
cat_filter = context.user_data['list_subs_in_category'] cat_filter = context.user_data['list_subs_in_category']
keyboard_buttons.append( category_id = context.user_data.get('list_subs_in_category_id')
[InlineKeyboardButton("« 返回分类订阅", callback_data=f'list_subs_in_category_{cat_filter}')]) if category_id:
back_cb = _build_category_callback_data(category_id)
else:
back_cb = 'list_categories'
keyboard_buttons.append([InlineKeyboardButton("« 返回分类订阅", callback_data=back_cb)])
else: else:
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')]) keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}") logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
@@ -785,9 +927,24 @@ async def button_callback_handler(update: Update, context: CallbackContext):
user_id = query.from_user.id user_id = query.from_user.id
logger.debug(f"Received callback query: {data} from user {user_id}") logger.debug(f"Received callback query: {data} from user {user_id}")
if data.startswith('list_subs_in_category_'): if data.startswith(CATEGORY_CB_PREFIX):
category = data.replace('list_subs_in_category_', '') category_id = _parse_category_id_from_callback(data)
if not category_id:
await query.edit_message_text("错误:无效或已过期的分类,请重新选择。")
return
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM categories WHERE id = ? AND user_id = ?", (category_id, user_id))
row = cursor.fetchone()
if not row:
await query.edit_message_text("错误:分类不存在或无权限。")
return
category = row['name']
context.user_data['list_subs_in_category'] = category context.user_data['list_subs_in_category'] = category
context.user_data['list_subs_in_category_id'] = category_id
keyboard = await get_subs_list_keyboard(user_id, category_filter=category) keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:" msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
if not keyboard: if not keyboard:
@@ -797,10 +954,12 @@ async def button_callback_handler(update: Update, context: CallbackContext):
return return
if data == 'list_categories': if data == 'list_categories':
context.user_data.pop('list_subs_in_category', None) context.user_data.pop('list_subs_in_category', None)
context.user_data.pop('list_subs_in_category_id', None)
await list_categories(update, context) await list_categories(update, context)
return return
if data == 'list_all_subs': if data == 'list_all_subs':
context.user_data.pop('list_subs_in_category', None) context.user_data.pop('list_subs_in_category', None)
context.user_data.pop('list_subs_in_category_id', None)
keyboard = await get_subs_list_keyboard(user_id) keyboard = await get_subs_list_keyboard(user_id)
if not keyboard: if not keyboard:
await query.edit_message_text("您还没有任何订阅。") await query.edit_message_text("您还没有任何订阅。")
@@ -821,33 +980,45 @@ async def button_callback_handler(update: Update, context: CallbackContext):
elif action == 'renewmanual': elif action == 'renewmanual':
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,)) cursor.execute(
"SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
sub = cursor.fetchone() sub = cursor.fetchone()
if sub: if sub:
today = datetime.date.today() today = datetime.date.today()
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value']) new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
if new_due_date: if new_due_date:
new_date_str = new_due_date.strftime('%Y-%m-%d') new_date_str = new_due_date.strftime('%Y-%m-%d')
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id)) cursor.execute(
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
(new_date_str, sub_id, user_id)
)
conn.commit() conn.commit()
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True) await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
await show_subscription_view(update, context, sub_id) await show_subscription_view(update, context, sub_id)
else: else:
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True) await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
else: else:
await query.answer("续费失败:订阅不存在。", show_alert=True) await query.answer("续费失败:订阅不存在或无权限", show_alert=True)
elif action == 'renewfromremind': elif action == 'renewfromremind':
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,)) cursor.execute(
"SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
sub = cursor.fetchone() sub = cursor.fetchone()
if sub: if sub:
today = datetime.date.today() today = datetime.date.today()
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value']) new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
if new_due_date: if new_due_date:
new_date_str = new_due_date.strftime('%Y-%m-%d') new_date_str = new_due_date.strftime('%Y-%m-%d')
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id)) cursor.execute(
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
(new_date_str, sub_id, user_id)
)
conn.commit() conn.commit()
safe_sub_name = escape_markdown(sub['name'], version=2) safe_sub_name = escape_markdown(sub['name'], version=2)
await query.edit_message_text( await query.edit_message_text(
@@ -858,11 +1029,17 @@ async def button_callback_handler(update: Update, context: CallbackContext):
else: else:
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True) await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
else: else:
await query.answer("续费失败:此订阅可能已被删除。", show_alert=True) await query.answer("续费失败:此订阅可能已被删除或无权限", show_alert=True)
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅已被删除*", await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅不存在或无权限*",
parse_mode='MarkdownV2', reply_markup=None) parse_mode='MarkdownV2', reply_markup=None)
elif action == 'delete': elif action == 'delete':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
if not cursor.fetchone():
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
return
keyboard = InlineKeyboardMarkup([ keyboard = InlineKeyboardMarkup([
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'), [InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')] InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
@@ -872,7 +1049,11 @@ async def button_callback_handler(update: Update, context: CallbackContext):
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
deleted = cursor.rowcount
conn.commit() conn.commit()
if deleted == 0:
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
return
await query.answer("订阅已删除") await query.answer("订阅已删除")
if 'list_subs_in_category' in context.user_data: if 'list_subs_in_category' in context.user_data:
category = context.user_data['list_subs_in_category'] category = context.user_data['list_subs_in_category']
@@ -907,7 +1088,21 @@ async def fallback_view_button(update: Update, context: CallbackContext):
async def edit_start(update: Update, context: CallbackContext): async def edit_start(update: Update, context: CallbackContext):
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
sub_id = query.data.split('_')[1] sub_id_str = query.data.split('_')[1]
user_id = query.from_user.id
if not sub_id_str.isdigit():
await query.edit_message_text("错误无效的订阅ID。")
return ConversationHandler.END
sub_id = int(sub_id_str)
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
if not cursor.fetchone():
await query.edit_message_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
logger.debug(f"Starting edit for sub_id: {sub_id}") logger.debug(f"Starting edit for sub_id: {sub_id}")
context.user_data['sub_id_for_action'] = sub_id context.user_data['sub_id_for_action'] = sub_id
keyboard = [ keyboard = [
@@ -960,7 +1155,11 @@ async def edit_field_selected(update: Update, context: CallbackContext):
async def edit_freq_unit_received(update: Update, context: CallbackContext): async def edit_freq_unit_received(update: Update, context: CallbackContext):
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
context.user_data['new_freq_unit'] = query.data.split('_')[2] unit = query.data.partition('freq_unit_')[2]
if unit not in VALID_FREQ_UNITS:
await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END
context.user_data['new_freq_unit'] = unit
await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2') await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2')
return EDIT_FREQ_VALUE return EDIT_FREQ_VALUE
@@ -975,14 +1174,23 @@ async def edit_freq_value_received(update: Update, context: CallbackContext):
await update.message.reply_text("请输入一个有效的正整数。") await update.message.reply_text("请输入一个有效的正整数。")
return EDIT_FREQ_VALUE return EDIT_FREQ_VALUE
unit = context.user_data.get('new_freq_unit') unit = context.user_data.get('new_freq_unit')
try:
sub_id = int(context.user_data.get('sub_id_for_action')) sub_id = int(context.user_data.get('sub_id_for_action'))
except (ValueError, TypeError):
await update.message.reply_text("错误:会话已过期,请重试。")
return ConversationHandler.END
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?", cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?",
(unit, value, sub_id, user_id)) (unit, value, sub_id, user_id))
if cursor.rowcount == 0:
await update.message.reply_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit() conn.commit()
await update.message.reply_text("✅ 周期已更新!") await update.message.reply_text("✅ 周期已更新!")
context.user_data.clear() _clear_action_state(context, ['sub_id_for_action', 'new_freq_unit', 'field_to_edit'])
await show_subscription_view(update, context, sub_id) await show_subscription_view(update, context, sub_id)
return ConversationHandler.END return ConversationHandler.END
@@ -1000,6 +1208,13 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
if update.effective_message: if update.effective_message:
await update.effective_message.reply_text("错误:未选择要编辑的字段。") await update.effective_message.reply_text("错误:未选择要编辑的字段。")
return ConversationHandler.END return ConversationHandler.END
db_field = EDITABLE_SUB_FIELDS.get(field)
if not db_field:
if update.effective_message:
await update.effective_message.reply_text("错误:不允许编辑该字段。")
logger.warning(f"Blocked unsafe field update attempt: {field}")
return ConversationHandler.END
query, new_value = update.callback_query, "" query, new_value = update.callback_query, ""
message_to_reply = update.effective_message message_to_reply = update.effective_message
@@ -1021,25 +1236,55 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
if new_value < 0: if new_value < 0:
raise ValueError("费用不能为负数") raise ValueError("费用不能为负数")
except (ValueError, TypeError): except (ValueError, TypeError):
if message_to_reply: await message_to_reply.reply_text("费用必须是有效的非负数字。") if message_to_reply:
await message_to_reply.reply_text("费用必须是有效的非负数字。")
validation_failed = True
elif field == 'name':
new_value = str(new_value).strip()
if not new_value:
if message_to_reply:
await message_to_reply.reply_text("名称不能为空。")
validation_failed = True
elif len(new_value) > MAX_NAME_LEN:
if message_to_reply:
await message_to_reply.reply_text(f"名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
validation_failed = True validation_failed = True
elif field == 'currency': elif field == 'currency':
new_value = str(new_value).upper() new_value = str(new_value).upper()
if not (len(new_value) == 3 and new_value.isalpha()): if not (len(new_value) == 3 and new_value.isalpha()):
if message_to_reply: await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY") if message_to_reply:
await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY")
validation_failed = True validation_failed = True
elif field == 'next_due': elif field == 'next_due':
parsed = parse_date(str(new_value)) parsed = parse_date(str(new_value))
if not parsed: if not parsed:
if message_to_reply: await message_to_reply.reply_text( if message_to_reply:
"无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。") await message_to_reply.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。")
validation_failed = True validation_failed = True
else: else:
new_value = parsed new_value = parsed
elif field == 'renewal_type':
if str(new_value) not in VALID_RENEWAL_TYPES:
if message_to_reply:
await message_to_reply.reply_text("续费方式只能为 auto 或 manual。")
validation_failed = True
elif field == 'notes':
note_val = str(new_value).strip()
if note_val and len(note_val) > MAX_NOTES_LEN:
if message_to_reply:
await message_to_reply.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。")
validation_failed = True
else:
new_value = note_val if note_val else None
elif field == 'category': elif field == 'category':
new_value = str(new_value).strip() new_value = str(new_value).strip()
if not new_value: if not new_value:
if message_to_reply: await message_to_reply.reply_text("类别不能为空。") if message_to_reply:
await message_to_reply.reply_text("类别不能为空。")
validation_failed = True
elif len(new_value) > MAX_CATEGORY_LEN:
if message_to_reply:
await message_to_reply.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
validation_failed = True validation_failed = True
else: else:
with get_db_connection() as conn: with get_db_connection() as conn:
@@ -1052,30 +1297,37 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f"UPDATE subscriptions SET {field} = ? WHERE id = ? AND user_id = ?", cursor.execute(f"UPDATE subscriptions SET {db_field} = ? WHERE id = ? AND user_id = ?",
(new_value, sub_id, user_id)) (new_value, sub_id, user_id))
if cursor.rowcount == 0:
if message_to_reply:
await message_to_reply.reply_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit() conn.commit()
if query: if query:
await query.answer(f"✅ 字段已更新!") await query.answer("✅ 字段已更新!")
elif message_to_reply: elif message_to_reply:
await message_to_reply.reply_text("✅ 字段已更新!") await message_to_reply.reply_text("✅ 字段已更新!")
context.user_data.clear() _clear_action_state(context, ['sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
await show_subscription_view(update, context, sub_id) await show_subscription_view(update, context, sub_id)
return ConversationHandler.END return ConversationHandler.END
# --- Reminder Settings Conversation --- # --- Reminder Settings Conversation ---
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int): async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
user_id = query.from_user.id
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days FROM subscriptions WHERE id = ?", "SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days "
(sub_id,)) "FROM subscriptions WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
sub = cursor.fetchone() sub = cursor.fetchone()
if not sub: if not sub:
await query.edit_message_text("错误:找不到该订阅。") await query.edit_message_text("错误:找不到该订阅或无权限")
return return
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒" enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒" due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
@@ -1096,10 +1348,20 @@ async def remind_settings_start(update: Update, context: CallbackContext):
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
sub_id_str = query.data.partition('_')[2] sub_id_str = query.data.partition('_')[2]
user_id = query.from_user.id
if not sub_id_str.isdigit(): if not sub_id_str.isdigit():
await query.edit_message_text("错误无效的订阅ID。") await query.edit_message_text("错误无效的订阅ID。")
return ConversationHandler.END return ConversationHandler.END
sub_id = int(sub_id_str) sub_id = int(sub_id_str)
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
if not cursor.fetchone():
await query.edit_message_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
logger.debug(f"Starting reminder settings for sub_id: {sub_id}") logger.debug(f"Starting reminder settings for sub_id: {sub_id}")
context.user_data['sub_id_for_action'] = sub_id context.user_data['sub_id_for_action'] = sub_id
await _display_reminder_settings(query, context, sub_id) await _display_reminder_settings(query, context, sub_id)
@@ -1119,6 +1381,8 @@ async def remind_action_handler(update: Update, context: CallbackContext):
await query.edit_message_text("错误:会话已过期,请重试。") await query.edit_message_text("错误:会话已过期,请重试。")
return ConversationHandler.END return ConversationHandler.END
user_id = query.from_user.id
if action == 'ask_days': if action == 'ask_days':
await query.edit_message_text("请输入您想提前几天收到提醒输入0则不提前提醒") await query.edit_message_text("请输入您想提前几天收到提醒输入0则不提前提醒")
return REMIND_GET_DAYS return REMIND_GET_DAYS
@@ -1130,10 +1394,20 @@ async def remind_action_handler(update: Update, context: CallbackContext):
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
if action == 'toggle_enabled': if action == 'toggle_enabled':
cursor.execute("UPDATE subscriptions SET reminders_enabled = NOT reminders_enabled WHERE id = ?", (sub_id,)) cursor.execute(
"UPDATE subscriptions SET reminders_enabled = CASE WHEN reminders_enabled THEN 0 ELSE 1 END "
"WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
elif action == 'toggle_due_date': elif action == 'toggle_due_date':
cursor.execute("UPDATE subscriptions SET reminder_on_due_date = NOT reminder_on_due_date WHERE id = ?", cursor.execute(
(sub_id,)) "UPDATE subscriptions SET reminder_on_due_date = CASE WHEN reminder_on_due_date THEN 0 ELSE 1 END "
"WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
if cursor.rowcount == 0:
await query.edit_message_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit() conn.commit()
await _display_reminder_settings(query, context, sub_id) await _display_reminder_settings(query, context, sub_id)
return REMIND_SELECT_ACTION return REMIND_SELECT_ACTION
@@ -1144,16 +1418,20 @@ async def remind_days_received(update: Update, context: CallbackContext):
if not sub_id: if not sub_id:
await update.message.reply_text("错误:会话已过期,请重试。") await update.message.reply_text("错误:会话已过期,请重试。")
return ConversationHandler.END return ConversationHandler.END
user_id = update.effective_user.id
try: try:
days = int(update.message.text) days = int(update.message.text)
if days < 0: if days < 0:
raise ValueError raise ValueError
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ?", (days, sub_id)) cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ? AND user_id = ?", (days, sub_id, user_id))
if cursor.rowcount == 0:
await update.message.reply_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit() conn.commit()
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。") await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
context.user_data.clear() _clear_action_state(context, ['sub_id_for_action'])
await show_subscription_view(update, context, sub_id) await show_subscription_view(update, context, sub_id)
except (ValueError, TypeError): except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的非负整数。") await update.message.reply_text("请输入一个有效的非负整数。")
@@ -1173,14 +1451,18 @@ async def set_currency(update: Update, context: CallbackContext):
return return
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO users (user_id, main_currency) VALUES (?, ?)", (user_id, new_currency)) cursor.execute("""
INSERT INTO users (user_id, main_currency)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET main_currency = excluded.main_currency
""", (user_id, new_currency))
conn.commit() conn.commit()
await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}", await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}",
parse_mode='MarkdownV2') parse_mode='MarkdownV2')
async def cancel(update: Update, context: CallbackContext): async def cancel(update: Update, context: CallbackContext):
context.user_data.clear() _clear_action_state(context, ['new_sub_data', 'sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
if update.callback_query: if update.callback_query:
await update.callback_query.answer() await update.callback_query.answer()
await update.callback_query.edit_message_text('操作已取消。') await update.callback_query.edit_message_text('操作已取消。')
@@ -1254,7 +1536,7 @@ def main():
) )
edit_conv = ConversationHandler( edit_conv = ConversationHandler(
entry_points=[CallbackQueryHandler(edit_start, pattern='^edit_')], entry_points=[CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$')],
states={ states={
EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')], EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')],
EDIT_GET_NEW_VALUE: [ EDIT_GET_NEW_VALUE: [
@@ -1268,15 +1550,15 @@ def main():
fallbacks=[ fallbacks=[
CommandHandler('cancel', cancel), CommandHandler('cancel', cancel),
# 【修改】使用新的包装函数来确保会话能正确结束 # 【修改】使用新的包装函数来确保会话能正确结束
CallbackQueryHandler(fallback_view_button, pattern='^view_'), CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'),
CallbackQueryHandler(edit_start, pattern='^edit_'), CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'),
CallbackQueryHandler(remind_settings_start, pattern='^remind_') CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')
], ],
per_message=False per_message=False
) )
remind_conv = ConversationHandler( remind_conv = ConversationHandler(
entry_points=[CallbackQueryHandler(remind_settings_start, pattern='^remind_')], entry_points=[CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')],
states={ states={
REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')], REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')],
REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)], REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)],
@@ -1284,9 +1566,9 @@ def main():
fallbacks=[ fallbacks=[
CommandHandler('cancel', cancel), CommandHandler('cancel', cancel),
# 【修改】使用新的包装函数来确保会话能正确结束 # 【修改】使用新的包装函数来确保会话能正确结束
CallbackQueryHandler(fallback_view_button, pattern='^view_'), CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'),
CallbackQueryHandler(edit_start, pattern='^edit_'), CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'),
CallbackQueryHandler(remind_settings_start, pattern='^remind_') CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')
], ],
per_message=False per_message=False
) )
@@ -1299,7 +1581,7 @@ def main():
fallbacks=[CommandHandler('cancel', cancel)] fallbacks=[CommandHandler('cancel', cancel)]
) )
button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_.+|list_categories|list_all_subs)$' button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_id_\d+|list_categories|list_all_subs)$'
application.add_handler(CommandHandler('start', start)) application.add_handler(CommandHandler('start', start))
application.add_handler(CommandHandler('help', help_command)) application.add_handler(CommandHandler('help', help_command))