diff --git a/SubMind.py b/SubMind.py index dd02249..2cf1fb6 100644 --- a/SubMind.py +++ b/SubMind.py @@ -425,7 +425,7 @@ def make_autopct(values, currency_code): async def stats(update: Update, context: CallbackContext): user_id = update.effective_user.id - await update.message.reply_text("正在为您生成订阅统计图表,请稍候...") + await update.message.reply_text("正在为您生成更美观的统计图,请稍候...") font_prop = get_chinese_font() main_currency = get_user_main_currency(user_id) @@ -453,41 +453,102 @@ async def stats(update: Update, context: CallbackContext): await update.message.reply_text("您的订阅没有有效的费用信息。") return - plt.style.use('seaborn-v0_8-darkgrid') - fig, ax = plt.subplots(figsize=(12, 12)) + max_categories = 8 + if len(category_costs) > max_categories: + top = category_costs.iloc[:max_categories] + others_sum = category_costs.iloc[max_categories:].sum() + if others_sum > 0: + category_costs = pd.concat([top, pd.Series({'其他': others_sum})]) + else: + category_costs = top + + total_monthly = category_costs.sum() + currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'} + symbol = currency_symbols.get(main_currency.upper(), f'{main_currency.upper()} ') + + def autopct_if_large(pct): + if pct < 4: + return '' + value = total_monthly * pct / 100 + return f'{pct:.1f}%\n{symbol}{value:.2f}' + + fig = plt.figure(figsize=(14, 8), facecolor='#f8fafc') + gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1], wspace=0.12) + ax_pie = fig.add_subplot(gs[0, 0]) + ax_bar = fig.add_subplot(gs[0, 1]) image_path = None try: - autopct_function = make_autopct(category_costs.values, main_currency) + colors = list(plt.get_cmap('tab20').colors)[:len(category_costs)] - wedges, texts, autotexts = ax.pie(category_costs.values, - labels=category_costs.index, - autopct=autopct_function, - startangle=140, - pctdistance=0.7, - labeldistance=1.05) + _, texts, autotexts = ax_pie.pie( + category_costs.values, + labels=category_costs.index, + autopct=autopct_if_large, + startangle=120, + counterclock=False, + pctdistance=0.78, + labeldistance=1.08, + colors=colors, + wedgeprops={'width': 0.42, 'edgecolor': 'white', 'linewidth': 1.2} + ) - ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20) + for t in texts: + t.set_fontproperties(font_prop) + t.set_fontsize(12) + t.set_color('#1f2937') + for t in autotexts: + t.set_fontproperties(font_prop) + t.set_fontsize(10) + t.set_color('#111827') - for text in texts: - text.set_fontproperties(font_prop) - text.set_fontsize(22) + ax_pie.text( + 0, 0, + f"月总支出\n{symbol}{total_monthly:.2f}", + ha='center', va='center', + fontproperties=font_prop, + fontsize=16, + color='#111827', + weight='bold' + ) + ax_pie.set_title('订阅支出结构(按类别)', fontproperties=font_prop, fontsize=16, pad=14, color='#111827') + ax_pie.axis('equal') - for autotext in autotexts: - autotext.set_fontproperties(font_prop) - autotext.set_fontsize(20) - autotext.set_color('white') + bar_series = category_costs.sort_values(ascending=True) + bars = ax_bar.barh(bar_series.index, bar_series.values, color=colors[:len(bar_series)], alpha=0.9) + ax_bar.set_title('类别月支出对比', fontproperties=font_prop, fontsize=16, pad=14, color='#111827') + ax_bar.set_xlabel(f'金额({main_currency.upper()})', fontproperties=font_prop, fontsize=11, color='#374151') + ax_bar.tick_params(axis='x', colors='#6b7280') + ax_bar.tick_params(axis='y', colors='#111827') + ax_bar.grid(axis='x', linestyle='--', alpha=0.25) + for label in ax_bar.get_yticklabels(): + label.set_fontproperties(font_prop) + label.set_fontsize(11) - ax.axis('equal') - fig.tight_layout() + max_val = bar_series.max() if len(bar_series) else 0 + offset = max_val * 0.015 if max_val > 0 else 0.1 + for rect, value in zip(bars, bar_series.values): + ax_bar.text( + rect.get_width() + offset, + rect.get_y() + rect.get_height() / 2, + f"{symbol}{value:.2f}", + va='center', + ha='left', + fontproperties=font_prop, + fontsize=10, + color='#1f2937' + ) + + fig.suptitle('每月订阅支出统计', fontproperties=font_prop, fontsize=20, color='#0f172a', y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.96]) with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp: image_path = tmp.name - plt.savefig(image_path) + plt.savefig(image_path, dpi=220, bbox_inches='tight', facecolor=fig.get_facecolor()) with open(image_path, 'rb') as photo: - await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。") + await update.message.reply_photo(photo, caption="这是优化后的订阅月支出统计图(结构 + 对比)。") finally: plt.close(fig) if image_path and os.path.exists(image_path):