fix: harden subscription callbacks and settings updates
This commit is contained in:
127
SubMind.py
127
SubMind.py
@@ -241,6 +241,27 @@ def format_frequency(unit, value) -> str:
|
|||||||
return f"每 {value} {unit_map.get(unit, unit)}"
|
return f"每 {value} {unit_map.get(unit, unit)}"
|
||||||
|
|
||||||
|
|
||||||
|
CATEGORY_CB_PREFIX = "list_subs_in_category_"
|
||||||
|
EDITABLE_SUB_FIELDS = {'name', 'cost', 'currency', 'category', 'next_due', 'renewal_type', 'notes'}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_category_callback_data(category_name: str) -> str:
|
||||||
|
"""Build callback_data within Telegram's 64-byte limit by falling back to a hash token."""
|
||||||
|
candidate = f"{CATEGORY_CB_PREFIX}{category_name}"
|
||||||
|
if len(candidate.encode('utf-8')) <= 64:
|
||||||
|
return candidate
|
||||||
|
token = abs(hash(category_name)) % 100000000
|
||||||
|
return f"{CATEGORY_CB_PREFIX}h{token}"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_category_from_callback(data: str, context: CallbackContext) -> str | None:
|
||||||
|
payload = data.replace(CATEGORY_CB_PREFIX, '', 1)
|
||||||
|
if payload.startswith('h') and payload[1:].isdigit():
|
||||||
|
mapping = context.user_data.get('category_cb_map', {})
|
||||||
|
return mapping.get(payload)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
|
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -352,7 +373,7 @@ async def start(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
|
|
||||||
async def help_command(update: Update, context: CallbackContext):
|
async def help_command(update: Update, context: CallbackContext):
|
||||||
help_text = f"""
|
help_text = fr"""
|
||||||
*{escape_markdown(PROJECT_NAME, version=2)} 命令列表*
|
*{escape_markdown(PROJECT_NAME, version=2)} 命令列表*
|
||||||
*🌟 核心功能*
|
*🌟 核心功能*
|
||||||
/add\_sub \- 引导您添加一个新的订阅
|
/add\_sub \- 引导您添加一个新的订阅
|
||||||
@@ -715,7 +736,16 @@ async def list_categories(update: Update, context: CallbackContext):
|
|||||||
await update.message.reply_text("您还没有任何分类。")
|
await update.message.reply_text("您还没有任何分类。")
|
||||||
return
|
return
|
||||||
|
|
||||||
buttons = [InlineKeyboardButton(cat[0], callback_data=f"list_subs_in_category_{cat[0]}") for cat in categories]
|
context.user_data['category_cb_map'] = {}
|
||||||
|
buttons = []
|
||||||
|
for cat in categories:
|
||||||
|
cat_name = cat[0]
|
||||||
|
cb_data = _build_category_callback_data(cat_name)
|
||||||
|
payload = cb_data.replace(CATEGORY_CB_PREFIX, '', 1)
|
||||||
|
if payload.startswith('h') and payload[1:].isdigit():
|
||||||
|
context.user_data['category_cb_map'][payload] = cat_name
|
||||||
|
buttons.append(InlineKeyboardButton(cat_name, callback_data=cb_data))
|
||||||
|
|
||||||
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
|
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
|
||||||
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
|
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
|
||||||
if update.callback_query:
|
if update.callback_query:
|
||||||
@@ -765,8 +795,12 @@ async def show_subscription_view(update: Update, context: CallbackContext, sub_i
|
|||||||
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
|
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
|
||||||
if 'list_subs_in_category' in context.user_data:
|
if 'list_subs_in_category' in context.user_data:
|
||||||
cat_filter = context.user_data['list_subs_in_category']
|
cat_filter = context.user_data['list_subs_in_category']
|
||||||
keyboard_buttons.append(
|
back_cb = _build_category_callback_data(cat_filter)
|
||||||
[InlineKeyboardButton("« 返回分类订阅", callback_data=f'list_subs_in_category_{cat_filter}')])
|
payload = back_cb.replace(CATEGORY_CB_PREFIX, '', 1)
|
||||||
|
if payload.startswith('h') and payload[1:].isdigit():
|
||||||
|
category_cb_map = context.user_data.setdefault('category_cb_map', {})
|
||||||
|
category_cb_map[payload] = cat_filter
|
||||||
|
keyboard_buttons.append([InlineKeyboardButton("« 返回分类订阅", callback_data=back_cb)])
|
||||||
else:
|
else:
|
||||||
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
|
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
|
||||||
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
|
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
|
||||||
@@ -785,8 +819,11 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
user_id = query.from_user.id
|
user_id = query.from_user.id
|
||||||
logger.debug(f"Received callback query: {data} from user {user_id}")
|
logger.debug(f"Received callback query: {data} from user {user_id}")
|
||||||
|
|
||||||
if data.startswith('list_subs_in_category_'):
|
if data.startswith(CATEGORY_CB_PREFIX):
|
||||||
category = data.replace('list_subs_in_category_', '')
|
category = _parse_category_from_callback(data, context)
|
||||||
|
if not category:
|
||||||
|
await query.edit_message_text("错误:无效或已过期的分类,请重新选择。")
|
||||||
|
return
|
||||||
context.user_data['list_subs_in_category'] = category
|
context.user_data['list_subs_in_category'] = category
|
||||||
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
|
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
|
||||||
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
|
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
|
||||||
@@ -821,33 +858,45 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
elif action == 'renewmanual':
|
elif action == 'renewmanual':
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
|
cursor.execute(
|
||||||
|
"SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
sub = cursor.fetchone()
|
sub = cursor.fetchone()
|
||||||
if sub:
|
if sub:
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
||||||
if new_due_date:
|
if new_due_date:
|
||||||
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
||||||
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
|
cursor.execute(
|
||||||
|
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
|
||||||
|
(new_date_str, sub_id, user_id)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
|
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
|
||||||
await show_subscription_view(update, context, sub_id)
|
await show_subscription_view(update, context, sub_id)
|
||||||
else:
|
else:
|
||||||
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
||||||
else:
|
else:
|
||||||
await query.answer("续费失败:订阅不存在。", show_alert=True)
|
await query.answer("续费失败:订阅不存在或无权限。", show_alert=True)
|
||||||
|
|
||||||
elif action == 'renewfromremind':
|
elif action == 'renewfromremind':
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
|
cursor.execute(
|
||||||
|
"SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
sub = cursor.fetchone()
|
sub = cursor.fetchone()
|
||||||
if sub:
|
if sub:
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
||||||
if new_due_date:
|
if new_due_date:
|
||||||
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
||||||
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
|
cursor.execute(
|
||||||
|
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
|
||||||
|
(new_date_str, sub_id, user_id)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
safe_sub_name = escape_markdown(sub['name'], version=2)
|
safe_sub_name = escape_markdown(sub['name'], version=2)
|
||||||
await query.edit_message_text(
|
await query.edit_message_text(
|
||||||
@@ -858,8 +907,8 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
|||||||
else:
|
else:
|
||||||
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
||||||
else:
|
else:
|
||||||
await query.answer("续费失败:此订阅可能已被删除。", show_alert=True)
|
await query.answer("续费失败:此订阅可能已被删除或无权限。", show_alert=True)
|
||||||
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅已被删除)*",
|
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅不存在或无权限)*",
|
||||||
parse_mode='MarkdownV2', reply_markup=None)
|
parse_mode='MarkdownV2', reply_markup=None)
|
||||||
|
|
||||||
elif action == 'delete':
|
elif action == 'delete':
|
||||||
@@ -1000,6 +1049,12 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
|||||||
if update.effective_message:
|
if update.effective_message:
|
||||||
await update.effective_message.reply_text("错误:未选择要编辑的字段。")
|
await update.effective_message.reply_text("错误:未选择要编辑的字段。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
if field not in EDITABLE_SUB_FIELDS:
|
||||||
|
if update.effective_message:
|
||||||
|
await update.effective_message.reply_text("错误:不允许编辑该字段。")
|
||||||
|
logger.warning(f"Blocked unsafe field update attempt: {field}")
|
||||||
|
return ConversationHandler.END
|
||||||
|
|
||||||
query, new_value = update.callback_query, ""
|
query, new_value = update.callback_query, ""
|
||||||
message_to_reply = update.effective_message
|
message_to_reply = update.effective_message
|
||||||
|
|
||||||
@@ -1021,25 +1076,28 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
|||||||
if new_value < 0:
|
if new_value < 0:
|
||||||
raise ValueError("费用不能为负数")
|
raise ValueError("费用不能为负数")
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
if message_to_reply: await message_to_reply.reply_text("费用必须是有效的非负数字。")
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("费用必须是有效的非负数字。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
elif field == 'currency':
|
elif field == 'currency':
|
||||||
new_value = str(new_value).upper()
|
new_value = str(new_value).upper()
|
||||||
if not (len(new_value) == 3 and new_value.isalpha()):
|
if not (len(new_value) == 3 and new_value.isalpha()):
|
||||||
if message_to_reply: await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
elif field == 'next_due':
|
elif field == 'next_due':
|
||||||
parsed = parse_date(str(new_value))
|
parsed = parse_date(str(new_value))
|
||||||
if not parsed:
|
if not parsed:
|
||||||
if message_to_reply: await message_to_reply.reply_text(
|
if message_to_reply:
|
||||||
"无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
await message_to_reply.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
else:
|
else:
|
||||||
new_value = parsed
|
new_value = parsed
|
||||||
elif field == 'category':
|
elif field == 'category':
|
||||||
new_value = str(new_value).strip()
|
new_value = str(new_value).strip()
|
||||||
if not new_value:
|
if not new_value:
|
||||||
if message_to_reply: await message_to_reply.reply_text("类别不能为空。")
|
if message_to_reply:
|
||||||
|
await message_to_reply.reply_text("类别不能为空。")
|
||||||
validation_failed = True
|
validation_failed = True
|
||||||
else:
|
else:
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
@@ -1057,7 +1115,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
await query.answer(f"✅ 字段已更新!")
|
await query.answer("✅ 字段已更新!")
|
||||||
elif message_to_reply:
|
elif message_to_reply:
|
||||||
await message_to_reply.reply_text("✅ 字段已更新!")
|
await message_to_reply.reply_text("✅ 字段已更新!")
|
||||||
|
|
||||||
@@ -1068,14 +1126,17 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
# --- Reminder Settings Conversation ---
|
# --- Reminder Settings Conversation ---
|
||||||
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
|
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
|
||||||
|
user_id = query.from_user.id
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days FROM subscriptions WHERE id = ?",
|
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days "
|
||||||
(sub_id,))
|
"FROM subscriptions WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
sub = cursor.fetchone()
|
sub = cursor.fetchone()
|
||||||
if not sub:
|
if not sub:
|
||||||
await query.edit_message_text("错误:找不到该订阅。")
|
await query.edit_message_text("错误:找不到该订阅或无权限。")
|
||||||
return
|
return
|
||||||
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
|
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
|
||||||
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
|
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
|
||||||
@@ -1119,6 +1180,8 @@ async def remind_action_handler(update: Update, context: CallbackContext):
|
|||||||
await query.edit_message_text("错误:会话已过期,请重试。")
|
await query.edit_message_text("错误:会话已过期,请重试。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
|
||||||
|
user_id = query.from_user.id
|
||||||
|
|
||||||
if action == 'ask_days':
|
if action == 'ask_days':
|
||||||
await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)")
|
await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)")
|
||||||
return REMIND_GET_DAYS
|
return REMIND_GET_DAYS
|
||||||
@@ -1130,10 +1193,15 @@ async def remind_action_handler(update: Update, context: CallbackContext):
|
|||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
if action == 'toggle_enabled':
|
if action == 'toggle_enabled':
|
||||||
cursor.execute("UPDATE subscriptions SET reminders_enabled = NOT reminders_enabled WHERE id = ?", (sub_id,))
|
cursor.execute(
|
||||||
|
"UPDATE subscriptions SET reminders_enabled = NOT reminders_enabled WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
elif action == 'toggle_due_date':
|
elif action == 'toggle_due_date':
|
||||||
cursor.execute("UPDATE subscriptions SET reminder_on_due_date = NOT reminder_on_due_date WHERE id = ?",
|
cursor.execute(
|
||||||
(sub_id,))
|
"UPDATE subscriptions SET reminder_on_due_date = NOT reminder_on_due_date WHERE id = ? AND user_id = ?",
|
||||||
|
(sub_id, user_id)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await _display_reminder_settings(query, context, sub_id)
|
await _display_reminder_settings(query, context, sub_id)
|
||||||
return REMIND_SELECT_ACTION
|
return REMIND_SELECT_ACTION
|
||||||
@@ -1144,13 +1212,14 @@ async def remind_days_received(update: Update, context: CallbackContext):
|
|||||||
if not sub_id:
|
if not sub_id:
|
||||||
await update.message.reply_text("错误:会话已过期,请重试。")
|
await update.message.reply_text("错误:会话已过期,请重试。")
|
||||||
return ConversationHandler.END
|
return ConversationHandler.END
|
||||||
|
user_id = update.effective_user.id
|
||||||
try:
|
try:
|
||||||
days = int(update.message.text)
|
days = int(update.message.text)
|
||||||
if days < 0:
|
if days < 0:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ?", (days, sub_id))
|
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ? AND user_id = ?", (days, sub_id, user_id))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
|
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
|
||||||
context.user_data.clear()
|
context.user_data.clear()
|
||||||
@@ -1173,7 +1242,11 @@ async def set_currency(update: Update, context: CallbackContext):
|
|||||||
return
|
return
|
||||||
with get_db_connection() as conn:
|
with get_db_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("INSERT OR REPLACE INTO users (user_id, main_currency) VALUES (?, ?)", (user_id, new_currency))
|
cursor.execute("""
|
||||||
|
INSERT INTO users (user_id, main_currency)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT(user_id) DO UPDATE SET main_currency = excluded.main_currency
|
||||||
|
""", (user_id, new_currency))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}。",
|
await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}。",
|
||||||
parse_mode='MarkdownV2')
|
parse_mode='MarkdownV2')
|
||||||
|
|||||||
Reference in New Issue
Block a user