fix: guard add-sub conversation against expired state
This commit is contained in:
73
SubMind.py
73
SubMind.py
@@ -289,9 +289,14 @@ async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> I
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
|
||||
def _clear_action_state(context: CallbackContext, keys: list[str]):
|
||||
for key in keys:
|
||||
context.user_data.pop(key, None)
|
||||
def _get_new_sub_data_or_end(update: Update, context: CallbackContext):
|
||||
sub_data = context.user_data.get('new_sub_data')
|
||||
if sub_data is None:
|
||||
message_obj = update.message or (update.callback_query.message if update.callback_query else None)
|
||||
if message_obj:
|
||||
# 统一提示,避免 KeyError 导致会话崩溃
|
||||
return None, message_obj
|
||||
return sub_data, None
|
||||
|
||||
|
||||
# --- 自动任务 ---
|
||||
@@ -608,6 +613,12 @@ async def add_sub_start(update: Update, context: CallbackContext):
|
||||
|
||||
|
||||
async def add_name_received(update: Update, context: CallbackContext):
|
||||
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||
if sub_data is None:
|
||||
if err_msg_obj:
|
||||
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
name = update.message.text.strip()
|
||||
if not name:
|
||||
await update.message.reply_text("订阅名称不能为空。")
|
||||
@@ -615,17 +626,23 @@ async def add_name_received(update: Update, context: CallbackContext):
|
||||
if len(name) > MAX_NAME_LEN:
|
||||
await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
|
||||
return ADD_NAME
|
||||
context.user_data['new_sub_data']['name'] = name
|
||||
sub_data['name'] = name
|
||||
await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2')
|
||||
return ADD_COST
|
||||
|
||||
|
||||
async def add_cost_received(update: Update, context: CallbackContext):
|
||||
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||
if sub_data is None:
|
||||
if err_msg_obj:
|
||||
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
try:
|
||||
cost = float(update.message.text)
|
||||
if cost < 0:
|
||||
raise ValueError("费用不能为负数")
|
||||
context.user_data['new_sub_data']['cost'] = cost
|
||||
sub_data['cost'] = cost
|
||||
except (ValueError, TypeError):
|
||||
await update.message.reply_text("费用必须是有效的非负数字。")
|
||||
return ADD_COST
|
||||
@@ -634,16 +651,28 @@ async def add_cost_received(update: Update, context: CallbackContext):
|
||||
|
||||
|
||||
async def add_currency_received(update: Update, context: CallbackContext):
|
||||
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||
if sub_data is None:
|
||||
if err_msg_obj:
|
||||
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
currency = update.message.text.upper()
|
||||
if not (len(currency) == 3 and currency.isalpha()):
|
||||
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
||||
return ADD_CURRENCY
|
||||
context.user_data['new_sub_data']['currency'] = currency
|
||||
sub_data['currency'] = currency
|
||||
await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2')
|
||||
return ADD_CATEGORY
|
||||
|
||||
|
||||
async def add_category_received(update: Update, context: CallbackContext):
|
||||
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||
if sub_data is None:
|
||||
if err_msg_obj:
|
||||
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
user_id, category_name = update.effective_user.id, update.message.text.strip()
|
||||
if not category_name:
|
||||
await update.message.reply_text("类别不能为空。")
|
||||
@@ -651,7 +680,7 @@ async def add_category_received(update: Update, context: CallbackContext):
|
||||
if len(category_name) > MAX_CATEGORY_LEN:
|
||||
await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
|
||||
return ADD_CATEGORY
|
||||
context.user_data['new_sub_data']['category'] = category_name
|
||||
sub_data['category'] = category_name
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
|
||||
@@ -662,11 +691,17 @@ async def add_category_received(update: Update, context: CallbackContext):
|
||||
|
||||
|
||||
async def add_next_due_received(update: Update, context: CallbackContext):
|
||||
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||
if sub_data is None:
|
||||
if err_msg_obj:
|
||||
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
parsed_date = parse_date(update.message.text)
|
||||
if not parsed_date:
|
||||
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
||||
return ADD_NEXT_DUE
|
||||
context.user_data['new_sub_data']['next_due'] = parsed_date
|
||||
sub_data['next_due'] = parsed_date
|
||||
keyboard = [
|
||||
[InlineKeyboardButton("天", callback_data='freq_unit_day'),
|
||||
InlineKeyboardButton("周", callback_data='freq_unit_week')],
|
||||
@@ -679,23 +714,34 @@ async def add_next_due_received(update: Update, context: CallbackContext):
|
||||
|
||||
|
||||
async def add_freq_unit_received(update: Update, context: CallbackContext):
|
||||
sub_data, _ = _get_new_sub_data_or_end(update, context)
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
if sub_data is None:
|
||||
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
unit = query.data.split('_')[2]
|
||||
if unit not in VALID_FREQ_UNITS:
|
||||
await query.edit_message_text("错误:无效的周期单位,请重试。")
|
||||
return ConversationHandler.END
|
||||
context.user_data['new_sub_data']['unit'] = unit
|
||||
sub_data['unit'] = unit
|
||||
await query.edit_message_text("第七步:请输入周期的*数量*(例如:每3个月,输入 3)", parse_mode='Markdown')
|
||||
return ADD_FREQ_VALUE
|
||||
|
||||
|
||||
async def add_freq_value_received(update: Update, context: CallbackContext):
|
||||
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
|
||||
if sub_data is None:
|
||||
if err_msg_obj:
|
||||
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
try:
|
||||
value = int(update.message.text)
|
||||
if value <= 0:
|
||||
raise ValueError
|
||||
context.user_data['new_sub_data']['value'] = value
|
||||
sub_data['value'] = value
|
||||
except (ValueError, TypeError):
|
||||
await update.message.reply_text("请输入一个有效的正整数。")
|
||||
return ADD_FREQ_VALUE
|
||||
@@ -709,13 +755,18 @@ async def add_freq_value_received(update: Update, context: CallbackContext):
|
||||
|
||||
|
||||
async def add_renewal_type_received(update: Update, context: CallbackContext):
|
||||
sub_data, _ = _get_new_sub_data_or_end(update, context)
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
if sub_data is None:
|
||||
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
|
||||
return ConversationHandler.END
|
||||
|
||||
renewal_type = query.data.split('_')[1]
|
||||
if renewal_type not in VALID_RENEWAL_TYPES:
|
||||
await query.edit_message_text("错误:无效的续费类型,请重试。")
|
||||
return ConversationHandler.END
|
||||
context.user_data['new_sub_data']['renewal_type'] = renewal_type
|
||||
sub_data['renewal_type'] = renewal_type
|
||||
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)")
|
||||
return ADD_NOTES
|
||||
|
||||
|
||||
Reference in New Issue
Block a user