Compare commits
18 Commits
c80914f257
...
f064f751f0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f064f751f0 | ||
|
|
5eebf4bf66 | ||
|
|
210af75e2c | ||
|
|
d212d73c2a | ||
|
|
052966e07c | ||
|
|
095e88cad3 | ||
|
|
276bb5fc83 | ||
|
|
decb9c12c1 | ||
|
|
ec06c5fac3 | ||
|
|
ced65fc4da | ||
|
|
36b136289c | ||
|
|
98a863f567 | ||
|
|
15f9ceb841 | ||
|
|
8601e78e17 | ||
|
|
530d81b565 | ||
|
|
8354e38e89 | ||
|
|
97bcee7258 | ||
|
|
db8257fdde |
141
README.md
141
README.md
@@ -1 +1,140 @@
|
|||||||
power by gemini
|
# SubMind
|
||||||
|
|
||||||
|
一个基于 Telegram 的订阅管理机器人,帮助你记录订阅、设置提醒、查看支出统计并导入/导出数据。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- ➕ 添加订阅(名称、费用、货币、分类、到期日、周期、续费方式、备注)
|
||||||
|
- 📋 列出订阅并查看详情
|
||||||
|
- 🗂️ 按分类浏览订阅
|
||||||
|
- ✏️ 编辑订阅信息
|
||||||
|
- 🗑️ 删除订阅
|
||||||
|
- 🔔 提醒设置(到期日提醒、提前 N 天提醒、手动续费一键确认)
|
||||||
|
- 📊 统计图(按分类汇总月均支出)
|
||||||
|
- 📥 CSV 导入
|
||||||
|
- 📤 CSV 导出
|
||||||
|
- 💱 多货币换算(支持缓存汇率)
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
- Python 3.10+
|
||||||
|
- [python-telegram-bot](https://github.com/python-telegram-bot/python-telegram-bot)
|
||||||
|
- SQLite
|
||||||
|
- pandas + matplotlib
|
||||||
|
- dateparser
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1) 克隆项目
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/zkysimon/SubMind.git
|
||||||
|
cd SubMind
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2) 安装依赖
|
||||||
|
|
||||||
|
建议使用虚拟环境:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
||||||
|
pip install -U pip
|
||||||
|
pip install python-telegram-bot pandas matplotlib python-dateutil dateparser python-dotenv requests
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3) 配置环境变量
|
||||||
|
|
||||||
|
复制并填写 `.env`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
`.env` 示例:
|
||||||
|
|
||||||
|
```env
|
||||||
|
TELEGRAM_TOKEN="<YOUR_TELEGRAM_BOT_TOKEN>"
|
||||||
|
EXCHANGE_API_KEY="<YOUR_EXCHANGE_API_KEY>"
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- `TELEGRAM_TOKEN` 必填。
|
||||||
|
- `EXCHANGE_API_KEY` 可选(不填时不做在线汇率转换)。
|
||||||
|
|
||||||
|
### 4) 运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python SubMind.py
|
||||||
|
```
|
||||||
|
|
||||||
|
首次启动会自动初始化数据库(默认 `submind.db`)。
|
||||||
|
|
||||||
|
## Bot 命令
|
||||||
|
|
||||||
|
- `/start` 开始使用
|
||||||
|
- `/add_sub` 添加订阅
|
||||||
|
- `/list_subs` 查看所有订阅
|
||||||
|
- `/list_categories` 按分类查看
|
||||||
|
- `/stats` 查看统计图
|
||||||
|
- `/import` 导入 CSV
|
||||||
|
- `/export` 导出 CSV
|
||||||
|
- `/set_currency <CODE>` 设置主货币(例如 `USD`、`CNY`)
|
||||||
|
- `/help` 帮助
|
||||||
|
- `/cancel` 取消当前流程
|
||||||
|
|
||||||
|
## CSV 导入格式
|
||||||
|
|
||||||
|
导入文件需包含以下列:
|
||||||
|
|
||||||
|
- `name`
|
||||||
|
- `cost`
|
||||||
|
- `currency`
|
||||||
|
- `category`
|
||||||
|
- `next_due`
|
||||||
|
- `frequency_unit`(`day` / `week` / `month` / `year`)
|
||||||
|
- `frequency_value`(正整数)
|
||||||
|
- `renewal_type`(`auto` / `manual`)
|
||||||
|
- `notes`(可选)
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```csv
|
||||||
|
name,cost,currency,category,next_due,frequency_unit,frequency_value,renewal_type,notes
|
||||||
|
Netflix,15.99,USD,影音,2026-03-01,month,1,auto,
|
||||||
|
VPS,60,CNY,开发,2026-03-12,month,1,manual,生产环境
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数据与迁移
|
||||||
|
|
||||||
|
- 使用 SQLite(文件:`submind.db`)。
|
||||||
|
- 启动时自动执行兼容性检查和必要字段补齐。
|
||||||
|
- 建议升级前先备份数据库:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp submind.db submind.db.bak
|
||||||
|
```
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### 1) 统计图中文乱码怎么办?
|
||||||
|
程序会尝试下载中文字体到 `fonts/` 目录;若网络不可达,可手动放置字体文件。
|
||||||
|
|
||||||
|
### 2) 为什么汇率没变化?
|
||||||
|
请检查 `EXCHANGE_API_KEY` 是否配置正确。未配置时会使用原货币金额。
|
||||||
|
|
||||||
|
### 3) 旧消息按钮点不动?
|
||||||
|
更新后建议重新执行 `/list_categories` 或 `/list_subs` 刷新最新按钮。
|
||||||
|
|
||||||
|
## 安全说明
|
||||||
|
|
||||||
|
- 项目已对订阅编辑/提醒等关键路径做用户归属校验(`user_id`)。
|
||||||
|
- 数据库操作使用参数化查询,并对可编辑字段做白名单约束。
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
可按你的开源策略补充(例如 MIT)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
如果这个项目对你有帮助,欢迎 Star ⭐
|
||||||
430
SubMind.py
430
SubMind.py
@@ -4,6 +4,7 @@ import requests
|
|||||||
import datetime
|
import datetime
|
||||||
import dateparser
|
import dateparser
|
||||||
import logging
|
import logging
|
||||||
|
import tempfile
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -241,6 +242,32 @@ def format_frequency(unit, value) -> str:
|
|||||||
return f"每 {value} {unit_map.get(unit, unit)}"
|
return f"每 {value} {unit_map.get(unit, unit)}"
|
||||||
|
|
||||||
|
|
||||||
|
CATEGORY_CB_PREFIX = "list_subs_in_category_id_"
|
||||||
|
EDITABLE_SUB_FIELDS = {
|
||||||
|
'name': 'name',
|
||||||
|
'cost': 'cost',
|
||||||
|
'currency': 'currency',
|
||||||
|
'category': 'category',
|
||||||
|
'next_due': 'next_due',
|
||||||
|
'renewal_type': 'renewal_type',
|
||||||
|
'notes': 'notes'
|
||||||
|
}
|
||||||
|
MAX_NAME_LEN = 128
|
||||||
|
MAX_CATEGORY_LEN = 64
|
||||||
|
MAX_NOTES_LEN = 1000
|
||||||
|
VALID_FREQ_UNITS = {'day', 'week', 'month', 'year'}
|
||||||
|
VALID_RENEWAL_TYPES = {'auto', 'manual'}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_category_callback_data(category_id: int) -> str:
|
||||||
|
return f"{CATEGORY_CB_PREFIX}{category_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_category_id_from_callback(data: str) -> int | None:
|
||||||
|
payload = data.replace(CATEGORY_CB_PREFIX, '', 1)
|
||||||
|
return int(payload) if payload.isdigit() else None
|
||||||
|
|
||||||
|
|
||||||
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
|
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -262,6 +289,21 @@ async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> I
|
|||||||
return InlineKeyboardMarkup(keyboard)
|
return InlineKeyboardMarkup(keyboard)
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_action_state(context: CallbackContext, keys: list[str]):
|
||||||
|
for key in keys:
|
||||||
|
context.user_data.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_new_sub_data_or_end(update: Update, context: CallbackContext):
|
||||||
|
sub_data = context.user_data.get('new_sub_data')
|
||||||
|
if sub_data is None:
|
||||||
|
message_obj = update.message or (update.callback_query.message if update.callback_query else None)
|
||||||
|
if message_obj:
|
||||||
|
# 统一提示,避免 KeyError 导致会话崩溃
|
||||||
|
return None, message_obj
|
||||||
|
return sub_data, None
|
||||||
|
|
||||||
|
|
||||||
# --- 自动任务 ---
|
# --- 自动任务 ---
|
||||||
def update_past_due_dates():
|
def update_past_due_dates():
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
@@ -352,7 +394,7 @@ async def start(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
|
|
||||||
async def help_command(update: Update, context: CallbackContext):
|
async def help_command(update: Update, context: CallbackContext):
|
||||||
help_text = f"""
|
help_text = fr"""
|
||||||
*{escape_markdown(PROJECT_NAME, version=2)} 命令列表*
|
*{escape_markdown(PROJECT_NAME, version=2)} 命令列表*
|
||||||
*🌟 核心功能*
|
*🌟 核心功能*
|
||||||
/add\_sub \- 引导您添加一个新的订阅
|
/add\_sub \- 引导您添加一个新的订阅
|
||||||
@@ -413,7 +455,9 @@ async def stats(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
plt.style.use('seaborn-v0_8-darkgrid')
|
plt.style.use('seaborn-v0_8-darkgrid')
|
||||||
fig, ax = plt.subplots(figsize=(12, 12))
|
fig, ax = plt.subplots(figsize=(12, 12))
|
||||||
|
image_path = None
|
||||||
|
|
||||||
|
try:
|
||||||
autopct_function = make_autopct(category_costs.values, main_currency)
|
autopct_function = make_autopct(category_costs.values, main_currency)
|
||||||
|
|
||||||
wedges, texts, autotexts = ax.pie(category_costs.values,
|
wedges, texts, autotexts = ax.pie(category_costs.values,
|
||||||
@@ -435,17 +479,18 @@ async def stats(update: Update, context: CallbackContext):
|
|||||||
autotext.set_color('white')
|
autotext.set_color('white')
|
||||||
|
|
||||||
ax.axis('equal')
|
ax.axis('equal')
|
||||||
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
image_path = f'stats_{user_id}.png'
|
|
||||||
plt.savefig(image_path)
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
try:
|
with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp:
|
||||||
|
image_path = tmp.name
|
||||||
|
|
||||||
|
plt.savefig(image_path)
|
||||||
|
|
||||||
with open(image_path, 'rb') as photo:
|
with open(image_path, 'rb') as photo:
|
||||||
await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。")
|
await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。")
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(image_path):
|
plt.close(fig)
|
||||||
|
if image_path and os.path.exists(image_path):
|
||||||
os.remove(image_path)
|
os.remove(image_path)
|
||||||
|
|
||||||
|
|
||||||
@@ -460,10 +505,12 @@ async def export_command(update: Update, context: CallbackContext):
|
|||||||
await update.message.reply_text("您还没有任何订阅数据,无法导出。")
|
await update.message.reply_text("您还没有任何订阅数据,无法导出。")
|
||||||
return
|
return
|
||||||
|
|
||||||
export_path = f'export_{user_id}.csv'
|
with tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) as tmp:
|
||||||
df.to_csv(export_path, index=False, encoding='utf-8-sig')
|
export_path = tmp.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
df.to_csv(export_path, index=False, encoding='utf-8-sig')
|
||||||
|
|
||||||
with open(export_path, 'rb') as file:
|
with open(export_path, 'rb') as file:
|
||||||
await update.message.reply_document(document=file, filename='subscriptions.csv',
|
await update.message.reply_document(document=file, filename='subscriptions.csv',
|
||||||
caption="您的订阅数据已导出为 CSV 文件。")
|
caption="您的订阅数据已导出为 CSV 文件。")
|
||||||
@@ -485,7 +532,8 @@ async def import_upload_received(update: Update, context: CallbackContext):
|
|||||||
return IMPORT_UPLOAD
|
return IMPORT_UPLOAD
|
||||||
|
|
||||||
file = await update.message.document.get_file()
|
file = await update.message.document.get_file()
|
||||||
file_path = f'import_{user_id}.csv'
|
with tempfile.NamedTemporaryFile(prefix=f'import_{user_id}_', suffix='.csv', delete=False) as tmp:
|
||||||
|
file_path = tmp.name
|
||||||
try:
|
try:
|
||||||
await file.download_to_drive(file_path)
|
await file.download_to_drive(file_path)
|
||||||
df = pd.read_csv(file_path, encoding='utf-8-sig')
|
df = pd.read_csv(file_path, encoding='utf-8-sig')
|
||||||
@@ -519,14 +567,26 @@ async def import_upload_received(update: Update, context: CallbackContext):
|
|||||||
renewal_type = str(row['renewal_type']).lower()
|
renewal_type = str(row['renewal_type']).lower()
|
||||||
if renewal_type not in valid_renewal_types:
|
if renewal_type not in valid_renewal_types:
|
||||||
raise ValueError(f"无效续费类型: {renewal_type}")
|
raise ValueError(f"无效续费类型: {renewal_type}")
|
||||||
notes = str(row['notes']) if pd.notna(row['notes']) else None
|
notes = str(row['notes']).strip() if pd.notna(row['notes']) else None
|
||||||
|
if notes and len(notes) > MAX_NOTES_LEN:
|
||||||
|
raise ValueError(f"备注过长(>{MAX_NOTES_LEN})")
|
||||||
|
name = str(row['name']).strip()
|
||||||
|
category = str(row['category']).strip()
|
||||||
|
if not name:
|
||||||
|
raise ValueError("名称不能为空")
|
||||||
|
if not category:
|
||||||
|
raise ValueError("类别不能为空")
|
||||||
|
if len(name) > MAX_NAME_LEN:
|
||||||
|
raise ValueError(f"名称过长(>{MAX_NAME_LEN})")
|
||||||
|
if len(category) > MAX_CATEGORY_LEN:
|
||||||
|
raise ValueError(f"类别过长(>{MAX_CATEGORY_LEN})")
|
||||||
records.append((
|
records.append((
|
||||||
user_id, row['name'], cost, currency, row['category'],
|
user_id, name, cost, currency, category,
|
||||||
next_due, frequency_unit, frequency_value, renewal_type, notes
|
next_due, frequency_unit, frequency_value, renewal_type, notes
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}")
|
logger.error(f"Invalid row in CSV import, error: {e}")
|
||||||
await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误:{e}")
|
await update.message.reply_text(f"导入失败,存在无效行:{e}")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
@@ -558,17 +618,36 @@ async def add_sub_start(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
|
|
||||||
async def add_name_received(update: Update, context: CallbackContext):
|
async def add_name_received(update: Update, context: CallbackContext):
|
||||||
context.user_data['new_sub_data']['name'] = update.message.text
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
|
if sub_data is None:
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
name = update.message.text.strip()
|
||||||
|
if not name:
|
||||||
|
await update.message.reply_text("订阅名称不能为空。")
|
||||||
|
return ADD_NAME
|
||||||
|
if len(name) > MAX_NAME_LEN:
|
||||||
|
await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
|
||||||
|
return ADD_NAME
|
||||||
|
sub_data['name'] = name
|
||||||
await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2')
|
await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2')
|
||||||
return ADD_COST
|
return ADD_COST
|
||||||
|
|
||||||
|
|
||||||
async def add_cost_received(update: Update, context: CallbackContext):
|
async def add_cost_received(update: Update, context: CallbackContext):
|
||||||
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
|
if sub_data is None:
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cost = float(update.message.text)
|
cost = float(update.message.text)
|
||||||
if cost < 0:
|
if cost < 0:
|
||||||
raise ValueError("费用不能为负数")
|
raise ValueError("费用不能为负数")
|
||||||
context.user_data['new_sub_data']['cost'] = cost
|
sub_data['cost'] = cost
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
await update.message.reply_text("费用必须是有效的非负数字。")
|
await update.message.reply_text("费用必须是有效的非负数字。")
|
||||||
return ADD_COST
|
return ADD_COST
|
||||||
@@ -577,21 +656,36 @@ async def add_cost_received(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
|
|
||||||
async def add_currency_received(update: Update, context: CallbackContext):
|
async def add_currency_received(update: Update, context: CallbackContext):
|
||||||
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
|
if sub_data is None:
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
currency = update.message.text.upper()
|
currency = update.message.text.upper()
|
||||||
if not (len(currency) == 3 and currency.isalpha()):
|
if not (len(currency) == 3 and currency.isalpha()):
|
||||||
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
||||||
return ADD_CURRENCY
|
return ADD_CURRENCY
|
||||||
context.user_data['new_sub_data']['currency'] = currency
|
sub_data['currency'] = currency
|
||||||
await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2')
|
await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2')
|
||||||
return ADD_CATEGORY
|
return ADD_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
async def add_category_received(update: Update, context: CallbackContext):
|
async def add_category_received(update: Update, context: CallbackContext):
|
||||||
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
|
if sub_data is None:
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
user_id, category_name = update.effective_user.id, update.message.text.strip()
|
user_id, category_name = update.effective_user.id, update.message.text.strip()
|
||||||
if not category_name:
|
if not category_name:
|
||||||
await update.message.reply_text("类别不能为空。")
|
await update.message.reply_text("类别不能为空。")
|
||||||
return ADD_CATEGORY
|
return ADD_CATEGORY
|
||||||
context.user_data['new_sub_data']['category'] = category_name
|
if len(category_name) > MAX_CATEGORY_LEN:
|
||||||
|
await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
|
||||||
|
return ADD_CATEGORY
|
||||||
|
sub_data['category'] = category_name
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
|
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
|
||||||
@@ -602,11 +696,17 @@ async def add_category_received(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
|
|
||||||
async def add_next_due_received(update: Update, context: CallbackContext):
|
async def add_next_due_received(update: Update, context: CallbackContext):
|
||||||
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
|
if sub_data is None:
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
parsed_date = parse_date(update.message.text)
|
parsed_date = parse_date(update.message.text)
|
||||||
if not parsed_date:
|
if not parsed_date:
|
||||||
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
||||||
return ADD_NEXT_DUE
|
return ADD_NEXT_DUE
|
||||||
context.user_data['new_sub_data']['next_due'] = parsed_date
|
sub_data['next_due'] = parsed_date
|
||||||
keyboard = [
|
keyboard = [
|
||||||
[InlineKeyboardButton("天", callback_data='freq_unit_day'),
|
[InlineKeyboardButton("天", callback_data='freq_unit_day'),
|
||||||
InlineKeyboardButton("周", callback_data='freq_unit_week')],
|
InlineKeyboardButton("周", callback_data='freq_unit_week')],
|
||||||
@@ -619,19 +719,34 @@ async def add_next_due_received(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
|
|
||||||
async def add_freq_unit_received(update: Update, context: CallbackContext):
|
async def add_freq_unit_received(update: Update, context: CallbackContext):
|
||||||
|
sub_data, _ = _get_new_sub_data_or_end(update, context)
|
||||||
query = update.callback_query
|
query = update.callback_query
|
||||||
await query.answer()
|
await query.answer()
|
||||||
context.user_data['new_sub_data']['unit'] = query.data.split('_')[2]
|
if sub_data is None:
|
||||||
|
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
unit = query.data.partition('freq_unit_')[2]
|
||||||
|
if unit not in VALID_FREQ_UNITS:
|
||||||
|
await query.edit_message_text("错误:无效的周期单位,请重试。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
sub_data['unit'] = unit
|
||||||
await query.edit_message_text("第七步:请输入周期的*数量*(例如:每3个月,输入 3)", parse_mode='Markdown')
|
await query.edit_message_text("第七步:请输入周期的*数量*(例如:每3个月,输入 3)", parse_mode='Markdown')
|
||||||
return ADD_FREQ_VALUE
|
return ADD_FREQ_VALUE
|
||||||
|
|
||||||
|
|
||||||
async def add_freq_value_received(update: Update, context: CallbackContext):
|
async def add_freq_value_received(update: Update, context: CallbackContext):
|
||||||
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
|
if sub_data is None:
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
try:
|
try:
|
||||||
value = int(update.message.text)
|
value = int(update.message.text)
|
||||||
if value <= 0:
|
if value <= 0:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
context.user_data['new_sub_data']['value'] = value
|
sub_data['value'] = value
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
await update.message.reply_text("请输入一个有效的正整数。")
|
await update.message.reply_text("请输入一个有效的正整数。")
|
||||||
return ADD_FREQ_VALUE
|
return ADD_FREQ_VALUE
|
||||||
@@ -645,36 +760,55 @@ async def add_freq_value_received(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
|
|
||||||
async def add_renewal_type_received(update: Update, context: CallbackContext):
|
async def add_renewal_type_received(update: Update, context: CallbackContext):
|
||||||
|
sub_data, _ = _get_new_sub_data_or_end(update, context)
|
||||||
query = update.callback_query
|
query = update.callback_query
|
||||||
await query.answer()
|
await query.answer()
|
||||||
context.user_data['new_sub_data']['renewal_type'] = query.data.split('_')[1]
|
if sub_data is None:
|
||||||
|
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
renewal_type = query.data.partition('renewal_')[2]
|
||||||
|
if renewal_type not in VALID_RENEWAL_TYPES:
|
||||||
|
await query.edit_message_text("错误:无效的续费类型,请重试。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
sub_data['renewal_type'] = renewal_type
|
||||||
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)")
|
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)")
|
||||||
return ADD_NOTES
|
return ADD_NOTES
|
||||||
|
|
||||||
|
|
||||||
async def add_notes_received(update: Update, context: CallbackContext):
|
async def add_notes_received(update: Update, context: CallbackContext):
|
||||||
sub_data = context.user_data.get('new_sub_data')
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
if not sub_data:
|
if sub_data is None:
|
||||||
await update.message.reply_text("发生错误,请重试。")
|
_clear_action_state(context, ['new_sub_data'])
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
sub_data['notes'] = update.message.text
|
|
||||||
|
note = update.message.text.strip()
|
||||||
|
if len(note) > MAX_NOTES_LEN:
|
||||||
|
await update.message.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。")
|
||||||
|
return ADD_NOTES
|
||||||
|
sub_data['notes'] = note if note else None
|
||||||
save_subscription(update.effective_user.id, sub_data)
|
save_subscription(update.effective_user.id, sub_data)
|
||||||
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
|
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
|
||||||
parse_mode='MarkdownV2')
|
parse_mode='MarkdownV2')
|
||||||
context.user_data.clear()
|
_clear_action_state(context, ['new_sub_data'])
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
|
||||||
async def skip_notes(update: Update, context: CallbackContext):
|
async def skip_notes(update: Update, context: CallbackContext):
|
||||||
sub_data = context.user_data.get('new_sub_data')
|
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||||
if not sub_data:
|
if sub_data is None:
|
||||||
await update.message.reply_text("发生错误,请重试。")
|
_clear_action_state(context, ['new_sub_data'])
|
||||||
|
if err_msg_obj:
|
||||||
|
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
sub_data['notes'] = None
|
sub_data['notes'] = None
|
||||||
save_subscription(update.effective_user.id, sub_data)
|
save_subscription(update.effective_user.id, sub_data)
|
||||||
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
|
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
|
||||||
parse_mode='MarkdownV2')
|
parse_mode='MarkdownV2')
|
||||||
context.user_data.clear()
|
_clear_action_state(context, ['new_sub_data'])
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
|
||||||
@@ -706,7 +840,7 @@ async def list_categories(update: Update, context: CallbackContext):
|
|||||||
user_id = update.effective_user.id
|
user_id = update.effective_user.id
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT name FROM categories WHERE user_id = ? ORDER BY name", (user_id,))
|
cursor.execute("SELECT id, name FROM categories WHERE user_id = ? ORDER BY name", (user_id,))
|
||||||
categories = cursor.fetchall()
|
categories = cursor.fetchall()
|
||||||
if not categories:
|
if not categories:
|
||||||
if update.callback_query:
|
if update.callback_query:
|
||||||
@@ -715,7 +849,11 @@ async def list_categories(update: Update, context: CallbackContext):
|
|||||||
await update.message.reply_text("您还没有任何分类。")
|
await update.message.reply_text("您还没有任何分类。")
|
||||||
return
|
return
|
||||||
|
|
||||||
buttons = [InlineKeyboardButton(cat[0], callback_data=f"list_subs_in_category_{cat[0]}") for cat in categories]
|
buttons = []
|
||||||
|
for cat in categories:
|
||||||
|
cat_id, cat_name = cat[0], cat[1]
|
||||||
|
buttons.append(InlineKeyboardButton(cat_name, callback_data=_build_category_callback_data(cat_id)))
|
||||||
|
|
||||||
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
|
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
|
||||||
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
|
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
|
||||||
if update.callback_query:
|
if update.callback_query:
|
||||||
@@ -765,8 +903,12 @@ async def show_subscription_view(update: Update, context: CallbackContext, sub_i
|
|||||||
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
|
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
|
||||||
if 'list_subs_in_category' in context.user_data:
|
if 'list_subs_in_category' in context.user_data:
|
||||||
cat_filter = context.user_data['list_subs_in_category']
|
cat_filter = context.user_data['list_subs_in_category']
|
||||||
keyboard_buttons.append(
|
category_id = context.user_data.get('list_subs_in_category_id')
|
||||||
[InlineKeyboardButton("« 返回分类订阅", callback_data=f'list_subs_in_category_{cat_filter}')])
|
if category_id:
|
||||||
|
back_cb = _build_category_callback_data(category_id)
|
||||||
|
else:
|
||||||
|
back_cb = 'list_categories'
|
||||||
|
keyboard_buttons.append([InlineKeyboardButton("« 返回分类订阅", callback_data=back_cb)])
|
||||||
else:
|
else:
|
||||||
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
|
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
|
||||||
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
|
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
|
||||||
@@ -785,9 +927,24 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
user_id = query.from_user.id
|
user_id = query.from_user.id
|
||||||
logger.debug(f"Received callback query: {data} from user {user_id}")
|
logger.debug(f"Received callback query: {data} from user {user_id}")
|
||||||
|
|
||||||
if data.startswith('list_subs_in_category_'):
|
if data.startswith(CATEGORY_CB_PREFIX):
|
||||||
category = data.replace('list_subs_in_category_', '')
|
category_id = _parse_category_id_from_callback(data)
|
||||||
|
if not category_id:
|
||||||
|
await query.edit_message_text("错误:无效或已过期的分类,请重新选择。")
|
||||||
|
return
|
||||||
|
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT name FROM categories WHERE id = ? AND user_id = ?", (category_id, user_id))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
await query.edit_message_text("错误:分类不存在或无权限。")
|
||||||
|
return
|
||||||
|
|
||||||
|
category = row['name']
|
||||||
context.user_data['list_subs_in_category'] = category
|
context.user_data['list_subs_in_category'] = category
|
||||||
|
context.user_data['list_subs_in_category_id'] = category_id
|
||||||
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
|
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
|
||||||
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
|
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
|
||||||
if not keyboard:
|
if not keyboard:
|
||||||
@@ -797,10 +954,12 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
return
|
return
|
||||||
if data == 'list_categories':
|
if data == 'list_categories':
|
||||||
context.user_data.pop('list_subs_in_category', None)
|
context.user_data.pop('list_subs_in_category', None)
|
||||||
|
context.user_data.pop('list_subs_in_category_id', None)
|
||||||
await list_categories(update, context)
|
await list_categories(update, context)
|
||||||
return
|
return
|
||||||
if data == 'list_all_subs':
|
if data == 'list_all_subs':
|
||||||
context.user_data.pop('list_subs_in_category', None)
|
context.user_data.pop('list_subs_in_category', None)
|
||||||
|
context.user_data.pop('list_subs_in_category_id', None)
|
||||||
keyboard = await get_subs_list_keyboard(user_id)
|
keyboard = await get_subs_list_keyboard(user_id)
|
||||||
if not keyboard:
|
if not keyboard:
|
||||||
await query.edit_message_text("您还没有任何订阅。")
|
await query.edit_message_text("您还没有任何订阅。")
|
||||||
@@ -821,33 +980,45 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
elif action == 'renewmanual':
|
elif action == 'renewmanual':
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
|
cursor.execute(
|
||||||
|
"SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
sub = cursor.fetchone()
|
sub = cursor.fetchone()
|
||||||
if sub:
|
if sub:
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
||||||
if new_due_date:
|
if new_due_date:
|
||||||
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
||||||
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
|
cursor.execute(
|
||||||
|
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
|
||||||
|
(new_date_str, sub_id, user_id)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
|
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
|
||||||
await show_subscription_view(update, context, sub_id)
|
await show_subscription_view(update, context, sub_id)
|
||||||
else:
|
else:
|
||||||
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
||||||
else:
|
else:
|
||||||
await query.answer("续费失败:订阅不存在。", show_alert=True)
|
await query.answer("续费失败:订阅不存在或无权限。", show_alert=True)
|
||||||
|
|
||||||
elif action == 'renewfromremind':
|
elif action == 'renewfromremind':
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
|
cursor.execute(
|
||||||
|
"SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
sub = cursor.fetchone()
|
sub = cursor.fetchone()
|
||||||
if sub:
|
if sub:
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
||||||
if new_due_date:
|
if new_due_date:
|
||||||
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
||||||
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
|
cursor.execute(
|
||||||
|
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
|
||||||
|
(new_date_str, sub_id, user_id)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
safe_sub_name = escape_markdown(sub['name'], version=2)
|
safe_sub_name = escape_markdown(sub['name'], version=2)
|
||||||
await query.edit_message_text(
|
await query.edit_message_text(
|
||||||
@@ -858,11 +1029,17 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
else:
|
else:
|
||||||
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
||||||
else:
|
else:
|
||||||
await query.answer("续费失败:此订阅可能已被删除。", show_alert=True)
|
await query.answer("续费失败:此订阅可能已被删除或无权限。", show_alert=True)
|
||||||
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅已被删除)*",
|
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅不存在或无权限)*",
|
||||||
parse_mode='MarkdownV2', reply_markup=None)
|
parse_mode='MarkdownV2', reply_markup=None)
|
||||||
|
|
||||||
elif action == 'delete':
|
elif action == 'delete':
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||||||
|
if not cursor.fetchone():
|
||||||
|
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
|
||||||
|
return
|
||||||
keyboard = InlineKeyboardMarkup([
|
keyboard = InlineKeyboardMarkup([
|
||||||
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
|
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
|
||||||
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
|
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
|
||||||
@@ -872,7 +1049,11 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||||||
|
deleted = cursor.rowcount
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
if deleted == 0:
|
||||||
|
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
|
||||||
|
return
|
||||||
await query.answer("订阅已删除")
|
await query.answer("订阅已删除")
|
||||||
if 'list_subs_in_category' in context.user_data:
|
if 'list_subs_in_category' in context.user_data:
|
||||||
category = context.user_data['list_subs_in_category']
|
category = context.user_data['list_subs_in_category']
|
||||||
@@ -907,7 +1088,21 @@ async def fallback_view_button(update: Update, context: CallbackContext):
|
|||||||
async def edit_start(update: Update, context: CallbackContext):
|
async def edit_start(update: Update, context: CallbackContext):
|
||||||
query = update.callback_query
|
query = update.callback_query
|
||||||
await query.answer()
|
await query.answer()
|
||||||
sub_id = query.data.split('_')[1]
|
sub_id_str = query.data.split('_')[1]
|
||||||
|
user_id = query.from_user.id
|
||||||
|
|
||||||
|
if not sub_id_str.isdigit():
|
||||||
|
await query.edit_message_text("错误:无效的订阅ID。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
sub_id = int(sub_id_str)
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||||||
|
if not cursor.fetchone():
|
||||||
|
await query.edit_message_text("错误:找不到该订阅或无权限。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
logger.debug(f"Starting edit for sub_id: {sub_id}")
|
logger.debug(f"Starting edit for sub_id: {sub_id}")
|
||||||
context.user_data['sub_id_for_action'] = sub_id
|
context.user_data['sub_id_for_action'] = sub_id
|
||||||
keyboard = [
|
keyboard = [
|
||||||
@@ -960,7 +1155,11 @@ async def edit_field_selected(update: Update, context: CallbackContext):
|
|||||||
async def edit_freq_unit_received(update: Update, context: CallbackContext):
|
async def edit_freq_unit_received(update: Update, context: CallbackContext):
|
||||||
query = update.callback_query
|
query = update.callback_query
|
||||||
await query.answer()
|
await query.answer()
|
||||||
context.user_data['new_freq_unit'] = query.data.split('_')[2]
|
unit = query.data.partition('freq_unit_')[2]
|
||||||
|
if unit not in VALID_FREQ_UNITS:
|
||||||
|
await query.edit_message_text("错误:无效的周期单位,请重试。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
context.user_data['new_freq_unit'] = unit
|
||||||
await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2')
|
await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2')
|
||||||
return EDIT_FREQ_VALUE
|
return EDIT_FREQ_VALUE
|
||||||
|
|
||||||
@@ -975,14 +1174,23 @@ async def edit_freq_value_received(update: Update, context: CallbackContext):
|
|||||||
await update.message.reply_text("请输入一个有效的正整数。")
|
await update.message.reply_text("请输入一个有效的正整数。")
|
||||||
return EDIT_FREQ_VALUE
|
return EDIT_FREQ_VALUE
|
||||||
unit = context.user_data.get('new_freq_unit')
|
unit = context.user_data.get('new_freq_unit')
|
||||||
|
try:
|
||||||
sub_id = int(context.user_data.get('sub_id_for_action'))
|
sub_id = int(context.user_data.get('sub_id_for_action'))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
await update.message.reply_text("错误:会话已过期,请重试。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?",
|
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?",
|
||||||
(unit, value, sub_id, user_id))
|
(unit, value, sub_id, user_id))
|
||||||
|
if cursor.rowcount == 0:
|
||||||
|
await update.message.reply_text("错误:找不到该订阅或无权限。")
|
||||||
|
return ConversationHandler.END
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
await update.message.reply_text("✅ 周期已更新!")
|
await update.message.reply_text("✅ 周期已更新!")
|
||||||
context.user_data.clear()
|
_clear_action_state(context, ['sub_id_for_action', 'new_freq_unit', 'field_to_edit'])
|
||||||
await show_subscription_view(update, context, sub_id)
|
await show_subscription_view(update, context, sub_id)
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
@@ -1000,6 +1208,13 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
|||||||
if update.effective_message:
|
if update.effective_message:
|
||||||
await update.effective_message.reply_text("错误:未选择要编辑的字段。")
|
await update.effective_message.reply_text("错误:未选择要编辑的字段。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
db_field = EDITABLE_SUB_FIELDS.get(field)
|
||||||
|
if not db_field:
|
||||||
|
if update.effective_message:
|
||||||
|
await update.effective_message.reply_text("错误:不允许编辑该字段。")
|
||||||
|
logger.warning(f"Blocked unsafe field update attempt: {field}")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
query, new_value = update.callback_query, ""
|
query, new_value = update.callback_query, ""
|
||||||
message_to_reply = update.effective_message
|
message_to_reply = update.effective_message
|
||||||
|
|
||||||
@@ -1021,25 +1236,55 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
|||||||
if new_value < 0:
|
if new_value < 0:
|
||||||
raise ValueError("费用不能为负数")
|
raise ValueError("费用不能为负数")
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
if message_to_reply: await message_to_reply.reply_text("费用必须是有效的非负数字。")
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("费用必须是有效的非负数字。")
|
||||||
|
validation_failed = True
|
||||||
|
elif field == 'name':
|
||||||
|
new_value = str(new_value).strip()
|
||||||
|
if not new_value:
|
||||||
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("名称不能为空。")
|
||||||
|
validation_failed = True
|
||||||
|
elif len(new_value) > MAX_NAME_LEN:
|
||||||
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text(f"名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
elif field == 'currency':
|
elif field == 'currency':
|
||||||
new_value = str(new_value).upper()
|
new_value = str(new_value).upper()
|
||||||
if not (len(new_value) == 3 and new_value.isalpha()):
|
if not (len(new_value) == 3 and new_value.isalpha()):
|
||||||
if message_to_reply: await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
elif field == 'next_due':
|
elif field == 'next_due':
|
||||||
parsed = parse_date(str(new_value))
|
parsed = parse_date(str(new_value))
|
||||||
if not parsed:
|
if not parsed:
|
||||||
if message_to_reply: await message_to_reply.reply_text(
|
if message_to_reply:
|
||||||
"无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
await message_to_reply.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
else:
|
else:
|
||||||
new_value = parsed
|
new_value = parsed
|
||||||
|
elif field == 'renewal_type':
|
||||||
|
if str(new_value) not in VALID_RENEWAL_TYPES:
|
||||||
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("续费方式只能为 auto 或 manual。")
|
||||||
|
validation_failed = True
|
||||||
|
elif field == 'notes':
|
||||||
|
note_val = str(new_value).strip()
|
||||||
|
if note_val and len(note_val) > MAX_NOTES_LEN:
|
||||||
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。")
|
||||||
|
validation_failed = True
|
||||||
|
else:
|
||||||
|
new_value = note_val if note_val else None
|
||||||
elif field == 'category':
|
elif field == 'category':
|
||||||
new_value = str(new_value).strip()
|
new_value = str(new_value).strip()
|
||||||
if not new_value:
|
if not new_value:
|
||||||
if message_to_reply: await message_to_reply.reply_text("类别不能为空。")
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("类别不能为空。")
|
||||||
|
validation_failed = True
|
||||||
|
elif len(new_value) > MAX_CATEGORY_LEN:
|
||||||
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
else:
|
else:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
@@ -1052,30 +1297,37 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(f"UPDATE subscriptions SET {field} = ? WHERE id = ? AND user_id = ?",
|
cursor.execute(f"UPDATE subscriptions SET {db_field} = ? WHERE id = ? AND user_id = ?",
|
||||||
(new_value, sub_id, user_id))
|
(new_value, sub_id, user_id))
|
||||||
|
if cursor.rowcount == 0:
|
||||||
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("错误:找不到该订阅或无权限。")
|
||||||
|
return ConversationHandler.END
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
await query.answer(f"✅ 字段已更新!")
|
await query.answer("✅ 字段已更新!")
|
||||||
elif message_to_reply:
|
elif message_to_reply:
|
||||||
await message_to_reply.reply_text("✅ 字段已更新!")
|
await message_to_reply.reply_text("✅ 字段已更新!")
|
||||||
|
|
||||||
context.user_data.clear()
|
_clear_action_state(context, ['sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
|
||||||
await show_subscription_view(update, context, sub_id)
|
await show_subscription_view(update, context, sub_id)
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
|
||||||
# --- Reminder Settings Conversation ---
|
# --- Reminder Settings Conversation ---
|
||||||
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
|
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
|
||||||
|
user_id = query.from_user.id
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days FROM subscriptions WHERE id = ?",
|
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days "
|
||||||
(sub_id,))
|
"FROM subscriptions WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
sub = cursor.fetchone()
|
sub = cursor.fetchone()
|
||||||
if not sub:
|
if not sub:
|
||||||
await query.edit_message_text("错误:找不到该订阅。")
|
await query.edit_message_text("错误:找不到该订阅或无权限。")
|
||||||
return
|
return
|
||||||
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
|
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
|
||||||
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
|
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
|
||||||
@@ -1096,10 +1348,20 @@ async def remind_settings_start(update: Update, context: CallbackContext):
|
|||||||
query = update.callback_query
|
query = update.callback_query
|
||||||
await query.answer()
|
await query.answer()
|
||||||
sub_id_str = query.data.partition('_')[2]
|
sub_id_str = query.data.partition('_')[2]
|
||||||
|
user_id = query.from_user.id
|
||||||
|
|
||||||
if not sub_id_str.isdigit():
|
if not sub_id_str.isdigit():
|
||||||
await query.edit_message_text("错误:无效的订阅ID。")
|
await query.edit_message_text("错误:无效的订阅ID。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
sub_id = int(sub_id_str)
|
sub_id = int(sub_id_str)
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||||||
|
if not cursor.fetchone():
|
||||||
|
await query.edit_message_text("错误:找不到该订阅或无权限。")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
logger.debug(f"Starting reminder settings for sub_id: {sub_id}")
|
logger.debug(f"Starting reminder settings for sub_id: {sub_id}")
|
||||||
context.user_data['sub_id_for_action'] = sub_id
|
context.user_data['sub_id_for_action'] = sub_id
|
||||||
await _display_reminder_settings(query, context, sub_id)
|
await _display_reminder_settings(query, context, sub_id)
|
||||||
@@ -1119,6 +1381,8 @@ async def remind_action_handler(update: Update, context: CallbackContext):
|
|||||||
await query.edit_message_text("错误:会话已过期,请重试。")
|
await query.edit_message_text("错误:会话已过期,请重试。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
user_id = query.from_user.id
|
||||||
|
|
||||||
if action == 'ask_days':
|
if action == 'ask_days':
|
||||||
await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)")
|
await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)")
|
||||||
return REMIND_GET_DAYS
|
return REMIND_GET_DAYS
|
||||||
@@ -1130,10 +1394,20 @@ async def remind_action_handler(update: Update, context: CallbackContext):
|
|||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
if action == 'toggle_enabled':
|
if action == 'toggle_enabled':
|
||||||
cursor.execute("UPDATE subscriptions SET reminders_enabled = NOT reminders_enabled WHERE id = ?", (sub_id,))
|
cursor.execute(
|
||||||
|
"UPDATE subscriptions SET reminders_enabled = CASE WHEN reminders_enabled THEN 0 ELSE 1 END "
|
||||||
|
"WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
elif action == 'toggle_due_date':
|
elif action == 'toggle_due_date':
|
||||||
cursor.execute("UPDATE subscriptions SET reminder_on_due_date = NOT reminder_on_due_date WHERE id = ?",
|
cursor.execute(
|
||||||
(sub_id,))
|
"UPDATE subscriptions SET reminder_on_due_date = CASE WHEN reminder_on_due_date THEN 0 ELSE 1 END "
|
||||||
|
"WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
|
if cursor.rowcount == 0:
|
||||||
|
await query.edit_message_text("错误:找不到该订阅或无权限。")
|
||||||
|
return ConversationHandler.END
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await _display_reminder_settings(query, context, sub_id)
|
await _display_reminder_settings(query, context, sub_id)
|
||||||
return REMIND_SELECT_ACTION
|
return REMIND_SELECT_ACTION
|
||||||
@@ -1144,16 +1418,20 @@ async def remind_days_received(update: Update, context: CallbackContext):
|
|||||||
if not sub_id:
|
if not sub_id:
|
||||||
await update.message.reply_text("错误:会话已过期,请重试。")
|
await update.message.reply_text("错误:会话已过期,请重试。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
user_id = update.effective_user.id
|
||||||
try:
|
try:
|
||||||
days = int(update.message.text)
|
days = int(update.message.text)
|
||||||
if days < 0:
|
if days < 0:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ?", (days, sub_id))
|
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ? AND user_id = ?", (days, sub_id, user_id))
|
||||||
|
if cursor.rowcount == 0:
|
||||||
|
await update.message.reply_text("错误:找不到该订阅或无权限。")
|
||||||
|
return ConversationHandler.END
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
|
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
|
||||||
context.user_data.clear()
|
_clear_action_state(context, ['sub_id_for_action'])
|
||||||
await show_subscription_view(update, context, sub_id)
|
await show_subscription_view(update, context, sub_id)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
await update.message.reply_text("请输入一个有效的非负整数。")
|
await update.message.reply_text("请输入一个有效的非负整数。")
|
||||||
@@ -1173,14 +1451,18 @@ async def set_currency(update: Update, context: CallbackContext):
|
|||||||
return
|
return
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("INSERT OR REPLACE INTO users (user_id, main_currency) VALUES (?, ?)", (user_id, new_currency))
|
cursor.execute("""
|
||||||
|
INSERT INTO users (user_id, main_currency)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT(user_id) DO UPDATE SET main_currency = excluded.main_currency
|
||||||
|
""", (user_id, new_currency))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}。",
|
await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}。",
|
||||||
parse_mode='MarkdownV2')
|
parse_mode='MarkdownV2')
|
||||||
|
|
||||||
|
|
||||||
async def cancel(update: Update, context: CallbackContext):
|
async def cancel(update: Update, context: CallbackContext):
|
||||||
context.user_data.clear()
|
_clear_action_state(context, ['new_sub_data', 'sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
|
||||||
if update.callback_query:
|
if update.callback_query:
|
||||||
await update.callback_query.answer()
|
await update.callback_query.answer()
|
||||||
await update.callback_query.edit_message_text('操作已取消。')
|
await update.callback_query.edit_message_text('操作已取消。')
|
||||||
@@ -1254,7 +1536,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
edit_conv = ConversationHandler(
|
edit_conv = ConversationHandler(
|
||||||
entry_points=[CallbackQueryHandler(edit_start, pattern='^edit_')],
|
entry_points=[CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$')],
|
||||||
states={
|
states={
|
||||||
EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')],
|
EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')],
|
||||||
EDIT_GET_NEW_VALUE: [
|
EDIT_GET_NEW_VALUE: [
|
||||||
@@ -1268,15 +1550,15 @@ def main():
|
|||||||
fallbacks=[
|
fallbacks=[
|
||||||
CommandHandler('cancel', cancel),
|
CommandHandler('cancel', cancel),
|
||||||
# 【修改】使用新的包装函数来确保会话能正确结束
|
# 【修改】使用新的包装函数来确保会话能正确结束
|
||||||
CallbackQueryHandler(fallback_view_button, pattern='^view_'),
|
CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'),
|
||||||
CallbackQueryHandler(edit_start, pattern='^edit_'),
|
CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'),
|
||||||
CallbackQueryHandler(remind_settings_start, pattern='^remind_')
|
CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')
|
||||||
],
|
],
|
||||||
per_message=False
|
per_message=False
|
||||||
)
|
)
|
||||||
|
|
||||||
remind_conv = ConversationHandler(
|
remind_conv = ConversationHandler(
|
||||||
entry_points=[CallbackQueryHandler(remind_settings_start, pattern='^remind_')],
|
entry_points=[CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')],
|
||||||
states={
|
states={
|
||||||
REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')],
|
REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')],
|
||||||
REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)],
|
REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)],
|
||||||
@@ -1284,9 +1566,9 @@ def main():
|
|||||||
fallbacks=[
|
fallbacks=[
|
||||||
CommandHandler('cancel', cancel),
|
CommandHandler('cancel', cancel),
|
||||||
# 【修改】使用新的包装函数来确保会话能正确结束
|
# 【修改】使用新的包装函数来确保会话能正确结束
|
||||||
CallbackQueryHandler(fallback_view_button, pattern='^view_'),
|
CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'),
|
||||||
CallbackQueryHandler(edit_start, pattern='^edit_'),
|
CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'),
|
||||||
CallbackQueryHandler(remind_settings_start, pattern='^remind_')
|
CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')
|
||||||
],
|
],
|
||||||
per_message=False
|
per_message=False
|
||||||
)
|
)
|
||||||
@@ -1299,7 +1581,7 @@ def main():
|
|||||||
fallbacks=[CommandHandler('cancel', cancel)]
|
fallbacks=[CommandHandler('cancel', cancel)]
|
||||||
)
|
)
|
||||||
|
|
||||||
button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_.+|list_categories|list_all_subs)$'
|
button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_id_\d+|list_categories|list_all_subs)$'
|
||||||
|
|
||||||
application.add_handler(CommandHandler('start', start))
|
application.add_handler(CommandHandler('start', start))
|
||||||
application.add_handler(CommandHandler('help', help_command))
|
application.add_handler(CommandHandler('help', help_command))
|
||||||
|
|||||||
Reference in New Issue
Block a user