From fb8a5521a99297485129b21c356a069877efaf28 Mon Sep 17 00:00:00 2001 From: Xiaolan Bot Date: Mon, 23 Feb 2026 00:05:50 +0800 Subject: [PATCH] perf: offload heavy blocking I/O (matplotlib, pandas) to asyncio threads --- SubMind.py | 331 ++++++++++++++++++++++++++++------------------------- 1 file changed, 177 insertions(+), 154 deletions(-) diff --git a/SubMind.py b/SubMind.py index 68d36b6..719eedf 100644 --- a/SubMind.py +++ b/SubMind.py @@ -1,4 +1,5 @@ import sqlite3 +import asyncio import os import html import requests @@ -447,186 +448,208 @@ async def stats(update: Update, context: CallbackContext): user_id = update.effective_user.id await update.message.reply_text("正在为您生成更美观的统计图,请稍候...") - font_prop = get_chinese_font() - main_currency = get_user_main_currency(user_id) - with get_db_connection() as conn: - df = pd.read_sql_query("SELECT * FROM subscriptions WHERE user_id = ?", conn, params=(user_id,)) - if df.empty: - await update.message.reply_text("您还没有任何订阅数据。") - return - - df['converted_cost'] = df.apply(lambda row: convert_currency(row['cost'], row['currency'], main_currency), axis=1) - unit_to_days = {'day': 1, 'week': 7, 'month': 30.4375, 'year': 365.25} - - def normalize_to_monthly(row): - if pd.isna(row['frequency_unit']) or pd.isna(row['frequency_value']) or row['frequency_value'] == 0: - return 0 - total_days = row['frequency_value'] * unit_to_days.get(row['frequency_unit'], 0) - if total_days == 0: - return 0 - return (row['converted_cost'] / total_days) * 30.4375 - - df['monthly_cost'] = df.apply(normalize_to_monthly, axis=1) - category_costs = df.groupby('category')['monthly_cost'].sum().sort_values(ascending=False) - - if category_costs.empty or category_costs.sum() == 0: - await update.message.reply_text("您的订阅没有有效的费用信息。") - return - - max_categories = 8 - if len(category_costs) > max_categories: - top = category_costs.iloc[:max_categories] - others_sum = category_costs.iloc[max_categories:].sum() - if others_sum > 0: - category_costs = pd.concat([top, pd.Series({'其他': others_sum})]) - else: - category_costs = top - - total_monthly = category_costs.sum() - currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'} - symbol = currency_symbols.get(main_currency.upper(), f'{main_currency.upper()} ') - - def autopct_if_large(pct): - if pct < 4: - return '' - value = total_monthly * pct / 100 - return f"{pct:.1f}%\\n{symbol}{value:.2f}" - - # Setup figure with a clean, modern background - fig = plt.figure(figsize=(15, 8.5), facecolor='#FAFAFA') - gs = fig.add_gridspec(1, 2, width_ratios=[1.1, 1], wspace=0.15) - ax_pie = fig.add_subplot(gs[0, 0]) - ax_bar = fig.add_subplot(gs[0, 1]) - image_path = None - - try: - # Modern color palette (Tailwind-inspired) - theme_colors = ['#3B82F6', '#10B981', '#F59E0B', '#EF4444', '#8B5CF6', '#EC4899', '#14B8A6', '#F97316', '#6366F1', '#84CC16'] - if len(category_costs) > len(theme_colors): - import matplotlib.pyplot as plt - extra_colors = [matplotlib.colors.to_hex(c) for c in plt.get_cmap('tab20').colors] - theme_colors.extend(extra_colors) + def generate_chart(): + font_prop = get_chinese_font() + main_currency = get_user_main_currency(user_id) + with get_db_connection() as conn: + df = pd.read_sql_query("SELECT * FROM subscriptions WHERE user_id = ?", conn, params=(user_id,)) + cursor = conn.cursor() + cursor.execute("SELECT from_currency, to_currency, rate FROM exchange_rates WHERE to_currency = ?", (main_currency.upper(),)) + rate_cache = {(row['from_currency'], row['to_currency']): row['rate'] for row in cursor.fetchall()} - color_map = {cat: theme_colors[i] for i, cat in enumerate(category_costs.index)} - pie_colors = [color_map[cat] for cat in category_costs.index] + if df.empty: + return False, "您还没有任何订阅数据。" - # Enhanced Donut Chart - wedges, texts, autotexts = ax_pie.pie( - category_costs.values, - labels=category_costs.index, - autopct=autopct_if_large, - startangle=140, - counterclock=False, - pctdistance=0.75, - labeldistance=1.1, - colors=pie_colors, - wedgeprops={'width': 0.35, 'edgecolor': '#FAFAFA', 'linewidth': 2.5} - ) + def fast_convert(amount, from_curr, to_curr): + if from_curr.upper() == to_curr.upper(): + return amount + cached_rate = rate_cache.get((from_curr.upper(), to_curr.upper())) + if cached_rate is not None: + return amount * cached_rate + return convert_currency(amount, from_curr, to_curr) - for t in texts: - t.set_fontproperties(font_prop) - t.set_fontsize(13) - t.set_color('#374151') - for t in autotexts: - t.set_fontproperties(font_prop) - t.set_fontsize(10) - t.set_color('#FFFFFF') - t.set_weight('bold') + df['converted_cost'] = df.apply(lambda row: fast_convert(row['cost'], row['currency'], main_currency), axis=1) + unit_to_days = {'day': 1, 'week': 7, 'month': 30.4375, 'year': 365.25} - # Center text for donut - ax_pie.text( - 0, 0, - f"月总支出\n{symbol}{total_monthly:.2f}", - ha='center', va='center', - fontproperties=font_prop, - fontsize=18, - color='#1F2937', - weight='bold' - ) - ax_pie.set_title('支出占比结构', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold') - ax_pie.axis('equal') + def normalize_to_monthly(row): + if pd.isna(row['frequency_unit']) or pd.isna(row['frequency_value']) or row['frequency_value'] == 0: + return 0 + total_days = row['frequency_value'] * unit_to_days.get(row['frequency_unit'], 0) + if total_days == 0: + return 0 + return (row['converted_cost'] / total_days) * 30.4375 - # Enhanced Bar Chart - bar_series = category_costs.sort_values(ascending=True) - bar_colors = [color_map[cat] for cat in bar_series.index] - - bars = ax_bar.barh(bar_series.index, bar_series.values, color=bar_colors, height=0.6, alpha=0.95, edgecolor='none') - - ax_bar.set_title('各类别月支出对比', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold') - ax_bar.set_xlabel(f'金额({main_currency.upper()})', fontproperties=font_prop, fontsize=12, color='#6B7280', labelpad=10) - - # Clean up axes - ax_bar.spines['top'].set_visible(False) - ax_bar.spines['right'].set_visible(False) - ax_bar.spines['left'].set_visible(False) - ax_bar.spines['bottom'].set_color('#E5E7EB') - - ax_bar.tick_params(axis='x', colors='#6B7280', labelsize=11) - ax_bar.tick_params(axis='y', length=0, pad=10) # Hide y ticks but keep labels - - ax_bar.grid(axis='x', color='#F3F4F6', linestyle='-', linewidth=1.5, alpha=1) - ax_bar.set_axisbelow(True) + df['monthly_cost'] = df.apply(normalize_to_monthly, axis=1) + category_costs = df.groupby('category')['monthly_cost'].sum().sort_values(ascending=False) - for label in ax_bar.get_yticklabels(): - label.set_fontproperties(font_prop) - label.set_fontsize(13) - label.set_color('#374151') + if category_costs.empty or category_costs.sum() == 0: + return False, "您的订阅没有有效的费用信息。" - # Bar value labels - max_val = bar_series.max() if len(bar_series) else 0 - offset = max_val * 0.02 if max_val > 0 else 0.1 - for rect, value in zip(bars, bar_series.values): - ax_bar.text( - rect.get_width() + offset, - rect.get_y() + rect.get_height() / 2, - f"{symbol}{value:.2f}", - va='center', - ha='left', - fontproperties=font_prop, - fontsize=11, - color='#4B5563', - weight='bold' + max_categories = 8 + if len(category_costs) > max_categories: + top = category_costs.iloc[:max_categories] + others_sum = category_costs.iloc[max_categories:].sum() + if others_sum > 0: + category_costs = pd.concat([top, pd.Series({'其他': others_sum})]) + else: + category_costs = top + + total_monthly = category_costs.sum() + currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'} + symbol = currency_symbols.get(main_currency.upper(), f'{main_currency.upper()} ') + + def autopct_if_large(pct): + if pct < 4: + return '' + value = total_monthly * pct / 100 + return f"{pct:.1f}%\n{symbol}{value:.2f}" + + fig = plt.figure(figsize=(15, 8.5), facecolor='#FAFAFA') + gs = fig.add_gridspec(1, 2, width_ratios=[1.1, 1], wspace=0.15) + ax_pie = fig.add_subplot(gs[0, 0]) + ax_bar = fig.add_subplot(gs[0, 1]) + image_path = None + + try: + theme_colors = ['#3B82F6', '#10B981', '#F59E0B', '#EF4444', '#8B5CF6', '#EC4899', '#14B8A6', '#F97316', '#6366F1', '#84CC16'] + if len(category_costs) > len(theme_colors): + import matplotlib.pyplot as plt + extra_colors = [matplotlib.colors.to_hex(c) for c in plt.get_cmap('tab20').colors] + theme_colors.extend(extra_colors) + + color_map = {cat: theme_colors[i] for i, cat in enumerate(category_costs.index)} + pie_colors = [color_map[cat] for cat in category_costs.index] + + wedges, texts, autotexts = ax_pie.pie( + category_costs.values, + labels=category_costs.index, + autopct=autopct_if_large, + startangle=140, + counterclock=False, + pctdistance=0.75, + labeldistance=1.1, + colors=pie_colors, + wedgeprops={'width': 0.35, 'edgecolor': '#FAFAFA', 'linewidth': 2.5} ) - fig.suptitle('📊 您的订阅支出洞察', fontproperties=font_prop, fontsize=24, color='#0F172A', y=1.02, weight='bold') - fig.tight_layout(rect=[0, 0, 1, 0.95]) + for t in texts: + t.set_fontproperties(font_prop) + t.set_fontsize(13) + t.set_color('#374151') + for t in autotexts: + t.set_fontproperties(font_prop) + t.set_fontsize(10) + t.set_color('#FFFFFF') + t.set_weight('bold') - with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp: - image_path = tmp.name + ax_pie.text( + 0, 0, + f"月总支出\n{symbol}{total_monthly:.2f}", + ha='center', va='center', + fontproperties=font_prop, + fontsize=18, + color='#1F2937', + weight='bold' + ) + ax_pie.set_title('支出占比结构', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold') + ax_pie.axis('equal') - plt.savefig(image_path, dpi=250, bbox_inches='tight', facecolor=fig.get_facecolor()) + bar_series = category_costs.sort_values(ascending=True) + bar_colors = [color_map[cat] for cat in bar_series.index] + + bars = ax_bar.barh(bar_series.index, bar_series.values, color=bar_colors, height=0.6, alpha=0.95, edgecolor='none') + + ax_bar.set_title('各类别月支出对比', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold') + ax_bar.set_xlabel(f'金额({main_currency.upper()})', fontproperties=font_prop, fontsize=12, color='#6B7280', labelpad=10) + + ax_bar.spines['top'].set_visible(False) + ax_bar.spines['right'].set_visible(False) + ax_bar.spines['left'].set_visible(False) + ax_bar.spines['bottom'].set_color('#E5E7EB') + + ax_bar.tick_params(axis='x', colors='#6B7280', labelsize=11) + ax_bar.tick_params(axis='y', length=0, pad=10) + + ax_bar.grid(axis='x', color='#F3F4F6', linestyle='-', linewidth=1.5, alpha=1) + ax_bar.set_axisbelow(True) - with open(image_path, 'rb') as photo: - await update.message.reply_photo(photo, caption="✨ 已为您生成全新的精美订阅统计图!") - finally: - plt.close(fig) - if image_path and os.path.exists(image_path): - os.remove(image_path) + for label in ax_bar.get_yticklabels(): + label.set_fontproperties(font_prop) + label.set_fontsize(13) + label.set_color('#374151') + + max_val = bar_series.max() if len(bar_series) else 0 + offset = max_val * 0.02 if max_val > 0 else 0.1 + for rect, value in zip(bars, bar_series.values): + ax_bar.text( + rect.get_width() + offset, + rect.get_y() + rect.get_height() / 2, + f"{symbol}{value:.2f}", + va='center', + ha='left', + fontproperties=font_prop, + fontsize=11, + color='#4B5563', + weight='bold' + ) + + fig.suptitle('📊 您的订阅支出洞察', fontproperties=font_prop, fontsize=24, color='#0F172A', y=1.02, weight='bold') + fig.tight_layout(rect=[0, 0, 1, 0.95]) + + with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp: + image_path = tmp.name + + plt.savefig(image_path, dpi=250, bbox_inches='tight', facecolor=fig.get_facecolor()) + return True, image_path + finally: + plt.close(fig) + + success, result = await asyncio.to_thread(generate_chart) + + if success: + try: + with open(result, 'rb') as photo: + await update.message.reply_photo(photo, caption="✨ 已为您生成全新的精美订阅统计图!") + finally: + if os.path.exists(result): + os.remove(result) + else: + await update.message.reply_text(result) # --- Import/Export Commands --- async def export_command(update: Update, context: CallbackContext): user_id = update.effective_user.id - with get_db_connection() as conn: - df = pd.read_sql_query( - "SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ?", - conn, params=(user_id,)) - if df.empty: + + # 将重度 I/O 和 CPU 绑定的 pandas 导出操作放入后台线程 + def process_export(): + with get_db_connection() as conn: + df = pd.read_sql_query( + "SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ?", + conn, params=(user_id,)) + if df.empty: + return False, None + + tmp = tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) + export_path = tmp.name + tmp.close() + + df.to_csv(export_path, index=False, encoding='utf-8-sig') + return True, export_path + + success, export_path = await asyncio.to_thread(process_export) + + if not success: await update.message.reply_text("您还没有任何订阅数据,无法导出。") return - with tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) as tmp: - export_path = tmp.name - try: - df.to_csv(export_path, index=False, encoding='utf-8-sig') - with open(export_path, 'rb') as file: await update.message.reply_document(document=file, filename='subscriptions.csv', caption="您的订阅数据已导出为 CSV 文件。") finally: - if os.path.exists(export_path): + if export_path and os.path.exists(export_path): os.remove(export_path)