fix: guard add-sub conversation against expired state

This commit is contained in:
Xiaolan Bot
2026-02-22 11:39:25 +08:00
parent d212d73c2a
commit 210af75e2c

View File

@@ -289,9 +289,14 @@ async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> I
return InlineKeyboardMarkup(keyboard) return InlineKeyboardMarkup(keyboard)
def _clear_action_state(context: CallbackContext, keys: list[str]): def _get_new_sub_data_or_end(update: Update, context: CallbackContext):
for key in keys: sub_data = context.user_data.get('new_sub_data')
context.user_data.pop(key, None) if sub_data is None:
message_obj = update.message or (update.callback_query.message if update.callback_query else None)
if message_obj:
# 统一提示,避免 KeyError 导致会话崩溃
return None, message_obj
return sub_data, None
# --- 自动任务 --- # --- 自动任务 ---
@@ -608,6 +613,12 @@ async def add_sub_start(update: Update, context: CallbackContext):
async def add_name_received(update: Update, context: CallbackContext): async def add_name_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
name = update.message.text.strip() name = update.message.text.strip()
if not name: if not name:
await update.message.reply_text("订阅名称不能为空。") await update.message.reply_text("订阅名称不能为空。")
@@ -615,17 +626,23 @@ async def add_name_received(update: Update, context: CallbackContext):
if len(name) > MAX_NAME_LEN: if len(name) > MAX_NAME_LEN:
await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。") await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
return ADD_NAME return ADD_NAME
context.user_data['new_sub_data']['name'] = name sub_data['name'] = name
await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2') await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2')
return ADD_COST return ADD_COST
async def add_cost_received(update: Update, context: CallbackContext): async def add_cost_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
try: try:
cost = float(update.message.text) cost = float(update.message.text)
if cost < 0: if cost < 0:
raise ValueError("费用不能为负数") raise ValueError("费用不能为负数")
context.user_data['new_sub_data']['cost'] = cost sub_data['cost'] = cost
except (ValueError, TypeError): except (ValueError, TypeError):
await update.message.reply_text("费用必须是有效的非负数字。") await update.message.reply_text("费用必须是有效的非负数字。")
return ADD_COST return ADD_COST
@@ -634,16 +651,28 @@ async def add_cost_received(update: Update, context: CallbackContext):
async def add_currency_received(update: Update, context: CallbackContext): async def add_currency_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
currency = update.message.text.upper() currency = update.message.text.upper()
if not (len(currency) == 3 and currency.isalpha()): if not (len(currency) == 3 and currency.isalpha()):
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY") await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY")
return ADD_CURRENCY return ADD_CURRENCY
context.user_data['new_sub_data']['currency'] = currency sub_data['currency'] = currency
await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2') await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2')
return ADD_CATEGORY return ADD_CATEGORY
async def add_category_received(update: Update, context: CallbackContext): async def add_category_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
user_id, category_name = update.effective_user.id, update.message.text.strip() user_id, category_name = update.effective_user.id, update.message.text.strip()
if not category_name: if not category_name:
await update.message.reply_text("类别不能为空。") await update.message.reply_text("类别不能为空。")
@@ -651,7 +680,7 @@ async def add_category_received(update: Update, context: CallbackContext):
if len(category_name) > MAX_CATEGORY_LEN: if len(category_name) > MAX_CATEGORY_LEN:
await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。") await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
return ADD_CATEGORY return ADD_CATEGORY
context.user_data['new_sub_data']['category'] = category_name sub_data['category'] = category_name
with get_db_connection() as conn: with get_db_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name)) cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
@@ -662,11 +691,17 @@ async def add_category_received(update: Update, context: CallbackContext):
async def add_next_due_received(update: Update, context: CallbackContext): async def add_next_due_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
parsed_date = parse_date(update.message.text) parsed_date = parse_date(update.message.text)
if not parsed_date: if not parsed_date:
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。") await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。")
return ADD_NEXT_DUE return ADD_NEXT_DUE
context.user_data['new_sub_data']['next_due'] = parsed_date sub_data['next_due'] = parsed_date
keyboard = [ keyboard = [
[InlineKeyboardButton("", callback_data='freq_unit_day'), [InlineKeyboardButton("", callback_data='freq_unit_day'),
InlineKeyboardButton("", callback_data='freq_unit_week')], InlineKeyboardButton("", callback_data='freq_unit_week')],
@@ -679,23 +714,34 @@ async def add_next_due_received(update: Update, context: CallbackContext):
async def add_freq_unit_received(update: Update, context: CallbackContext): async def add_freq_unit_received(update: Update, context: CallbackContext):
sub_data, _ = _get_new_sub_data_or_end(update, context)
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
if sub_data is None:
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
unit = query.data.split('_')[2] unit = query.data.split('_')[2]
if unit not in VALID_FREQ_UNITS: if unit not in VALID_FREQ_UNITS:
await query.edit_message_text("错误:无效的周期单位,请重试。") await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END return ConversationHandler.END
context.user_data['new_sub_data']['unit'] = unit sub_data['unit'] = unit
await query.edit_message_text("第七步:请输入周期的*数量*例如每3个月输入 3", parse_mode='Markdown') await query.edit_message_text("第七步:请输入周期的*数量*例如每3个月输入 3", parse_mode='Markdown')
return ADD_FREQ_VALUE return ADD_FREQ_VALUE
async def add_freq_value_received(update: Update, context: CallbackContext): async def add_freq_value_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
try: try:
value = int(update.message.text) value = int(update.message.text)
if value <= 0: if value <= 0:
raise ValueError raise ValueError
context.user_data['new_sub_data']['value'] = value sub_data['value'] = value
except (ValueError, TypeError): except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的正整数。") await update.message.reply_text("请输入一个有效的正整数。")
return ADD_FREQ_VALUE return ADD_FREQ_VALUE
@@ -709,13 +755,18 @@ async def add_freq_value_received(update: Update, context: CallbackContext):
async def add_renewal_type_received(update: Update, context: CallbackContext): async def add_renewal_type_received(update: Update, context: CallbackContext):
sub_data, _ = _get_new_sub_data_or_end(update, context)
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
if sub_data is None:
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
renewal_type = query.data.split('_')[1] renewal_type = query.data.split('_')[1]
if renewal_type not in VALID_RENEWAL_TYPES: if renewal_type not in VALID_RENEWAL_TYPES:
await query.edit_message_text("错误:无效的续费类型,请重试。") await query.edit_message_text("错误:无效的续费类型,请重试。")
return ConversationHandler.END return ConversationHandler.END
context.user_data['new_sub_data']['renewal_type'] = renewal_type sub_data['renewal_type'] = renewal_type
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip") await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip")
return ADD_NOTES return ADD_NOTES