diff --git a/SubMind.py b/SubMind.py index de2ccaa..22cd575 100644 --- a/SubMind.py +++ b/SubMind.py @@ -263,25 +263,9 @@ def _parse_category_id_from_callback(data: str) -> int | None: return int(payload) if payload.isdigit() else None -async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup: - with get_db_connection() as conn: - cursor = conn.cursor() - query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id] - if category_filter: - query += "AND category = ? " - params.append(category_filter) - query += "ORDER BY next_due ASC" - cursor.execute(query, tuple(params)) - subs = cursor.fetchall() - if not subs: - return None - buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs] - keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)] - if category_filter: - keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]) - else: - keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')]) - return InlineKeyboardMarkup(keyboard) +def _clear_action_state(context: CallbackContext, keys: list[str]): + for key in keys: + context.user_data.pop(key, None) # --- 自动任务 --- @@ -435,40 +419,42 @@ async def stats(update: Update, context: CallbackContext): plt.style.use('seaborn-v0_8-darkgrid') fig, ax = plt.subplots(figsize=(12, 12)) - - autopct_function = make_autopct(category_costs.values, main_currency) - - wedges, texts, autotexts = ax.pie(category_costs.values, - labels=category_costs.index, - autopct=autopct_function, - startangle=140, - pctdistance=0.7, - labeldistance=1.05) - - ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20) - - for text in texts: - text.set_fontproperties(font_prop) - text.set_fontsize(22) - - for autotext in autotexts: - autotext.set_fontproperties(font_prop) - autotext.set_fontsize(20) - autotext.set_color('white') - - ax.axis('equal') - - fig.tight_layout() - with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp: - image_path = tmp.name - plt.savefig(image_path) - plt.close(fig) + image_path = None try: + autopct_function = make_autopct(category_costs.values, main_currency) + + wedges, texts, autotexts = ax.pie(category_costs.values, + labels=category_costs.index, + autopct=autopct_function, + startangle=140, + pctdistance=0.7, + labeldistance=1.05) + + ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20) + + for text in texts: + text.set_fontproperties(font_prop) + text.set_fontsize(22) + + for autotext in autotexts: + autotext.set_fontproperties(font_prop) + autotext.set_fontsize(20) + autotext.set_color('white') + + ax.axis('equal') + fig.tight_layout() + + with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp: + image_path = tmp.name + + plt.savefig(image_path) + with open(image_path, 'rb') as photo: await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。") finally: - if os.path.exists(image_path): + plt.close(fig) + if image_path and os.path.exists(image_path): os.remove(image_path) @@ -486,9 +472,9 @@ async def export_command(update: Update, context: CallbackContext): with tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) as tmp: export_path = tmp.name - df.to_csv(export_path, index=False, encoding='utf-8-sig') - try: + df.to_csv(export_path, index=False, encoding='utf-8-sig') + with open(export_path, 'rb') as file: await update.message.reply_document(document=file, filename='subscriptions.csv', caption="您的订阅数据已导出为 CSV 文件。") @@ -545,14 +531,20 @@ async def import_upload_received(update: Update, context: CallbackContext): renewal_type = str(row['renewal_type']).lower() if renewal_type not in valid_renewal_types: raise ValueError(f"无效续费类型: {renewal_type}") - notes = str(row['notes']) if pd.notna(row['notes']) else None + notes = str(row['notes']).strip() if pd.notna(row['notes']) else None + name = str(row['name']).strip() + category = str(row['category']).strip() + if not name: + raise ValueError("名称不能为空") + if not category: + raise ValueError("类别不能为空") records.append(( - user_id, row['name'], cost, currency, row['category'], + user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes )) except Exception as e: - logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}") - await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误:{e}") + logger.error(f"Invalid row in CSV import, error: {e}") + await update.message.reply_text(f"导入失败,存在无效行:{e}") return ConversationHandler.END with get_db_connection() as conn: @@ -926,6 +918,12 @@ async def button_callback_handler(update: Update, context: CallbackContext): parse_mode='MarkdownV2', reply_markup=None) elif action == 'delete': + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) + if not cursor.fetchone(): + await query.answer("错误:找不到该订阅或无权限。", show_alert=True) + return keyboard = InlineKeyboardMarkup([ [InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'), InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')] @@ -935,7 +933,11 @@ async def button_callback_handler(update: Update, context: CallbackContext): with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) + deleted = cursor.rowcount conn.commit() + if deleted == 0: + await query.answer("错误:找不到该订阅或无权限。", show_alert=True) + return await query.answer("订阅已删除") if 'list_subs_in_category' in context.user_data: category = context.user_data['list_subs_in_category'] @@ -1068,7 +1070,7 @@ async def edit_freq_value_received(update: Update, context: CallbackContext): conn.commit() await update.message.reply_text("✅ 周期已更新!") - context.user_data.clear() + _clear_action_state(context, ['sub_id_for_action', 'new_freq_unit', 'field_to_edit']) await show_subscription_view(update, context, sub_id) return ConversationHandler.END @@ -1131,6 +1133,11 @@ async def edit_new_value_received(update: Update, context: CallbackContext): validation_failed = True else: new_value = parsed + elif field == 'renewal_type': + if str(new_value) not in ('auto', 'manual'): + if message_to_reply: + await message_to_reply.reply_text("续费方式只能为 auto 或 manual。") + validation_failed = True elif field == 'category': new_value = str(new_value).strip() if not new_value: @@ -1161,7 +1168,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext): elif message_to_reply: await message_to_reply.reply_text("✅ 字段已更新!") - context.user_data.clear() + _clear_action_state(context, ['sub_id_for_action', 'field_to_edit', 'new_freq_unit']) await show_subscription_view(update, context, sub_id) return ConversationHandler.END @@ -1282,7 +1289,7 @@ async def remind_days_received(update: Update, context: CallbackContext): return ConversationHandler.END conn.commit() await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。") - context.user_data.clear() + _clear_action_state(context, ['sub_id_for_action']) await show_subscription_view(update, context, sub_id) except (ValueError, TypeError): await update.message.reply_text("请输入一个有效的非负整数。")