import sqlite3 import asyncio import os import sys import subprocess import html import requests import datetime import dateparser import logging import tempfile import pandas as pd import matplotlib import matplotlib.pyplot as plt import matplotlib.font_manager as fm import re from dotenv import load_dotenv from dateutil.relativedelta import relativedelta from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, BotCommand, CallbackQuery from telegram.ext import ( Application, CommandHandler, MessageHandler, filters, CallbackContext, CallbackQueryHandler, ConversationHandler ) from telegram.error import TelegramError from telegram.helpers import escape_markdown # Fallback just in case import html def escape_html(text, version=None): if text is None: return '' return html.escape(str(text)) # --- 加载 .env 和设置 --- load_dotenv() # --- 日志配置 --- logging.basicConfig( format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO ) logging.getLogger("httpx").setLevel(logging.WARNING) logger = logging.getLogger(__name__) # --- 环境变量和项目配置 --- TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN') EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY') PROJECT_NAME = "SubMind" DB_FILE = 'submind.db' # 自动更新配置 UPDATE_OWNER_ID = os.getenv('UPDATE_OWNER_ID') # 仅允许此用户执行 /update AUTO_UPDATE_REMOTE = os.getenv('AUTO_UPDATE_REMOTE', 'gitllc') AUTO_UPDATE_BRANCH = os.getenv('AUTO_UPDATE_BRANCH', 'main') # --- 对话处理器状态 --- (ADD_NAME, ADD_COST, ADD_CURRENCY, ADD_CATEGORY, ADD_NEXT_DUE, ADD_FREQ_UNIT, ADD_FREQ_VALUE, ADD_RENEWAL_TYPE, ADD_NOTES) = range(9) (EDIT_SELECT_FIELD, EDIT_GET_NEW_VALUE, EDIT_FREQ_UNIT, EDIT_FREQ_VALUE) = range(4) (REMIND_SELECT_ACTION, REMIND_GET_DAYS) = range(2) (IMPORT_UPLOAD,) = range(1) # --- 辅助函数:数据库连接管理 --- def get_db_connection(): conn = sqlite3.connect(DB_FILE) conn.row_factory = sqlite3.Row return conn # --- 字体管理 --- def get_chinese_font(): font_name = 'SourceHanSansSC-Regular.otf' font_path = os.path.join('fonts', font_name) if os.path.exists(font_path): logger.debug(f"Found font at {font_path}") return fm.FontProperties(fname=font_path) logger.info(f"Font '{font_name}' not found. Attempting to download...") os.makedirs('fonts', exist_ok=True) urls = [ 'https://github.com/wweir/source-han-sans-sc/raw/refs/heads/master/SourceHanSansSC-Regular.otf', 'https://cdn.jsdelivr.net/gh/wweir/source-han-sans-sc@master/SourceHanSansSC-Regular.otf', 'https://fastly.jsdelivr.net/gh/wweir/source-han-sans-sc@master/SourceHanSansSC-Regular.otf' ] headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } for url in urls: try: logger.info(f"Trying to download font from: {url}") response = requests.get(url, stream=True, headers=headers, timeout=15) response.raise_for_status() with open(font_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info(f"Font '{font_name}' downloaded successfully to '{font_path}'.") fm._load_fontmanager(try_read_cache=False) return fm.FontProperties(fname=font_path) except requests.exceptions.RequestException as e: logger.warning(f"Failed to download font from {url}. Error: {e}") continue logger.error("All font download attempts failed. Falling back to default sans-serif.") return fm.FontProperties(family='sans-serif') # --- 数据库初始化与迁移 --- def init_db(): with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS subscriptions ( id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, cost REAL, currency TEXT, category TEXT, next_due DATE, frequency TEXT, renewal_type TEXT DEFAULT 'auto', reminders_enabled BOOLEAN DEFAULT TRUE, reminder_days INTEGER DEFAULT 3, reminder_on_due_date BOOLEAN DEFAULT TRUE, frequency_unit TEXT, frequency_value INTEGER, notes TEXT ) ''') cursor.execute('CREATE INDEX IF NOT EXISTS idx_subscriptions_user_id ON subscriptions(user_id)') cursor.execute("PRAGMA table_info(subscriptions)") columns = [info[1] for info in cursor.fetchall()] if 'frequency_unit' not in columns: cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_unit TEXT") if 'frequency_value' not in columns: cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_value INTEGER") if 'reminders_enabled' not in columns: cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminders_enabled BOOLEAN DEFAULT TRUE") if 'reminder_days' not in columns: cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_days INTEGER DEFAULT 3") if 'reminder_on_due_date' not in columns: cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_on_due_date BOOLEAN DEFAULT TRUE") if 'notes' not in columns: cursor.execute("ALTER TABLE subscriptions ADD COLUMN notes TEXT") if 'last_reminded_date' not in columns: cursor.execute("ALTER TABLE subscriptions ADD COLUMN last_reminded_date DATE") cursor.execute(''' CREATE TABLE IF NOT EXISTS categories ( id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, UNIQUE(user_id, name) ) ''') cursor.execute('CREATE INDEX IF NOT EXISTS idx_categories_user_id ON categories(user_id)') cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( user_id INTEGER PRIMARY KEY, main_currency TEXT DEFAULT "USD", language TEXT DEFAULT "en" ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS exchange_rates ( from_currency TEXT, to_currency TEXT, rate REAL, last_updated TIMESTAMP, PRIMARY KEY (from_currency, to_currency) ) ''') migrate_frequency_data(conn, cursor) conn.commit() def migrate_frequency_data(conn, cursor): cursor.execute("SELECT id, frequency FROM subscriptions WHERE frequency IS NOT NULL AND frequency_unit IS NULL") subs_to_migrate = cursor.fetchall() if not subs_to_migrate: return freq_map = { 'daily': ('day', 1), 'weekly': ('week', 1), '周付': ('week', 1), 'monthly': ('month', 1), '月付': ('month', 1), 'quarterly': ('month', 3), '季付': ('month', 3), '半年': ('month', 6), 'half-year': ('month', 6), 'biannually': ('month', 6), 'yearly': ('year', 1), '年付': ('year', 1) } for sub_id, freq_str in subs_to_migrate: unit, value = freq_map.get(str(freq_str).lower(), (None, None)) if unit and value: cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ?", (unit, value, sub_id)) conn.commit() init_db() # --- 辅助函数 --- def get_user_main_currency(user_id): with get_db_connection() as conn: cursor = conn.cursor() cursor.execute('SELECT main_currency FROM users WHERE user_id = ?', (user_id,)) result = cursor.fetchone() return result['main_currency'] if result else 'USD' def convert_currency(amount, from_curr, to_curr): if from_curr.upper() == to_curr.upper(): return amount with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT rate, last_updated FROM exchange_rates WHERE from_currency = ? AND to_currency = ? ''', (from_curr.upper(), to_curr.upper())) result = cursor.fetchone() now = datetime.datetime.now() cache_validity = datetime.timedelta(hours=24) if result and (now - datetime.datetime.fromisoformat(result['last_updated'])) < cache_validity: logger.debug(f"Using cached exchange rate for {from_curr} to {to_curr}: {result['rate']}") return amount * result['rate'] if not EXCHANGE_API_KEY: logger.warning("EXCHANGE_API_KEY not set, returning original amount") return amount try: url = f"https://v6.exchangerate-api.com/v6/{EXCHANGE_API_KEY}/pair/{from_curr}/{to_curr}/{amount}" response = requests.get(url, timeout=5) response.raise_for_status() data = response.json() rate = data.get('conversion_rate', 1.0) cursor.execute(''' INSERT OR REPLACE INTO exchange_rates (from_currency, to_currency, rate, last_updated) VALUES (?, ?, ?, ?) ''', (from_curr.upper(), to_curr.upper(), rate, now.isoformat())) conn.commit() logger.debug(f"Updated exchange rate cache for {from_curr} to {to_curr}: {rate}") return amount * rate except requests.exceptions.RequestException as e: logger.error(f"Currency conversion API error: {e}") if result: logger.warning(f"Falling back to cached rate: {result['rate']}") return amount * result['rate'] logger.warning("No cached rate available, returning original amount") return amount def parse_date(date_string: str) -> str: today = datetime.datetime.now() try: dt = dateparser.parse(date_string, languages=['en', 'zh'], settings={'TIMEZONE': 'Asia/Shanghai', 'RETURN_AS_TIMEZONE_AWARE': False}) if not dt: return None has_year_info = any(c in date_string for c in ['年', '/']) or (re.search(r'\d{4}', date_string) is not None) if not has_year_info and dt.date() < today.date(): dt = dt.replace(year=dt.year + 1) return dt.strftime('%Y-%m-%d') except Exception as e: logger.error(f"Date parsing failed for string '{date_string}'. Error: {e}") return None def calculate_new_due_date(base_date, unit, value): delta_map = { 'day': relativedelta(days=+value), 'week': relativedelta(weeks=+value), 'month': relativedelta(months=+value), 'year': relativedelta(years=+value) } delta = delta_map.get(str(unit).lower()) return base_date + delta if delta else None def format_frequency(unit, value) -> str: if not unit or value is None: return "未知" unit_map = {'day': '天', 'week': '周', 'month': '个月', 'year': '年'} if value == 1: single_unit_map = {'day': '每天', 'week': '每周', 'month': '每月', 'year': '每年'} return single_unit_map.get(unit, f"每 {value} {unit_map.get(unit, unit)}") return f"每 {value} {unit_map.get(unit, unit)}" CATEGORY_CB_PREFIX = "list_subs_in_category_id_" EDITABLE_SUB_FIELDS = { 'name': 'name', 'cost': 'cost', 'currency': 'currency', 'category': 'category', 'next_due': 'next_due', 'renewal_type': 'renewal_type', 'notes': 'notes' } MAX_NAME_LEN = 128 MAX_CATEGORY_LEN = 64 MAX_NOTES_LEN = 1000 VALID_FREQ_UNITS = {'day', 'week', 'month', 'year'} VALID_RENEWAL_TYPES = {'auto', 'manual'} def _build_category_callback_data(category_id: int) -> str: return f"{CATEGORY_CB_PREFIX}{category_id}" def _parse_category_id_from_callback(data: str) -> int | None: payload = data.replace(CATEGORY_CB_PREFIX, '', 1) return int(payload) if payload.isdigit() else None async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup: with get_db_connection() as conn: cursor = conn.cursor() query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id] if category_filter: query += "AND category = ? " params.append(category_filter) query += "ORDER BY next_due ASC" cursor.execute(query, tuple(params)) subs = cursor.fetchall() if not subs: return None buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs] keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)] if category_filter: keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]) else: keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')]) return InlineKeyboardMarkup(keyboard) def _clear_action_state(context: CallbackContext, keys: list[str]): for key in keys: context.user_data.pop(key, None) def _get_new_sub_data_or_end(update: Update, context: CallbackContext): sub_data = context.user_data.get('new_sub_data') if sub_data is None: message_obj = update.message or (update.callback_query.message if update.callback_query else None) if message_obj: # 统一提示,避免 KeyError 导致会话崩溃 return None, message_obj return sub_data, None # --- 自动任务 --- def update_past_due_dates(): today = datetime.date.today() with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM subscriptions WHERE next_due < ? AND renewal_type = 'auto'", (today,)) past_due_subs = cursor.fetchall() if not past_due_subs: return for sub in past_due_subs: try: last_due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date() new_due_date = last_due_date while new_due_date <= today: calculated_date = calculate_new_due_date(new_due_date, sub['frequency_unit'], sub['frequency_value']) if calculated_date: new_due_date = calculated_date else: break if new_due_date > last_due_date: cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_due_date.strftime('%Y-%m-%d'), sub['id'])) except Exception as e: logger.error(f"Failed to update subscription {sub['id']}: {e}") conn.commit() async def check_and_send_reminders(context: CallbackContext): logger.info("Running job: Checking for subscription reminders...") today = datetime.date.today() today_str = today.strftime('%Y-%m-%d') with get_db_connection() as conn: cursor = conn.cursor() # 过滤掉今天已经提醒过的订阅 cursor.execute("SELECT * FROM subscriptions WHERE reminders_enabled = TRUE AND next_due IS NOT NULL AND (last_reminded_date IS NULL OR last_reminded_date != ?)", (today_str,)) subs_to_check = cursor.fetchall() for sub in subs_to_check: try: due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date() user_id = sub['user_id'] renewal_type = sub['renewal_type'] safe_sub_name = escape_html(sub['name']) message = None keyboard = None if renewal_type == 'manual': keyboard = InlineKeyboardMarkup([ [InlineKeyboardButton("✅ 我已续费", callback_data=f"renewfromremind_{sub['id']}")] ]) if sub['reminder_on_due_date'] and due_date == today: message = f"🔔 订阅到期提醒\n\n您的订阅 {safe_sub_name} 今天到期。" if renewal_type == 'manual': message += " 请记得手动续费。" else: message += " 将会自动续费。" keyboard = None elif renewal_type == 'manual' and sub['reminder_days'] > 0: reminder_date = due_date - datetime.timedelta(days=sub['reminder_days']) if reminder_date == today: days_left = (due_date - today).days days_text = f"{days_left}天后" if days_left > 0 else "今天" message = f"🔔 订阅即将到期提醒\n\n您的手动续费订阅 {safe_sub_name} 将在 {days_text} 到期。" if message: await context.bot.send_message( chat_id=user_id, text=message, parse_mode='HTML', reply_markup=keyboard ) # 记录今天已发送提醒 with get_db_connection() as update_conn: update_cursor = update_conn.cursor() update_cursor.execute("UPDATE subscriptions SET last_reminded_date = ? WHERE id = ?", (today_str, sub['id'])) update_conn.commit() logger.info(f"Reminder sent for sub_id {sub['id']}") except Exception as e: logger.error(f"Failed to process reminder for sub_id {sub.get('id', 'N/A')}: {e}") # --- 命令处理器 --- async def start(update: Update, context: CallbackContext): user_id = update.effective_user.id with get_db_connection() as conn: cursor = conn.cursor() cursor.execute('INSERT OR IGNORE INTO users (user_id) VALUES (?)', (user_id,)) conn.commit() await update.message.reply_text(f'欢迎使用 {escape_html(PROJECT_NAME)}!\n您的私人订阅智能管家。', parse_mode='HTML') async def help_command(update: Update, context: CallbackContext): help_text = fr""" *{escape_html(PROJECT_NAME)} 命令列表* *🌟 核心功能* /add\_sub \- 引导您添加一个新的订阅 /list\_subs \- 列出您的所有订阅 /list\_categories \- 按分类浏览您的订阅 *📊 数据管理* /stats \- 查看按类别分类的订阅统计 /import \- 通过上传 CSV 文件批量导入订阅 /export \- 将您的所有订阅导出为 CSV 文件 *⚙️ 个性化设置* /set\_currency \`\` \- 设置您的主要货币 /cancel \- 在任何流程中取消当前操作 """ await update.message.reply_text(help_text, parse_mode='HTML') def make_autopct(values, currency_code): currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'} symbol = currency_symbols.get(currency_code.upper(), f'{currency_code} ') def my_autopct(pct): total = sum(values) val = float(pct * total / 100.0) return f'{symbol}{val:.2f}\n({pct:.1f}%)' return my_autopct async def stats(update: Update, context: CallbackContext): user_id = update.effective_user.id await update.message.reply_text("正在为您生成更美观的统计图,请稍候...") def generate_chart(): font_prop = get_chinese_font() main_currency = get_user_main_currency(user_id) with get_db_connection() as conn: df = pd.read_sql_query("SELECT * FROM subscriptions WHERE user_id = ?", conn, params=(user_id,)) cursor = conn.cursor() cursor.execute("SELECT from_currency, to_currency, rate FROM exchange_rates WHERE to_currency = ?", (main_currency.upper(),)) rate_cache = {(row['from_currency'], row['to_currency']): row['rate'] for row in cursor.fetchall()} if df.empty: return False, "您还没有任何订阅数据。" def fast_convert(amount, from_curr, to_curr): if from_curr.upper() == to_curr.upper(): return amount cached_rate = rate_cache.get((from_curr.upper(), to_curr.upper())) if cached_rate is not None: return amount * cached_rate return convert_currency(amount, from_curr, to_curr) df['converted_cost'] = df.apply(lambda row: fast_convert(row['cost'], row['currency'], main_currency), axis=1) unit_to_days = {'day': 1, 'week': 7, 'month': 30.4375, 'year': 365.25} def normalize_to_monthly(row): if pd.isna(row['frequency_unit']) or pd.isna(row['frequency_value']) or row['frequency_value'] == 0: return 0 total_days = row['frequency_value'] * unit_to_days.get(row['frequency_unit'], 0) if total_days == 0: return 0 return (row['converted_cost'] / total_days) * 30.4375 df['monthly_cost'] = df.apply(normalize_to_monthly, axis=1) category_costs = df.groupby('category')['monthly_cost'].sum().sort_values(ascending=False) if category_costs.empty or category_costs.sum() == 0: return False, "您的订阅没有有效的费用信息。" max_categories = 8 if len(category_costs) > max_categories: top = category_costs.iloc[:max_categories] others_sum = category_costs.iloc[max_categories:].sum() if others_sum > 0: category_costs = pd.concat([top, pd.Series({'其他': others_sum})]) else: category_costs = top total_monthly = category_costs.sum() currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'} symbol = currency_symbols.get(main_currency.upper(), f'{main_currency.upper()} ') def autopct_if_large(pct): if pct < 4: return '' value = total_monthly * pct / 100 return f"{pct:.1f}%\n{symbol}{value:.2f}" fig = plt.figure(figsize=(15, 8.5), facecolor='#FAFAFA') gs = fig.add_gridspec(1, 2, width_ratios=[1.1, 1], wspace=0.15) ax_pie = fig.add_subplot(gs[0, 0]) ax_bar = fig.add_subplot(gs[0, 1]) image_path = None try: theme_colors = ['#3B82F6', '#10B981', '#F59E0B', '#EF4444', '#8B5CF6', '#EC4899', '#14B8A6', '#F97316', '#6366F1', '#84CC16'] if len(category_costs) > len(theme_colors): # 移除导致遮蔽的局部 import,直接使用全局的 matplotlib 和 plt extra_colors = [matplotlib.colors.to_hex(c) for c in plt.get_cmap('tab20').colors] theme_colors.extend(extra_colors) color_map = {cat: theme_colors[i] for i, cat in enumerate(category_costs.index)} pie_colors = [color_map[cat] for cat in category_costs.index] wedges, texts, autotexts = ax_pie.pie( category_costs.values, labels=category_costs.index, autopct=autopct_if_large, startangle=140, counterclock=False, pctdistance=0.75, labeldistance=1.1, colors=pie_colors, wedgeprops={'width': 0.35, 'edgecolor': '#FAFAFA', 'linewidth': 2.5} ) for t in texts: t.set_fontproperties(font_prop) t.set_fontsize(13) t.set_color('#374151') for t in autotexts: t.set_fontproperties(font_prop) t.set_fontsize(10) t.set_color('#FFFFFF') t.set_weight('bold') ax_pie.text( 0, 0, f"月总支出\n{symbol}{total_monthly:.2f}", ha='center', va='center', fontproperties=font_prop, fontsize=18, color='#1F2937', weight='bold' ) ax_pie.set_title('支出占比结构', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold') ax_pie.axis('equal') bar_series = category_costs.sort_values(ascending=True) bar_colors = [color_map[cat] for cat in bar_series.index] bars = ax_bar.barh(bar_series.index, bar_series.values, color=bar_colors, height=0.6, alpha=0.95, edgecolor='none') ax_bar.set_title('各类别月支出对比', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold') ax_bar.set_xlabel(f'金额({main_currency.upper()})', fontproperties=font_prop, fontsize=12, color='#6B7280', labelpad=10) ax_bar.spines['top'].set_visible(False) ax_bar.spines['right'].set_visible(False) ax_bar.spines['left'].set_visible(False) ax_bar.spines['bottom'].set_color('#E5E7EB') ax_bar.tick_params(axis='x', colors='#6B7280', labelsize=11) ax_bar.tick_params(axis='y', length=0, pad=10) ax_bar.grid(axis='x', color='#F3F4F6', linestyle='-', linewidth=1.5, alpha=1) ax_bar.set_axisbelow(True) for label in ax_bar.get_yticklabels(): label.set_fontproperties(font_prop) label.set_fontsize(13) label.set_color('#374151') max_val = bar_series.max() if len(bar_series) else 0 offset = max_val * 0.02 if max_val > 0 else 0.1 for rect, value in zip(bars, bar_series.values): ax_bar.text( rect.get_width() + offset, rect.get_y() + rect.get_height() / 2, f"{symbol}{value:.2f}", va='center', ha='left', fontproperties=font_prop, fontsize=11, color='#4B5563', weight='bold' ) fig.suptitle('您的订阅统计报告', fontproperties=font_prop, fontsize=24, color='#0F172A', y=1.02, weight='bold') fig.tight_layout(rect=[0, 0, 1, 0.95]) with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp: image_path = tmp.name plt.savefig(image_path, dpi=250, bbox_inches='tight', facecolor=fig.get_facecolor()) return True, image_path finally: plt.close(fig) success, result = await asyncio.to_thread(generate_chart) if success: try: with open(result, 'rb') as photo: await update.message.reply_photo(photo, caption="✨ 已为您生成全新的精美订阅统计图!") finally: if os.path.exists(result): os.remove(result) else: await update.message.reply_text(result) # --- Import/Export Commands --- async def export_command(update: Update, context: CallbackContext): user_id = update.effective_user.id # 将重度 I/O 和 CPU 绑定的 pandas 导出操作放入后台线程 def process_export(): with get_db_connection() as conn: df = pd.read_sql_query( "SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ?", conn, params=(user_id,)) if df.empty: return False, None tmp = tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) export_path = tmp.name tmp.close() df.to_csv(export_path, index=False, encoding='utf-8-sig') return True, export_path success, export_path = await asyncio.to_thread(process_export) if not success: await update.message.reply_text("您还没有任何订阅数据,无法导出。") return try: with open(export_path, 'rb') as file: await update.message.reply_document(document=file, filename='subscriptions.csv', caption="您的订阅数据已导出为 CSV 文件。") finally: if export_path and os.path.exists(export_path): os.remove(export_path) async def import_start(update: Update, context: CallbackContext): await update.message.reply_text( "请上传一个 CSV 文件以导入订阅数据。\n文件应包含以下列:name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes(notes 可为空)。") return IMPORT_UPLOAD async def import_upload_received(update: Update, context: CallbackContext): user_id = update.effective_user.id if not update.message.document or not update.message.document.file_name.endswith('.csv'): await update.message.reply_text("请上传一个有效的 CSV 文件。") return IMPORT_UPLOAD file = await update.message.document.get_file() with tempfile.NamedTemporaryFile(prefix=f'import_{user_id}_', suffix='.csv', delete=False) as tmp: file_path = tmp.name try: await file.download_to_drive(file_path) df = pd.read_csv(file_path, encoding='utf-8-sig') required_columns = ['name', 'cost', 'currency', 'category', 'next_due', 'frequency_unit', 'frequency_value', 'renewal_type'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: await update.message.reply_text(f"CSV 文件缺少以下必要列:{', '.join(missing_columns)}") return ConversationHandler.END valid_units = ['day', 'week', 'month', 'year'] valid_renewal_types = ['auto', 'manual'] records = [] for _, row in df.iterrows(): try: cost = float(row['cost']) if cost < 0: raise ValueError("费用不能为负数") currency = str(row['currency']).upper() if not (len(currency) == 3 and currency.isalpha()): raise ValueError(f"无效货币代码: {currency}") next_due = parse_date(str(row['next_due'])) if not next_due: raise ValueError(f"无效日期格式: {row['next_due']}") frequency_unit = str(row['frequency_unit']).lower() if frequency_unit not in valid_units: raise ValueError(f"无效周期单位: {frequency_unit}") frequency_value = int(row['frequency_value']) if frequency_value <= 0: raise ValueError(f"无效周期数量: {frequency_value}") renewal_type = str(row['renewal_type']).lower() if renewal_type not in valid_renewal_types: raise ValueError(f"无效续费类型: {renewal_type}") notes = str(row['notes']).strip() if pd.notna(row['notes']) else None if notes and len(notes) > MAX_NOTES_LEN: raise ValueError(f"备注过长(>{MAX_NOTES_LEN})") name = str(row['name']).strip() category = str(row['category']).strip() if not name: raise ValueError("名称不能为空") if not category: raise ValueError("类别不能为空") if len(name) > MAX_NAME_LEN: raise ValueError(f"名称过长(>{MAX_NAME_LEN})") if len(category) > MAX_CATEGORY_LEN: raise ValueError(f"类别过长(>{MAX_CATEGORY_LEN})") records.append(( user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes )) except Exception as e: logger.error(f"Invalid row in CSV import, error: {e}") await update.message.reply_text(f"导入失败,存在无效行:{e}") return ConversationHandler.END with get_db_connection() as conn: cursor = conn.cursor() cursor.executemany(''' INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', records) for record in records: cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, record[4])) conn.commit() await update.message.reply_text(f"✅ 成功导入 {len(records)} 条订阅数据!") except Exception as e: logger.error(f"Import failed: {e}") await update.message.reply_text(f"导入失败:{e}") finally: if os.path.exists(file_path): os.remove(file_path) return ConversationHandler.END # --- Add Subscription Conversation --- async def add_sub_start(update: Update, context: CallbackContext): context.user_data['new_sub_data'] = {} await update.message.reply_text("好的,我们来添加一个新订阅。\n\n第一步:请输入订阅的 名称", parse_mode='HTML') return ADD_NAME async def add_name_received(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END name = update.message.text.strip() if not name: await update.message.reply_text("订阅名称不能为空。") return ADD_NAME if len(name) > MAX_NAME_LEN: await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。") return ADD_NAME sub_data['name'] = name await update.message.reply_text("第二步:请输入订阅 费用", parse_mode='HTML') return ADD_COST async def add_cost_received(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END try: cost = float(update.message.text) if cost < 0: raise ValueError("费用不能为负数") sub_data['cost'] = cost except (ValueError, TypeError): await update.message.reply_text("费用必须是有效的非负数字。") return ADD_COST await update.message.reply_text("第三步:请输入 货币 代码(例如 USD, CNY)", parse_mode='HTML') return ADD_CURRENCY async def add_currency_received(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END currency = update.message.text.upper() if not (len(currency) == 3 and currency.isalpha()): await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。") return ADD_CURRENCY sub_data['currency'] = currency await update.message.reply_text("第四步:请为订阅指定一个 类别", parse_mode='HTML') return ADD_CATEGORY async def add_category_received(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END user_id, category_name = update.effective_user.id, update.message.text.strip() if not category_name: await update.message.reply_text("类别不能为空。") return ADD_CATEGORY if len(category_name) > MAX_CATEGORY_LEN: await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。") return ADD_CATEGORY sub_data['category'] = category_name with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name)) conn.commit() await update.message.reply_text("第五步:请输入 *下一次付款日期*(例如 2025\\-10\\-01 或 10月1日)", parse_mode='HTML') return ADD_NEXT_DUE async def add_next_due_received(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END parsed_date = parse_date(update.message.text) if not parsed_date: await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。") return ADD_NEXT_DUE sub_data['next_due'] = parsed_date keyboard = [ [InlineKeyboardButton("天", callback_data='freq_unit_day'), InlineKeyboardButton("周", callback_data='freq_unit_week')], [InlineKeyboardButton("月", callback_data='freq_unit_month'), InlineKeyboardButton("年", callback_data='freq_unit_year')] ] await update.message.reply_text("第六步:请选择付款周期的单位", reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='HTML') return ADD_FREQ_UNIT async def add_freq_unit_received(update: Update, context: CallbackContext): sub_data, _ = _get_new_sub_data_or_end(update, context) query = update.callback_query await query.answer() if sub_data is None: await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END unit = query.data.partition('freq_unit_')[2] if unit not in VALID_FREQ_UNITS: await query.edit_message_text("错误:无效的周期单位,请重试。") return ConversationHandler.END sub_data['unit'] = unit await query.edit_message_text("第七步:请输入周期的数量(例如:每3个月,输入 3)", parse_mode='Markdown') return ADD_FREQ_VALUE async def add_freq_value_received(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END try: value = int(update.message.text) if value <= 0: raise ValueError sub_data['value'] = value except (ValueError, TypeError): await update.message.reply_text("请输入一个有效的正整数。") return ADD_FREQ_VALUE keyboard = [ [InlineKeyboardButton("自动续费", callback_data='renewal_auto'), InlineKeyboardButton("手动续费", callback_data='renewal_manual')] ] await update.message.reply_text("第八步:请选择 续费方式", reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='HTML') return ADD_RENEWAL_TYPE async def add_renewal_type_received(update: Update, context: CallbackContext): sub_data, _ = _get_new_sub_data_or_end(update, context) query = update.callback_query await query.answer() if sub_data is None: await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END renewal_type = query.data.partition('renewal_')[2] if renewal_type not in VALID_RENEWAL_TYPES: await query.edit_message_text("错误:无效的续费类型,请重试。") return ConversationHandler.END sub_data['renewal_type'] = renewal_type await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)") return ADD_NOTES async def add_notes_received(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: _clear_action_state(context, ['new_sub_data']) if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END note = update.message.text.strip() if len(note) > MAX_NOTES_LEN: await update.message.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。") return ADD_NOTES sub_data['notes'] = note if note else None save_subscription(update.effective_user.id, sub_data) await update.message.reply_text(text=f"✅ 订阅 '{escape_html(sub_data.get('name', ''))}' 已添加!", parse_mode='HTML') _clear_action_state(context, ['new_sub_data']) return ConversationHandler.END async def skip_notes(update: Update, context: CallbackContext): sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context) if sub_data is None: _clear_action_state(context, ['new_sub_data']) if err_msg_obj: await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。") return ConversationHandler.END sub_data['notes'] = None save_subscription(update.effective_user.id, sub_data) await update.message.reply_text(text=f"✅ 订阅 '{escape_html(sub_data.get('name', ''))}' 已添加!", parse_mode='HTML') _clear_action_state(context, ['new_sub_data']) return ConversationHandler.END def save_subscription(user_id, data): with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( user_id, data.get('name'), data.get('cost'), data.get('currency'), data.get('category'), data.get('next_due'), data.get('unit'), data.get('value'), data.get('renewal_type', 'auto'), data.get('notes') )) conn.commit() # --- List, View, Edit, Delete --- async def list_subs(update: Update, context: CallbackContext): user_id = update.effective_user.id keyboard = await get_subs_list_keyboard(user_id) if not keyboard: await update.message.reply_text("您还没有任何订阅。") return await update.message.reply_text("您的所有订阅:", reply_markup=keyboard) async def list_categories(update: Update, context: CallbackContext): user_id = update.effective_user.id with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT id, name FROM categories WHERE user_id = ? ORDER BY name", (user_id,)) categories = cursor.fetchall() if not categories: if update.callback_query: await update.callback_query.edit_message_text("您还没有任何分类。") else: await update.message.reply_text("您还没有任何分类。") return buttons = [] for cat in categories: cat_id, cat_name = cat[0], cat[1] buttons.append(InlineKeyboardButton(cat_name, callback_data=_build_category_callback_data(cat_id))) keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)] keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")]) if update.callback_query: await update.callback_query.edit_message_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard)) else: await update.message.reply_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard)) async def show_subscription_view(update: Update, context: CallbackContext, sub_id: int): user_id = update.effective_user.id with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) sub = cursor.fetchone() if not sub: logger.error(f"Subscription with id {sub_id} not found for user {user_id}") if update.effective_message: await update.effective_message.reply_text("错误:找不到该订阅。") return name, cost, currency, category, next_due, renewal_type, reminders_enabled, notes = ( sub['name'], sub['cost'], sub['currency'], sub['category'], sub['next_due'], sub['renewal_type'], sub['reminders_enabled'], sub['notes']) freq_text = format_frequency(sub['frequency_unit'], sub['frequency_value']) main_currency = get_user_main_currency(user_id) converted_cost = convert_currency(cost, currency, main_currency) safe_name, safe_category, safe_freq = escape_html(name), escape_html(category), escape_html(freq_text) cost_str, converted_cost_str = escape_html(f"{cost:.2f}"), escape_html(f"{converted_cost:.2f}") renewal_text = "手动续费" if renewal_type == 'manual' else "自动续费" reminder_status = "开启" if reminders_enabled else "关闭" text = (f"*订阅详情: {safe_name}*\n\n" f"\\- *费用*: `{cost_str} {currency.upper()}` \\(\\~`{converted_cost_str} {main_currency.upper()}`\\)\n" f"\\- *类别*: `{safe_category}`\n" f"\\- *下次付款*: `{next_due}` \\(周期: {safe_freq}\\)\n" f"\\- *续费方式*: `{renewal_text}`\n" f"\\- *提醒状态*: `{reminder_status}`") if notes: text += f"\n\\- *备注*: {escape_html(notes)}" keyboard_buttons = [ [InlineKeyboardButton("✏️ 编辑", callback_data=f'edit_{sub_id}'), InlineKeyboardButton("🗑️ 删除", callback_data=f'delete_{sub_id}')], [InlineKeyboardButton("🔔 提醒设置", callback_data=f'remind_{sub_id}')] ] if renewal_type == 'manual': keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')]) if 'list_subs_in_category' in context.user_data: cat_filter = context.user_data['list_subs_in_category'] category_id = context.user_data.get('list_subs_in_category_id') if category_id: back_cb = _build_category_callback_data(category_id) else: back_cb = 'list_categories' keyboard_buttons.append([InlineKeyboardButton("« 返回分类订阅", callback_data=back_cb)]) else: keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')]) logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}") if update.callback_query: await update.callback_query.edit_message_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons), parse_mode='HTML') elif update.effective_message: await update.effective_message.reply_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons), parse_mode='HTML') async def button_callback_handler(update: Update, context: CallbackContext): query = update.callback_query await query.answer() data = query.data user_id = query.from_user.id logger.debug(f"Received callback query: {data} from user {user_id}") if data.startswith(CATEGORY_CB_PREFIX): category_id = _parse_category_id_from_callback(data) if not category_id: await query.edit_message_text("错误:无效或已过期的分类,请重新选择。") return with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT name FROM categories WHERE id = ? AND user_id = ?", (category_id, user_id)) row = cursor.fetchone() if not row: await query.edit_message_text("错误:分类不存在或无权限。") return category = row['name'] context.user_data['list_subs_in_category'] = category context.user_data['list_subs_in_category_id'] = category_id keyboard = await get_subs_list_keyboard(user_id, category_filter=category) msg_text = f"分类 {escape_html(category)} 下的订阅:" if not keyboard: msg_text = f"分类 {escape_html(category)} 下没有订阅。" keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]]) await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='HTML') return if data == 'list_categories': context.user_data.pop('list_subs_in_category', None) context.user_data.pop('list_subs_in_category_id', None) await list_categories(update, context) return if data == 'list_all_subs': context.user_data.pop('list_subs_in_category', None) context.user_data.pop('list_subs_in_category_id', None) keyboard = await get_subs_list_keyboard(user_id) if not keyboard: await query.edit_message_text("您还没有任何订阅。") return await query.edit_message_text("您的所有订阅:", reply_markup=keyboard) return action, _, sub_id_str = data.partition('_') sub_id = int(sub_id_str) if sub_id_str.isdigit() else None if not sub_id: logger.error(f"Invalid sub_id in callback data: {data}") await query.edit_message_text("错误:无效的订阅ID。") return if action == 'view': await show_subscription_view(update, context, sub_id) elif action == 'renewmanual': with get_db_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id) ) sub = cursor.fetchone() if sub: today = datetime.date.today() new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value']) if new_due_date: new_date_str = new_due_date.strftime('%Y-%m-%d') cursor.execute( "UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?", (new_date_str, sub_id, user_id) ) conn.commit() await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True) await show_subscription_view(update, context, sub_id) else: await query.answer("续费失败:无法计算新的到期日期。", show_alert=True) else: await query.answer("续费失败:订阅不存在或无权限。", show_alert=True) elif action == 'renewfromremind': with get_db_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id) ) sub = cursor.fetchone() if sub: today = datetime.date.today() new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value']) if new_due_date: new_date_str = new_due_date.strftime('%Y-%m-%d') cursor.execute( "UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?", (new_date_str, sub_id, user_id) ) conn.commit() safe_sub_name = escape_html(sub['name']) await query.edit_message_text( text=f"✅ 续费成功\n\n您的订阅 {safe_sub_name} 新的到期日为: {new_date_str}", parse_mode='HTML', reply_markup=None ) else: await query.answer("续费失败:无法计算新的到期日期。", show_alert=True) else: await query.answer("续费失败:此订阅可能已被删除或无权限。", show_alert=True) await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅不存在或无权限)*", parse_mode='HTML', reply_markup=None) elif action == 'delete': with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) if not cursor.fetchone(): await query.answer("错误:找不到该订阅或无权限。", show_alert=True) return keyboard = InlineKeyboardMarkup([ [InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'), InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')] ]) await query.edit_message_text(text="您确定要删除这个订阅吗?", reply_markup=keyboard) elif action == 'confirmdelete': with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) deleted = cursor.rowcount conn.commit() if deleted == 0: await query.answer("错误:找不到该订阅或无权限。", show_alert=True) return await query.answer("订阅已删除") if 'list_subs_in_category' in context.user_data: category = context.user_data['list_subs_in_category'] keyboard = await get_subs_list_keyboard(user_id, category_filter=category) msg_text = f"分类 {escape_html(category)} 下的订阅:" if not keyboard: msg_text = f"分类 {escape_html(category)} 下没有订阅。" keyboard = InlineKeyboardMarkup( [[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]]) await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='HTML') else: keyboard = await get_subs_list_keyboard(user_id) if not keyboard: await query.edit_message_text("您还没有任何订阅。") else: await query.edit_message_text("您的所有订阅:", reply_markup=keyboard) # --- 【新增】包装函数,用于在会话中处理“返回”按钮 --- async def fallback_view_button(update: Update, context: CallbackContext): """ 在会话的 fallback 中调用,处理 view_... 按钮的点击。 它会先显示订阅详情,然后明确地结束当前会话。 """ # 先执行通用的按钮处理逻辑来显示界面 await button_callback_handler(update, context) # 然后返回 END,以确保当前会话(如编辑、提醒设置)被正确终止 return ConversationHandler.END # --- Edit Subscription Conversation --- async def edit_start(update: Update, context: CallbackContext): query = update.callback_query await query.answer() sub_id_str = query.data.split('_')[1] user_id = query.from_user.id if not sub_id_str.isdigit(): await query.edit_message_text("错误:无效的订阅ID。") return ConversationHandler.END sub_id = int(sub_id_str) with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) if not cursor.fetchone(): await query.edit_message_text("错误:找不到该订阅或无权限。") return ConversationHandler.END logger.debug(f"Starting edit for sub_id: {sub_id}") context.user_data['sub_id_for_action'] = sub_id keyboard = [ [InlineKeyboardButton("名称", callback_data="editfield_name"), InlineKeyboardButton("费用", callback_data="editfield_cost")], [InlineKeyboardButton("货币", callback_data="editfield_currency"), InlineKeyboardButton("类别", callback_data="editfield_category")], [InlineKeyboardButton("下次付款日", callback_data="editfield_next_due"), InlineKeyboardButton("周期", callback_data="editfield_frequency")], [InlineKeyboardButton("续费方式", callback_data="editfield_renewal_type"), InlineKeyboardButton("📝 备注", callback_data="editfield_notes")], [InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')] ] await query.edit_message_text("请选择您想编辑的字段:", reply_markup=InlineKeyboardMarkup(keyboard)) return EDIT_SELECT_FIELD async def edit_field_selected(update: Update, context: CallbackContext): query = update.callback_query await query.answer() field_to_edit = query.data.partition('_')[2] context.user_data['field_to_edit'] = field_to_edit if field_to_edit == 'renewal_type': keyboard = [ [InlineKeyboardButton("自动续费", callback_data='editvalue_auto'), InlineKeyboardButton("手动续费", callback_data='editvalue_manual')] ] await query.edit_message_text("请选择新的续费方式:", reply_markup=InlineKeyboardMarkup(keyboard)) return EDIT_GET_NEW_VALUE if field_to_edit == 'frequency': keyboard = [ [InlineKeyboardButton("天", callback_data='freq_unit_day'), InlineKeyboardButton("周", callback_data='freq_unit_week')], [InlineKeyboardButton("月", callback_data='freq_unit_month'), InlineKeyboardButton("年", callback_data='freq_unit_year')] ] await query.edit_message_text("请选择新的周期单位", reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='HTML') return EDIT_FREQ_UNIT else: field_map = {'name': '名称', 'cost': '费用', 'currency': '货币', 'category': '类别', 'next_due': '下次付款日', 'notes': '备注'} prompt = f"好的,请输入新的 {field_map.get(field_to_edit, field_to_edit)} 值:" if field_to_edit == 'notes': prompt += "\n(如需清空备注,请输入 /empty )" await query.edit_message_text(prompt, parse_mode='HTML') return EDIT_GET_NEW_VALUE async def edit_freq_unit_received(update: Update, context: CallbackContext): query = update.callback_query await query.answer() unit = query.data.partition('freq_unit_')[2] if unit not in VALID_FREQ_UNITS: await query.edit_message_text("错误:无效的周期单位,请重试。") return ConversationHandler.END context.user_data['new_freq_unit'] = unit await query.edit_message_text("好的,现在请输入新的周期数量。", parse_mode='HTML') return EDIT_FREQ_VALUE async def edit_freq_value_received(update: Update, context: CallbackContext): user_id = update.effective_user.id try: value = int(update.message.text) if value <= 0: raise ValueError except (ValueError, TypeError): await update.message.reply_text("请输入一个有效的正整数。") return EDIT_FREQ_VALUE unit = context.user_data.get('new_freq_unit') try: sub_id = int(context.user_data.get('sub_id_for_action')) except (ValueError, TypeError): await update.message.reply_text("错误:会话已过期,请重试。") return ConversationHandler.END with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?", (unit, value, sub_id, user_id)) if cursor.rowcount == 0: await update.message.reply_text("错误:找不到该订阅或无权限。") return ConversationHandler.END conn.commit() await update.message.reply_text("✅ 周期已更新!") _clear_action_state(context, ['sub_id_for_action', 'new_freq_unit', 'field_to_edit']) await show_subscription_view(update, context, sub_id) return ConversationHandler.END async def edit_new_value_received(update: Update, context: CallbackContext): user_id = update.effective_user.id field = context.user_data.get('field_to_edit') try: sub_id = int(context.user_data.get('sub_id_for_action')) except (ValueError, TypeError): if update.effective_message: await update.effective_message.reply_text("错误:无效的订阅ID。") return ConversationHandler.END if not field: if update.effective_message: await update.effective_message.reply_text("错误:未选择要编辑的字段。") return ConversationHandler.END db_field = EDITABLE_SUB_FIELDS.get(field) if not db_field or not db_field.isidentifier(): if update.effective_message: await update.effective_message.reply_text("错误:不允许编辑该字段。") logger.warning(f"Blocked unsafe field update attempt: {field}") return ConversationHandler.END query, new_value = update.callback_query, "" message_to_reply = update.effective_message if update.message and update.message.text == '/empty' and field == 'notes': new_value = None elif query: new_value = query.data.split('_')[1] elif update.message: new_value = update.message.text else: if message_to_reply: await message_to_reply.reply_text("错误:未提供新值。") return ConversationHandler.END validation_failed = False if field == 'cost': try: new_value = float(new_value) if new_value < 0: raise ValueError("费用不能为负数") except (ValueError, TypeError): if message_to_reply: await message_to_reply.reply_text("费用必须是有效的非负数字。") validation_failed = True elif field == 'name': new_value = str(new_value).strip() if not new_value: if message_to_reply: await message_to_reply.reply_text("名称不能为空。") validation_failed = True elif len(new_value) > MAX_NAME_LEN: if message_to_reply: await message_to_reply.reply_text(f"名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。") validation_failed = True elif field == 'currency': new_value = str(new_value).upper() if not (len(new_value) == 3 and new_value.isalpha()): if message_to_reply: await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。") validation_failed = True elif field == 'next_due': parsed = parse_date(str(new_value)) if not parsed: if message_to_reply: await message_to_reply.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。") validation_failed = True else: new_value = parsed elif field == 'renewal_type': if str(new_value) not in VALID_RENEWAL_TYPES: if message_to_reply: await message_to_reply.reply_text("续费方式只能为 auto 或 manual。") validation_failed = True elif field == 'notes': note_val = str(new_value).strip() if note_val and len(note_val) > MAX_NOTES_LEN: if message_to_reply: await message_to_reply.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。") validation_failed = True else: new_value = note_val if note_val else None elif field == 'category': new_value = str(new_value).strip() if not new_value: if message_to_reply: await message_to_reply.reply_text("类别不能为空。") validation_failed = True elif len(new_value) > MAX_CATEGORY_LEN: if message_to_reply: await message_to_reply.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。") validation_failed = True else: with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, new_value)) conn.commit() if validation_failed: return EDIT_GET_NEW_VALUE with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(f"UPDATE subscriptions SET {db_field} = ? WHERE id = ? AND user_id = ?", (new_value, sub_id, user_id)) if cursor.rowcount == 0: if message_to_reply: await message_to_reply.reply_text("错误:找不到该订阅或无权限。") return ConversationHandler.END conn.commit() if query: await query.answer("✅ 字段已更新!") elif message_to_reply: await message_to_reply.reply_text("✅ 字段已更新!") _clear_action_state(context, ['sub_id_for_action', 'field_to_edit', 'new_freq_unit']) await show_subscription_view(update, context, sub_id) return ConversationHandler.END # --- Reminder Settings Conversation --- async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int): user_id = query.from_user.id with get_db_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days " "FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id) ) sub = cursor.fetchone() if not sub: await query.edit_message_text("错误:找不到该订阅或无权限。") return enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒" due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒" keyboard = [ [InlineKeyboardButton(enabled_text, callback_data='remindaction_toggle_enabled')], [InlineKeyboardButton(due_date_text, callback_data='remindaction_toggle_due_date')] ] safe_name = escape_html(sub['name']) current_status = f"🔔 提醒设置: {safe_name}\n\n" if sub['renewal_type'] == 'manual': current_status += f"当前提前提醒: *{sub['reminder_days']}天*\n" keyboard.append([InlineKeyboardButton("⚙️ 更改提前天数", callback_data='remindaction_ask_days')]) keyboard.append([InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')]) await query.edit_message_text(current_status, reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='HTML') async def remind_settings_start(update: Update, context: CallbackContext): query = update.callback_query await query.answer() sub_id_str = query.data.partition('_')[2] user_id = query.from_user.id if not sub_id_str.isdigit(): await query.edit_message_text("错误:无效的订阅ID。") return ConversationHandler.END sub_id = int(sub_id_str) with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) if not cursor.fetchone(): await query.edit_message_text("错误:找不到该订阅或无权限。") return ConversationHandler.END logger.debug(f"Starting reminder settings for sub_id: {sub_id}") context.user_data['sub_id_for_action'] = sub_id await _display_reminder_settings(query, context, sub_id) return REMIND_SELECT_ACTION async def remind_action_handler(update: Update, context: CallbackContext): query = update.callback_query await query.answer() if not query.data: return REMIND_SELECT_ACTION action = query.data.partition('remindaction_')[2] sub_id = context.user_data.get('sub_id_for_action') if not sub_id: await query.edit_message_text("错误:会话已过期,请重试。") return ConversationHandler.END user_id = query.from_user.id if action == 'ask_days': await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)") return REMIND_GET_DAYS if action not in ['toggle_enabled', 'toggle_due_date']: logger.warning(f"Unexpected action '{query.data}' in remind_action_handler") return REMIND_SELECT_ACTION with get_db_connection() as conn: cursor = conn.cursor() if action == 'toggle_enabled': cursor.execute( "UPDATE subscriptions SET reminders_enabled = CASE WHEN reminders_enabled THEN 0 ELSE 1 END " "WHERE id = ? AND user_id = ?", (sub_id, user_id) ) elif action == 'toggle_due_date': cursor.execute( "UPDATE subscriptions SET reminder_on_due_date = CASE WHEN reminder_on_due_date THEN 0 ELSE 1 END " "WHERE id = ? AND user_id = ?", (sub_id, user_id) ) if cursor.rowcount == 0: await query.edit_message_text("错误:找不到该订阅或无权限。") return ConversationHandler.END conn.commit() await _display_reminder_settings(query, context, sub_id) return REMIND_SELECT_ACTION async def remind_days_received(update: Update, context: CallbackContext): sub_id = context.user_data.get('sub_id_for_action') if not sub_id: await update.message.reply_text("错误:会话已过期,请重试。") return ConversationHandler.END user_id = update.effective_user.id try: days = int(update.message.text) if days < 0: raise ValueError with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ? AND user_id = ?", (days, sub_id, user_id)) if cursor.rowcount == 0: await update.message.reply_text("错误:找不到该订阅或无权限。") return ConversationHandler.END conn.commit() await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。") _clear_action_state(context, ['sub_id_for_action']) await show_subscription_view(update, context, sub_id) except (ValueError, TypeError): await update.message.reply_text("请输入一个有效的非负整数。") return REMIND_GET_DAYS return ConversationHandler.END # --- Other Commands --- async def set_currency(update: Update, context: CallbackContext): user_id, args = update.effective_user.id, context.args if len(args) != 1: await update.message.reply_text("用法: /set_currency 代码(例如 /set_currency USD)", parse_mode='HTML') return new_currency = args[0].upper() if len(new_currency) != 3 or not new_currency.isalpha(): await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。") return with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO users (user_id, main_currency) VALUES (?, ?) ON CONFLICT(user_id) DO UPDATE SET main_currency = excluded.main_currency """, (user_id, new_currency)) conn.commit() await update.message.reply_text(f"您的主货币已设为 {escape_html(new_currency)}。", parse_mode='HTML') return ConversationHandler.END async def cancel(update: Update, context: CallbackContext): _clear_action_state(context, ['new_sub_data', 'sub_id_for_action', 'field_to_edit', 'new_freq_unit']) if update.callback_query: await update.callback_query.answer() await update.callback_query.edit_message_text('操作已取消。') else: await update.message.reply_text('操作已取消。') return ConversationHandler.END def _can_run_update(user_id: int) -> bool: """仅允许指定 owner 执行自动更新。未配置 owner 时默认拒绝。""" if not UPDATE_OWNER_ID: return False try: return int(UPDATE_OWNER_ID) == int(user_id) except (ValueError, TypeError): return False async def update_bot(update: Update, context: CallbackContext): user_id = update.effective_user.id if not _can_run_update(user_id): await update.message.reply_text("无权限执行 /update。") return await update.message.reply_text("开始检查更新,请稍候…") repo_dir = os.path.dirname(os.path.abspath(__file__)) try: fetch_cmd = ["git", "fetch", AUTO_UPDATE_REMOTE, AUTO_UPDATE_BRANCH] fetch_proc = subprocess.run(fetch_cmd, cwd=repo_dir, capture_output=True, text=True) if fetch_proc.returncode != 0: err = (fetch_proc.stderr or fetch_proc.stdout or "未知错误").strip() await update.message.reply_text(f"更新失败(fetch):\n{escape_html(err)}", parse_mode='HTML') return local_rev = subprocess.run( ["git", "rev-parse", "HEAD"], cwd=repo_dir, capture_output=True, text=True ) remote_rev = subprocess.run( ["git", "rev-parse", f"{AUTO_UPDATE_REMOTE}/{AUTO_UPDATE_BRANCH}"], cwd=repo_dir, capture_output=True, text=True ) if local_rev.returncode != 0 or remote_rev.returncode != 0: await update.message.reply_text("更新失败:无法读取当前版本。") return local_hash = local_rev.stdout.strip() remote_hash = remote_rev.stdout.strip() if local_hash == remote_hash: await update.message.reply_text("当前已是最新版本,无需更新。") return reset_proc = subprocess.run( ["git", "reset", "--hard", f"{AUTO_UPDATE_REMOTE}/{AUTO_UPDATE_BRANCH}"], cwd=repo_dir, capture_output=True, text=True ) if reset_proc.returncode != 0: err = (reset_proc.stderr or reset_proc.stdout or "未知错误").strip() await update.message.reply_text(f"更新失败(reset):\n{escape_html(err)}", parse_mode='HTML') return pip_proc = subprocess.run( [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=repo_dir, capture_output=True, text=True ) if pip_proc.returncode != 0: err = (pip_proc.stderr or pip_proc.stdout or "未知错误").strip() await update.message.reply_text(f"依赖安装失败:\n{escape_html(err[-1800:])}", parse_mode='HTML') return await update.message.reply_text("更新完成,正在重启机器人…") os.execv(sys.executable, [sys.executable] + sys.argv) except Exception as e: logger.error(f"/update failed: {e}") await update.message.reply_text(f"更新异常:{escape_html(str(e))}", parse_mode='HTML') # --- Main --- def main(): if not TELEGRAM_TOKEN: logger.critical("TELEGRAM_TOKEN 环境变量未设置!") return if not EXCHANGE_API_KEY: logger.info("未配置 EXCHANGE_API_KEY,多货币换算将降级为只使用本地缓存(若无缓存则不转换)。") application = Application.builder().token(TELEGRAM_TOKEN).build() async def post_init(app: Application): try: bot_info = await app.bot.get_me() logger.info(f"TELEGRAM_TOKEN 验证成功: {bot_info.username}") except TelegramError as e: logger.critical(f"TELEGRAM_TOKEN 无效或无法连接 Telegram API: {e}") raise SystemExit commands = [ BotCommand("start", "🚀 开始使用"), BotCommand("add_sub", "➕ 添加新订阅"), BotCommand("list_subs", "📋 列出所有订阅"), BotCommand("list_categories", "🗂️ 按分类浏览"), BotCommand("stats", "📊 查看订阅统计"), BotCommand("import", "📥 导入订阅"), BotCommand("export", "📤 导出订阅"), BotCommand("set_currency", "💲 设置主货币"), BotCommand("update", "🛠️ 拉取最新代码并重启"), BotCommand("help", "ℹ️ 获取帮助"), BotCommand("cancel", "❌ 取消当前操作") ] try: await app.bot.delete_my_commands() logger.debug("Cleared existing bot commands") await app.bot.set_my_commands(commands) logger.info("Bot commands registered successfully") except TelegramError as e: logger.error(f"Failed to register bot commands: {e}") app.job_queue.run_daily( check_and_send_reminders, time=datetime.time(hour=9, minute=0, tzinfo=datetime.timezone(datetime.timedelta(hours=8))), name='daily_reminders' ) logger.info("Daily reminder job scheduled.") application.post_init = post_init add_conv = ConversationHandler( entry_points=[CommandHandler('add_sub', add_sub_start)], states={ ADD_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_name_received)], ADD_COST: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_cost_received)], ADD_CURRENCY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_currency_received)], ADD_CATEGORY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_category_received)], ADD_NEXT_DUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_next_due_received)], ADD_FREQ_UNIT: [CallbackQueryHandler(add_freq_unit_received, pattern='^freq_unit_')], ADD_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_freq_value_received)], ADD_RENEWAL_TYPE: [CallbackQueryHandler(add_renewal_type_received, pattern='^renewal_')], ADD_NOTES: [ MessageHandler(filters.TEXT & ~filters.COMMAND, add_notes_received), CommandHandler('skip', skip_notes) ], }, fallbacks=[CommandHandler('cancel', cancel)] ) edit_conv = ConversationHandler( entry_points=[CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$')], states={ EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')], EDIT_GET_NEW_VALUE: [ MessageHandler(filters.TEXT & ~filters.COMMAND, edit_new_value_received), CallbackQueryHandler(edit_new_value_received, pattern='^editvalue_'), CommandHandler('empty', edit_new_value_received) ], EDIT_FREQ_UNIT: [CallbackQueryHandler(edit_freq_unit_received, pattern='^freq_unit_')], EDIT_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, edit_freq_value_received)], }, fallbacks=[ CommandHandler('cancel', cancel), # 【修改】使用新的包装函数来确保会话能正确结束 CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'), CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'), CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$') ], per_message=False ) remind_conv = ConversationHandler( entry_points=[CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')], states={ REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')], REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)], }, fallbacks=[ CommandHandler('cancel', cancel), # 【修改】使用新的包装函数来确保会话能正确结束 CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'), CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'), CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$') ], per_message=False ) import_conv = ConversationHandler( entry_points=[CommandHandler('import', import_start)], states={ IMPORT_UPLOAD: [MessageHandler(filters.Document.ALL, import_upload_received)], }, fallbacks=[CommandHandler('cancel', cancel)] ) button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_id_\d+|list_categories|list_all_subs)$' application.add_handler(CommandHandler('start', start)) application.add_handler(CommandHandler('help', help_command)) application.add_handler(CommandHandler('list_subs', list_subs)) application.add_handler(CommandHandler('list_categories', list_categories)) application.add_handler(CommandHandler('set_currency', set_currency)) application.add_handler(CommandHandler('stats', stats)) application.add_handler(CommandHandler('export', export_command)) application.add_handler(CommandHandler('update', update_bot)) application.add_handler(CommandHandler('cancel', cancel)) application.add_handler(add_conv) application.add_handler(edit_conv) application.add_handler(remind_conv) application.add_handler(import_conv) application.add_handler(CallbackQueryHandler(button_callback_handler, pattern=button_pattern)) logger.info(f"{PROJECT_NAME} Bot is starting...") application.run_polling() if __name__ == '__main__': update_past_due_dates() main()