diff --git a/SubMind.py b/SubMind.py index ab02922..6f56eec 100644 --- a/SubMind.py +++ b/SubMind.py @@ -289,9 +289,14 @@ async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> I return InlineKeyboardMarkup(keyboard) -def _clear_action_state(context: CallbackContext, keys: list[str]): - for key in keys: - context.user_data.pop(key, None) +def _get_new_sub_data_or_end(update: Update, context: CallbackContext): + sub_data = context.user_data.get('new_sub_data') + if sub_data is None: + message_obj = update.message or (update.callback_query.message if update.callback_query else None) + if message_obj: + # 统一提示,避免 KeyError 导致会话崩溃 + return None, message_obj + return sub_data, None # --- 自动任务 --- @@ -608,6 +613,12 @@ async def add_sub_start(update: Update, context: CallbackContext): async def add_name_received(update: Update, context: CallbackContext): + sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) + if sub_data is None: + if err_msg_obj: + await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + name = update.message.text.strip() if not name: await update.message.reply_text("订阅名称不能为空。") @@ -615,17 +626,23 @@ async def add_name_received(update: Update, context: CallbackContext): if len(name) > MAX_NAME_LEN: await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。") return ADD_NAME - context.user_data['new_sub_data']['name'] = name + sub_data['name'] = name await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2') return ADD_COST async def add_cost_received(update: Update, context: CallbackContext): + sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) + if sub_data is None: + if err_msg_obj: + await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + try: cost = float(update.message.text) if cost < 0: raise ValueError("费用不能为负数") - context.user_data['new_sub_data']['cost'] = cost + sub_data['cost'] = cost except (ValueError, TypeError): await update.message.reply_text("费用必须是有效的非负数字。") return ADD_COST @@ -634,16 +651,28 @@ async def add_cost_received(update: Update, context: CallbackContext): async def add_currency_received(update: Update, context: CallbackContext): + sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) + if sub_data is None: + if err_msg_obj: + await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + currency = update.message.text.upper() if not (len(currency) == 3 and currency.isalpha()): await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。") return ADD_CURRENCY - context.user_data['new_sub_data']['currency'] = currency + sub_data['currency'] = currency await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2') return ADD_CATEGORY async def add_category_received(update: Update, context: CallbackContext): + sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) + if sub_data is None: + if err_msg_obj: + await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + user_id, category_name = update.effective_user.id, update.message.text.strip() if not category_name: await update.message.reply_text("类别不能为空。") @@ -651,7 +680,7 @@ async def add_category_received(update: Update, context: CallbackContext): if len(category_name) > MAX_CATEGORY_LEN: await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。") return ADD_CATEGORY - context.user_data['new_sub_data']['category'] = category_name + sub_data['category'] = category_name with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name)) @@ -662,11 +691,17 @@ async def add_category_received(update: Update, context: CallbackContext): async def add_next_due_received(update: Update, context: CallbackContext): + sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) + if sub_data is None: + if err_msg_obj: + await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + parsed_date = parse_date(update.message.text) if not parsed_date: await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。") return ADD_NEXT_DUE - context.user_data['new_sub_data']['next_due'] = parsed_date + sub_data['next_due'] = parsed_date keyboard = [ [InlineKeyboardButton("天", callback_data='freq_unit_day'), InlineKeyboardButton("周", callback_data='freq_unit_week')], @@ -679,23 +714,34 @@ async def add_next_due_received(update: Update, context: CallbackContext): async def add_freq_unit_received(update: Update, context: CallbackContext): + sub_data, _ = _get_new_sub_data_or_end(update, context) query = update.callback_query await query.answer() + if sub_data is None: + await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + unit = query.data.split('_')[2] if unit not in VALID_FREQ_UNITS: await query.edit_message_text("错误:无效的周期单位,请重试。") return ConversationHandler.END - context.user_data['new_sub_data']['unit'] = unit + sub_data['unit'] = unit await query.edit_message_text("第七步:请输入周期的*数量*(例如:每3个月,输入 3)", parse_mode='Markdown') return ADD_FREQ_VALUE async def add_freq_value_received(update: Update, context: CallbackContext): + sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) + if sub_data is None: + if err_msg_obj: + await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + try: value = int(update.message.text) if value <= 0: raise ValueError - context.user_data['new_sub_data']['value'] = value + sub_data['value'] = value except (ValueError, TypeError): await update.message.reply_text("请输入一个有效的正整数。") return ADD_FREQ_VALUE @@ -709,13 +755,18 @@ async def add_freq_value_received(update: Update, context: CallbackContext): async def add_renewal_type_received(update: Update, context: CallbackContext): + sub_data, _ = _get_new_sub_data_or_end(update, context) query = update.callback_query await query.answer() + if sub_data is None: + await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。") + return ConversationHandler.END + renewal_type = query.data.split('_')[1] if renewal_type not in VALID_RENEWAL_TYPES: await query.edit_message_text("错误:无效的续费类型,请重试。") return ConversationHandler.END - context.user_data['new_sub_data']['renewal_type'] = renewal_type + sub_data['renewal_type'] = renewal_type await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)") return ADD_NOTES