From 885faaa5246074ebd52247280b669ee3ef7a8463 Mon Sep 17 00:00:00 2001 From: zkysimon Date: Mon, 8 Dec 2025 09:30:12 +0800 Subject: [PATCH] Add files via upload --- SubMind.py | 1325 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1325 insertions(+) create mode 100644 SubMind.py diff --git a/SubMind.py b/SubMind.py new file mode 100644 index 0000000..cb12870 --- /dev/null +++ b/SubMind.py @@ -0,0 +1,1325 @@ +import sqlite3 +import os +import requests +import datetime +import dateparser +import logging +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm +import re +from dotenv import load_dotenv +from dateutil.relativedelta import relativedelta +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, BotCommand, CallbackQuery +from telegram.ext import ( + Application, CommandHandler, MessageHandler, filters, + CallbackContext, CallbackQueryHandler, ConversationHandler +) +from telegram.error import TelegramError +from telegram.helpers import escape_markdown + +# --- 加载 .env 和设置 --- +load_dotenv() + +# --- 日志配置 --- +logging.basicConfig( + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + level=logging.INFO +) +logging.getLogger("httpx").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + +# --- 环境变量和项目配置 --- +TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN') +EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY') +PROJECT_NAME = "SubMind" +DB_FILE = 'submind.db' + +# --- 对话处理器状态 --- +(ADD_NAME, ADD_COST, ADD_CURRENCY, ADD_CATEGORY, ADD_NEXT_DUE, + ADD_FREQ_UNIT, ADD_FREQ_VALUE, ADD_RENEWAL_TYPE, ADD_NOTES) = range(9) +(EDIT_SELECT_FIELD, EDIT_GET_NEW_VALUE, EDIT_FREQ_UNIT, EDIT_FREQ_VALUE) = range(4) +(REMIND_SELECT_ACTION, REMIND_GET_DAYS) = range(2) +(IMPORT_UPLOAD,) = range(1) + + +# --- 辅助函数:数据库连接管理 --- +def get_db_connection(): + conn = sqlite3.connect(DB_FILE) + conn.row_factory = sqlite3.Row + return conn + + +# --- 字体管理 --- +def get_chinese_font(): + font_name = 'SourceHanSansSC-Regular.otf' + font_path = os.path.join('fonts', font_name) + + if os.path.exists(font_path): + logger.debug(f"Found font at {font_path}") + return fm.FontProperties(fname=font_path) + + logger.info(f"Font '{font_name}' not found. Attempting to download...") + os.makedirs('fonts', exist_ok=True) + + url = 'https://github.com/wweir/source-han-sans-sc/raw/refs/heads/master/SourceHanSansSC-Regular.otf' + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + + try: + response = requests.get(url, stream=True, headers=headers, timeout=10) + response.raise_for_status() + with open(font_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logger.info(f"Font '{font_name}' downloaded successfully to '{font_path}'.") + fm._load_fontmanager(try_read_cache=False) + return fm.FontProperties(fname=font_path) + except requests.exceptions.RequestException as e: + logger.error(f"Failed to download font. Error: {e}") + return fm.FontProperties(family='sans-serif') + + +# --- 数据库初始化与迁移 --- +def init_db(): + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + CREATE TABLE IF NOT EXISTS subscriptions ( + id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, cost REAL, currency TEXT, + category TEXT, next_due DATE, frequency TEXT, + renewal_type TEXT DEFAULT 'auto', + reminders_enabled BOOLEAN DEFAULT TRUE, + reminder_days INTEGER DEFAULT 3, + reminder_on_due_date BOOLEAN DEFAULT TRUE, + frequency_unit TEXT, + frequency_value INTEGER, + notes TEXT + ) + ''') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_subscriptions_user_id ON subscriptions(user_id)') + cursor.execute("PRAGMA table_info(subscriptions)") + columns = [info[1] for info in cursor.fetchall()] + if 'frequency_unit' not in columns: + cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_unit TEXT") + if 'frequency_value' not in columns: + cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_value INTEGER") + if 'reminders_enabled' not in columns: + cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminders_enabled BOOLEAN DEFAULT TRUE") + if 'reminder_days' not in columns: + cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_days INTEGER DEFAULT 3") + if 'reminder_on_due_date' not in columns: + cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_on_due_date BOOLEAN DEFAULT TRUE") + if 'notes' not in columns: + cursor.execute("ALTER TABLE subscriptions ADD COLUMN notes TEXT") + + cursor.execute(''' + CREATE TABLE IF NOT EXISTS categories ( + id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, UNIQUE(user_id, name) + ) + ''') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_categories_user_id ON categories(user_id)') + cursor.execute(''' + CREATE TABLE IF NOT EXISTS users ( + user_id INTEGER PRIMARY KEY, main_currency TEXT DEFAULT "USD", language TEXT DEFAULT "en" + ) + ''') + cursor.execute(''' + CREATE TABLE IF NOT EXISTS exchange_rates ( + from_currency TEXT, to_currency TEXT, rate REAL, last_updated TIMESTAMP, + PRIMARY KEY (from_currency, to_currency) + ) + ''') + migrate_frequency_data(conn, cursor) + conn.commit() + + +def migrate_frequency_data(conn, cursor): + cursor.execute("SELECT id, frequency FROM subscriptions WHERE frequency IS NOT NULL AND frequency_unit IS NULL") + subs_to_migrate = cursor.fetchall() + if not subs_to_migrate: + return + freq_map = { + 'daily': ('day', 1), 'weekly': ('week', 1), '周付': ('week', 1), 'monthly': ('month', 1), + '月付': ('month', 1), 'quarterly': ('month', 3), '季付': ('month', 3), '半年': ('month', 6), + 'half-year': ('month', 6), 'biannually': ('month', 6), 'yearly': ('year', 1), '年付': ('year', 1) + } + for sub_id, freq_str in subs_to_migrate: + unit, value = freq_map.get(str(freq_str).lower(), (None, None)) + if unit and value: + cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ?", + (unit, value, sub_id)) + conn.commit() + + +init_db() + + +# --- 辅助函数 --- +def get_user_main_currency(user_id): + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute('SELECT main_currency FROM users WHERE user_id = ?', (user_id,)) + result = cursor.fetchone() + return result['main_currency'] if result else 'USD' + + +def convert_currency(amount, from_curr, to_curr): + if from_curr.upper() == to_curr.upper(): + return amount + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT rate, last_updated FROM exchange_rates + WHERE from_currency = ? AND to_currency = ? + ''', (from_curr.upper(), to_curr.upper())) + result = cursor.fetchone() + now = datetime.datetime.now() + cache_validity = datetime.timedelta(hours=24) + if result and (now - datetime.datetime.fromisoformat(result['last_updated'])) < cache_validity: + logger.debug(f"Using cached exchange rate for {from_curr} to {to_curr}: {result['rate']}") + return amount * result['rate'] + + if not EXCHANGE_API_KEY: + logger.warning("EXCHANGE_API_KEY not set, returning original amount") + return amount + try: + url = f"https://v6.exchangerate-api.com/v6/{EXCHANGE_API_KEY}/pair/{from_curr}/{to_curr}/{amount}" + response = requests.get(url, timeout=5) + response.raise_for_status() + data = response.json() + rate = data.get('conversion_rate', 1.0) + cursor.execute(''' + INSERT OR REPLACE INTO exchange_rates (from_currency, to_currency, rate, last_updated) + VALUES (?, ?, ?, ?) + ''', (from_curr.upper(), to_curr.upper(), rate, now.isoformat())) + conn.commit() + logger.debug(f"Updated exchange rate cache for {from_curr} to {to_curr}: {rate}") + return amount * rate + except requests.exceptions.RequestException as e: + logger.error(f"Currency conversion API error: {e}") + if result: + logger.warning(f"Falling back to cached rate: {result['rate']}") + return amount * result['rate'] + logger.warning("No cached rate available, returning original amount") + return amount + + +def parse_date(date_string: str) -> str: + today = datetime.datetime.now() + try: + dt = dateparser.parse(date_string, languages=['en', 'zh']) + if not dt: + return None + has_year_info = any(c in date_string for c in ['年', '/']) or (re.search(r'\d{4}', date_string) is not None) + if not has_year_info and dt.date() < today.date(): + dt = dt.replace(year=dt.year + 1) + return dt.strftime('%Y-%m-%d') + except Exception as e: + logger.error(f"Date parsing failed for string '{date_string}'. Error: {e}") + return None + + +def calculate_new_due_date(base_date, unit, value): + delta_map = { + 'day': relativedelta(days=+value), 'week': relativedelta(weeks=+value), + 'month': relativedelta(months=+value), 'year': relativedelta(years=+value) + } + delta = delta_map.get(str(unit).lower()) + return base_date + delta if delta else None + + +def format_frequency(unit, value) -> str: + if not unit or value is None: + return "未知" + unit_map = {'day': '天', 'week': '周', 'month': '个月', 'year': '年'} + if value == 1: + single_unit_map = {'day': '每天', 'week': '每周', 'month': '每月', 'year': '每年'} + return single_unit_map.get(unit, f"每 {value} {unit_map.get(unit, unit)}") + return f"每 {value} {unit_map.get(unit, unit)}" + + +async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup: + with get_db_connection() as conn: + cursor = conn.cursor() + query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id] + if category_filter: + query += "AND category = ? " + params.append(category_filter) + query += "ORDER BY next_due ASC" + cursor.execute(query, tuple(params)) + subs = cursor.fetchall() + if not subs: + return None + buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs] + keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)] + if category_filter: + keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]) + else: + keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')]) + return InlineKeyboardMarkup(keyboard) + + +# --- 自动任务 --- +def update_past_due_dates(): + today = datetime.date.today() + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM subscriptions WHERE next_due < ? AND renewal_type = 'auto'", (today,)) + past_due_subs = cursor.fetchall() + if not past_due_subs: + return + for sub in past_due_subs: + try: + last_due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date() + new_due_date = last_due_date + while new_due_date <= today: + calculated_date = calculate_new_due_date(new_due_date, sub['frequency_unit'], + sub['frequency_value']) + if calculated_date: + new_due_date = calculated_date + else: + break + if new_due_date > last_due_date: + cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", + (new_due_date.strftime('%Y-%m-%d'), sub['id'])) + except Exception as e: + logger.error(f"Failed to update subscription {sub['id']}: {e}") + conn.commit() + + +async def check_and_send_reminders(context: CallbackContext): + logger.info("Running job: Checking for subscription reminders...") + today = datetime.date.today() + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM subscriptions WHERE reminders_enabled = TRUE AND next_due IS NOT NULL") + subs_to_check = cursor.fetchall() + + for sub in subs_to_check: + try: + due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date() + user_id = sub['user_id'] + renewal_type = sub['renewal_type'] + safe_sub_name = escape_markdown(sub['name'], version=2) + + message = None + keyboard = None + + if renewal_type == 'manual': + keyboard = InlineKeyboardMarkup([ + [InlineKeyboardButton("✅ 我已续费", callback_data=f"renewfromremind_{sub['id']}")] + ]) + + if sub['reminder_on_due_date'] and due_date == today: + message = f"🔔 *订阅到期提醒*\n\n您的订阅 `{safe_sub_name}` 今天到期。" + if renewal_type == 'manual': + message += " 请记得手动续费。" + else: + message += " 将会自动续费。" + keyboard = None + + elif renewal_type == 'manual' and sub['reminder_days'] > 0: + reminder_date = due_date - datetime.timedelta(days=sub['reminder_days']) + if reminder_date == today: + days_left = (due_date - today).days + days_text = f"*{days_left}天后*" if days_left > 0 else "*今天*" + message = f"🔔 *订阅即将到期提醒*\n\n您的手动续费订阅 `{safe_sub_name}` 将在 {days_text} 到期。" + + if message: + await context.bot.send_message( + chat_id=user_id, + text=message, + parse_mode='MarkdownV2', + reply_markup=keyboard + ) + + except Exception as e: + logger.error(f"Failed to process reminder for sub_id {sub.get('id', 'N/A')}: {e}") + + +# --- 命令处理器 --- +async def start(update: Update, context: CallbackContext): + user_id = update.effective_user.id + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute('INSERT OR IGNORE INTO users (user_id) VALUES (?)', (user_id,)) + conn.commit() + await update.message.reply_text(f'欢迎使用 {escape_markdown(PROJECT_NAME, version=2)}!\n您的私人订阅智能管家。', + parse_mode='MarkdownV2') + + +async def help_command(update: Update, context: CallbackContext): + help_text = f""" +*{escape_markdown(PROJECT_NAME, version=2)} 命令列表* +*🌟 核心功能* +/add\_sub \- 引导您添加一个新的订阅 +/list\_subs \- 列出您的所有订阅 +/list\_categories \- 按分类浏览您的订阅 +*📊 数据管理* +/stats \- 查看按类别分类的订阅统计 +/import \- 通过上传 CSV 文件批量导入订阅 +/export \- 将您的所有订阅导出为 CSV 文件 +*⚙️ 个性化设置* +/set\_currency \`\` \- 设置您的主要货币 +/cancel \- 在任何流程中取消当前操作 +""" + await update.message.reply_text(help_text, parse_mode='MarkdownV2') + + +def make_autopct(values, currency_code): + currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'} + symbol = currency_symbols.get(currency_code.upper(), f'{currency_code} ') + + def my_autopct(pct): + total = sum(values) + val = float(pct * total / 100.0) + return f'{symbol}{val:.2f}\n({pct:.1f}%)' + + return my_autopct + + +async def stats(update: Update, context: CallbackContext): + user_id = update.effective_user.id + await update.message.reply_text("正在为您生成订阅统计图表,请稍候...") + + font_prop = get_chinese_font() + main_currency = get_user_main_currency(user_id) + with get_db_connection() as conn: + df = pd.read_sql_query("SELECT * FROM subscriptions WHERE user_id = ?", conn, params=(user_id,)) + if df.empty: + await update.message.reply_text("您还没有任何订阅数据。") + return + + df['converted_cost'] = df.apply(lambda row: convert_currency(row['cost'], row['currency'], main_currency), axis=1) + unit_to_days = {'day': 1, 'week': 7, 'month': 30.4375, 'year': 365.25} + + def normalize_to_monthly(row): + if pd.isna(row['frequency_unit']) or pd.isna(row['frequency_value']) or row['frequency_value'] == 0: + return 0 + total_days = row['frequency_value'] * unit_to_days.get(row['frequency_unit'], 0) + if total_days == 0: + return 0 + return (row['converted_cost'] / total_days) * 30.4375 + + df['monthly_cost'] = df.apply(normalize_to_monthly, axis=1) + category_costs = df.groupby('category')['monthly_cost'].sum().sort_values(ascending=False) + + if category_costs.empty or category_costs.sum() == 0: + await update.message.reply_text("您的订阅没有有效的费用信息。") + return + + plt.style.use('seaborn-v0_8-darkgrid') + fig, ax = plt.subplots(figsize=(12, 12)) + + autopct_function = make_autopct(category_costs.values, main_currency) + + wedges, texts, autotexts = ax.pie(category_costs.values, + labels=category_costs.index, + autopct=autopct_function, + startangle=140, + pctdistance=0.7, + labeldistance=1.05) + + ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20) + + for text in texts: + text.set_fontproperties(font_prop) + text.set_fontsize(22) + + for autotext in autotexts: + autotext.set_fontproperties(font_prop) + autotext.set_fontsize(20) + autotext.set_color('white') + + ax.axis('equal') + + fig.tight_layout() + image_path = f'stats_{user_id}.png' + plt.savefig(image_path) + plt.close(fig) + + try: + with open(image_path, 'rb') as photo: + await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。") + finally: + if os.path.exists(image_path): + os.remove(image_path) + + +# --- Import/Export Commands --- +async def export_command(update: Update, context: CallbackContext): + user_id = update.effective_user.id + with get_db_connection() as conn: + df = pd.read_sql_query( + "SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ?", + conn, params=(user_id,)) + if df.empty: + await update.message.reply_text("您还没有任何订阅数据,无法导出。") + return + + export_path = f'export_{user_id}.csv' + df.to_csv(export_path, index=False, encoding='utf-8-sig') + + try: + with open(export_path, 'rb') as file: + await update.message.reply_document(document=file, filename='subscriptions.csv', + caption="您的订阅数据已导出为 CSV 文件。") + finally: + if os.path.exists(export_path): + os.remove(export_path) + + +async def import_start(update: Update, context: CallbackContext): + await update.message.reply_text( + "请上传一个 CSV 文件以导入订阅数据。\n文件应包含以下列:name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes(notes 可为空)。") + return IMPORT_UPLOAD + + +async def import_upload_received(update: Update, context: CallbackContext): + user_id = update.effective_user.id + if not update.message.document or not update.message.document.file_name.endswith('.csv'): + await update.message.reply_text("请上传一个有效的 CSV 文件。") + return IMPORT_UPLOAD + + file = await update.message.document.get_file() + file_path = f'import_{user_id}.csv' + try: + await file.download_to_drive(file_path) + df = pd.read_csv(file_path, encoding='utf-8-sig') + required_columns = ['name', 'cost', 'currency', 'category', 'next_due', 'frequency_unit', 'frequency_value', + 'renewal_type'] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + await update.message.reply_text(f"CSV 文件缺少以下必要列:{', '.join(missing_columns)}") + return ConversationHandler.END + + valid_units = ['day', 'week', 'month', 'year'] + valid_renewal_types = ['auto', 'manual'] + records = [] + for _, row in df.iterrows(): + try: + cost = float(row['cost']) + if cost < 0: + raise ValueError("费用不能为负数") + currency = str(row['currency']).upper() + if not (len(currency) == 3 and currency.isalpha()): + raise ValueError(f"无效货币代码: {currency}") + next_due = parse_date(str(row['next_due'])) + if not next_due: + raise ValueError(f"无效日期格式: {row['next_due']}") + frequency_unit = str(row['frequency_unit']).lower() + if frequency_unit not in valid_units: + raise ValueError(f"无效周期单位: {frequency_unit}") + frequency_value = int(row['frequency_value']) + if frequency_value <= 0: + raise ValueError(f"无效周期数量: {frequency_value}") + renewal_type = str(row['renewal_type']).lower() + if renewal_type not in valid_renewal_types: + raise ValueError(f"无效续费类型: {renewal_type}") + notes = str(row['notes']) if pd.notna(row['notes']) else None + records.append(( + user_id, row['name'], cost, currency, row['category'], + next_due, frequency_unit, frequency_value, renewal_type, notes + )) + except Exception as e: + logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}") + await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误:{e}") + return ConversationHandler.END + + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.executemany(''' + INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', records) + for record in records: + cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, record[4])) + conn.commit() + + await update.message.reply_text(f"✅ 成功导入 {len(records)} 条订阅数据!") + except Exception as e: + logger.error(f"Import failed: {e}") + await update.message.reply_text(f"导入失败:{e}") + finally: + if os.path.exists(file_path): + os.remove(file_path) + return ConversationHandler.END + + +# --- Add Subscription Conversation --- +async def add_sub_start(update: Update, context: CallbackContext): + context.user_data['new_sub_data'] = {} + await update.message.reply_text("好的,我们来添加一个新订阅。\n\n第一步:请输入订阅的 *名称*", + parse_mode='MarkdownV2') + return ADD_NAME + + +async def add_name_received(update: Update, context: CallbackContext): + context.user_data['new_sub_data']['name'] = update.message.text + await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2') + return ADD_COST + + +async def add_cost_received(update: Update, context: CallbackContext): + try: + cost = float(update.message.text) + if cost < 0: + raise ValueError("费用不能为负数") + context.user_data['new_sub_data']['cost'] = cost + except (ValueError, TypeError): + await update.message.reply_text("费用必须是有效的非负数字。") + return ADD_COST + await update.message.reply_text("第三步:请输入 *货币* 代码(例如 USD, CNY)", parse_mode='MarkdownV2') + return ADD_CURRENCY + + +async def add_currency_received(update: Update, context: CallbackContext): + currency = update.message.text.upper() + if not (len(currency) == 3 and currency.isalpha()): + await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。") + return ADD_CURRENCY + context.user_data['new_sub_data']['currency'] = currency + await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2') + return ADD_CATEGORY + + +async def add_category_received(update: Update, context: CallbackContext): + user_id, category_name = update.effective_user.id, update.message.text.strip() + if not category_name: + await update.message.reply_text("类别不能为空。") + return ADD_CATEGORY + context.user_data['new_sub_data']['category'] = category_name + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name)) + conn.commit() + await update.message.reply_text("第五步:请输入 *下一次付款日期*(例如 2025\\-10\\-01 或 10月1日)", + parse_mode='MarkdownV2') + return ADD_NEXT_DUE + + +async def add_next_due_received(update: Update, context: CallbackContext): + parsed_date = parse_date(update.message.text) + if not parsed_date: + await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。") + return ADD_NEXT_DUE + context.user_data['new_sub_data']['next_due'] = parsed_date + keyboard = [ + [InlineKeyboardButton("天", callback_data='freq_unit_day'), + InlineKeyboardButton("周", callback_data='freq_unit_week')], + [InlineKeyboardButton("月", callback_data='freq_unit_month'), + InlineKeyboardButton("年", callback_data='freq_unit_year')] + ] + await update.message.reply_text("第六步:请选择付款周期的*单位*", reply_markup=InlineKeyboardMarkup(keyboard), + parse_mode='MarkdownV2') + return ADD_FREQ_UNIT + + +async def add_freq_unit_received(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + context.user_data['new_sub_data']['unit'] = query.data.split('_')[2] + await query.edit_message_text("第七步:请输入周期的*数量*(例如:每3个月,输入 3)", parse_mode='Markdown') + return ADD_FREQ_VALUE + + +async def add_freq_value_received(update: Update, context: CallbackContext): + try: + value = int(update.message.text) + if value <= 0: + raise ValueError + context.user_data['new_sub_data']['value'] = value + except (ValueError, TypeError): + await update.message.reply_text("请输入一个有效的正整数。") + return ADD_FREQ_VALUE + keyboard = [ + [InlineKeyboardButton("自动续费", callback_data='renewal_auto'), + InlineKeyboardButton("手动续费", callback_data='renewal_manual')] + ] + await update.message.reply_text("第八步:请选择 *续费方式*", reply_markup=InlineKeyboardMarkup(keyboard), + parse_mode='MarkdownV2') + return ADD_RENEWAL_TYPE + + +async def add_renewal_type_received(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + context.user_data['new_sub_data']['renewal_type'] = query.data.split('_')[1] + await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)") + return ADD_NOTES + + +async def add_notes_received(update: Update, context: CallbackContext): + sub_data = context.user_data.get('new_sub_data') + if not sub_data: + await update.message.reply_text("发生错误,请重试。") + return ConversationHandler.END + sub_data['notes'] = update.message.text + save_subscription(update.effective_user.id, sub_data) + await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!", + parse_mode='MarkdownV2') + context.user_data.clear() + return ConversationHandler.END + + +async def skip_notes(update: Update, context: CallbackContext): + sub_data = context.user_data.get('new_sub_data') + if not sub_data: + await update.message.reply_text("发生错误,请重试。") + return ConversationHandler.END + sub_data['notes'] = None + save_subscription(update.effective_user.id, sub_data) + await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!", + parse_mode='MarkdownV2') + context.user_data.clear() + return ConversationHandler.END + + +def save_subscription(user_id, data): + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', ( + user_id, data.get('name'), data.get('cost'), data.get('currency'), data.get('category'), + data.get('next_due'), + data.get('unit'), data.get('value'), data.get('renewal_type', 'auto'), data.get('notes') + )) + conn.commit() + + +# --- List, View, Edit, Delete --- +async def list_subs(update: Update, context: CallbackContext): + user_id = update.effective_user.id + keyboard = await get_subs_list_keyboard(user_id) + if not keyboard: + await update.message.reply_text("您还没有任何订阅。") + return + await update.message.reply_text("您的所有订阅:", reply_markup=keyboard) + + +async def list_categories(update: Update, context: CallbackContext): + user_id = update.effective_user.id + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT name FROM categories WHERE user_id = ? ORDER BY name", (user_id,)) + categories = cursor.fetchall() + if not categories: + if update.callback_query: + await update.callback_query.edit_message_text("您还没有任何分类。") + else: + await update.message.reply_text("您还没有任何分类。") + return + + buttons = [InlineKeyboardButton(cat[0], callback_data=f"list_subs_in_category_{cat[0]}") for cat in categories] + keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)] + keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")]) + if update.callback_query: + await update.callback_query.edit_message_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard)) + else: + await update.message.reply_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard)) + + +async def show_subscription_view(update: Update, context: CallbackContext, sub_id: int): + user_id = update.effective_user.id + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) + sub = cursor.fetchone() + if not sub: + logger.error(f"Subscription with id {sub_id} not found for user {user_id}") + if update.effective_message: + await update.effective_message.reply_text("错误:找不到该订阅。") + return + name, cost, currency, category, next_due, renewal_type, reminders_enabled, notes = ( + sub['name'], sub['cost'], sub['currency'], sub['category'], sub['next_due'], sub['renewal_type'], + sub['reminders_enabled'], sub['notes']) + freq_text = format_frequency(sub['frequency_unit'], sub['frequency_value']) + main_currency = get_user_main_currency(user_id) + converted_cost = convert_currency(cost, currency, main_currency) + safe_name, safe_category, safe_freq = escape_markdown(name, version=2), escape_markdown(category, + version=2), escape_markdown( + freq_text, version=2) + cost_str, converted_cost_str = escape_markdown(f"{cost:.2f}", version=2), escape_markdown(f"{converted_cost:.2f}", + version=2) + renewal_text = "手动续费" if renewal_type == 'manual' else "自动续费" + reminder_status = "开启" if reminders_enabled else "关闭" + text = (f"*订阅详情: {safe_name}*\n\n" + f"\\- *费用*: `{cost_str} {currency.upper()}` \\(\\~`{converted_cost_str} {main_currency.upper()}`\\)\n" + f"\\- *类别*: `{safe_category}`\n" + f"\\- *下次付款*: `{next_due}` \\(周期: {safe_freq}\\)\n" + f"\\- *续费方式*: `{renewal_text}`\n" + f"\\- *提醒状态*: `{reminder_status}`") + if notes: + text += f"\n\\- *备注*: {escape_markdown(notes, version=2)}" + keyboard_buttons = [ + [InlineKeyboardButton("✏️ 编辑", callback_data=f'edit_{sub_id}'), + InlineKeyboardButton("🗑️ 删除", callback_data=f'delete_{sub_id}')], + [InlineKeyboardButton("🔔 提醒设置", callback_data=f'remind_{sub_id}')] + ] + if renewal_type == 'manual': + keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')]) + if 'list_subs_in_category' in context.user_data: + cat_filter = context.user_data['list_subs_in_category'] + keyboard_buttons.append( + [InlineKeyboardButton("« 返回分类订阅", callback_data=f'list_subs_in_category_{cat_filter}')]) + else: + keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')]) + logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}") + if update.callback_query: + await update.callback_query.edit_message_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons), + parse_mode='MarkdownV2') + elif update.effective_message: + await update.effective_message.reply_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons), + parse_mode='MarkdownV2') + + +async def button_callback_handler(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + data = query.data + user_id = query.from_user.id + logger.debug(f"Received callback query: {data} from user {user_id}") + + if data.startswith('list_subs_in_category_'): + category = data.replace('list_subs_in_category_', '') + context.user_data['list_subs_in_category'] = category + keyboard = await get_subs_list_keyboard(user_id, category_filter=category) + msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:" + if not keyboard: + msg_text = f"分类“{escape_markdown(category, version=2)}”下没有订阅。" + keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]]) + await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='MarkdownV2') + return + if data == 'list_categories': + context.user_data.pop('list_subs_in_category', None) + await list_categories(update, context) + return + if data == 'list_all_subs': + context.user_data.pop('list_subs_in_category', None) + keyboard = await get_subs_list_keyboard(user_id) + if not keyboard: + await query.edit_message_text("您还没有任何订阅。") + return + await query.edit_message_text("您的所有订阅:", reply_markup=keyboard) + return + + action, _, sub_id_str = data.partition('_') + sub_id = int(sub_id_str) if sub_id_str.isdigit() else None + if not sub_id: + logger.error(f"Invalid sub_id in callback data: {data}") + await query.edit_message_text("错误:无效的订阅ID。") + return + + if action == 'view': + await show_subscription_view(update, context, sub_id) + + elif action == 'renewmanual': + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,)) + sub = cursor.fetchone() + if sub: + today = datetime.date.today() + new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value']) + if new_due_date: + new_date_str = new_due_date.strftime('%Y-%m-%d') + cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id)) + conn.commit() + await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True) + await show_subscription_view(update, context, sub_id) + else: + await query.answer("续费失败:无法计算新的到期日期。", show_alert=True) + else: + await query.answer("续费失败:订阅不存在。", show_alert=True) + + elif action == 'renewfromremind': + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,)) + sub = cursor.fetchone() + if sub: + today = datetime.date.today() + new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value']) + if new_due_date: + new_date_str = new_due_date.strftime('%Y-%m-%d') + cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id)) + conn.commit() + safe_sub_name = escape_markdown(sub['name'], version=2) + await query.edit_message_text( + text=f"✅ *续费成功*\n\n您的订阅 `{safe_sub_name}` 新的到期日为: `{new_date_str}`", + parse_mode='MarkdownV2', + reply_markup=None + ) + else: + await query.answer("续费失败:无法计算新的到期日期。", show_alert=True) + else: + await query.answer("续费失败:此订阅可能已被删除。", show_alert=True) + await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅已被删除)*", + parse_mode='MarkdownV2', reply_markup=None) + + elif action == 'delete': + keyboard = InlineKeyboardMarkup([ + [InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'), + InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')] + ]) + await query.edit_message_text(text="您确定要删除这个订阅吗?", reply_markup=keyboard) + elif action == 'confirmdelete': + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id)) + conn.commit() + await query.answer("订阅已删除") + if 'list_subs_in_category' in context.user_data: + category = context.user_data['list_subs_in_category'] + keyboard = await get_subs_list_keyboard(user_id, category_filter=category) + msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:" + if not keyboard: + msg_text = f"分类“{escape_markdown(category, version=2)}”下没有订阅。" + keyboard = InlineKeyboardMarkup( + [[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]]) + await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='MarkdownV2') + else: + keyboard = await get_subs_list_keyboard(user_id) + if not keyboard: + await query.edit_message_text("您还没有任何订阅。") + else: + await query.edit_message_text("您的所有订阅:", reply_markup=keyboard) + + +# --- 【新增】包装函数,用于在会话中处理“返回”按钮 --- +async def fallback_view_button(update: Update, context: CallbackContext): + """ + 在会话的 fallback 中调用,处理 view_... 按钮的点击。 + 它会先显示订阅详情,然后明确地结束当前会话。 + """ + # 先执行通用的按钮处理逻辑来显示界面 + await button_callback_handler(update, context) + # 然后返回 END,以确保当前会话(如编辑、提醒设置)被正确终止 + return ConversationHandler.END + + +# --- Edit Subscription Conversation --- +async def edit_start(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + sub_id = query.data.split('_')[1] + logger.debug(f"Starting edit for sub_id: {sub_id}") + context.user_data['sub_id_for_action'] = sub_id + keyboard = [ + [InlineKeyboardButton("名称", callback_data="editfield_name"), + InlineKeyboardButton("费用", callback_data="editfield_cost")], + [InlineKeyboardButton("货币", callback_data="editfield_currency"), + InlineKeyboardButton("类别", callback_data="editfield_category")], + [InlineKeyboardButton("下次付款日", callback_data="editfield_next_due"), + InlineKeyboardButton("周期", callback_data="editfield_frequency")], + [InlineKeyboardButton("续费方式", callback_data="editfield_renewal_type"), + InlineKeyboardButton("📝 备注", callback_data="editfield_notes")], + [InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')] + ] + await query.edit_message_text("请选择您想编辑的字段:", reply_markup=InlineKeyboardMarkup(keyboard)) + return EDIT_SELECT_FIELD + + +async def edit_field_selected(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + field_to_edit = query.data.partition('_')[2] + context.user_data['field_to_edit'] = field_to_edit + if field_to_edit == 'renewal_type': + keyboard = [ + [InlineKeyboardButton("自动续费", callback_data='editvalue_auto'), + InlineKeyboardButton("手动续费", callback_data='editvalue_manual')] + ] + await query.edit_message_text("请选择新的续费方式:", reply_markup=InlineKeyboardMarkup(keyboard)) + return EDIT_GET_NEW_VALUE + if field_to_edit == 'frequency': + keyboard = [ + [InlineKeyboardButton("天", callback_data='freq_unit_day'), + InlineKeyboardButton("周", callback_data='freq_unit_week')], + [InlineKeyboardButton("月", callback_data='freq_unit_month'), + InlineKeyboardButton("年", callback_data='freq_unit_year')] + ] + await query.edit_message_text("请选择新的周期*单位*", reply_markup=InlineKeyboardMarkup(keyboard), + parse_mode='MarkdownV2') + return EDIT_FREQ_UNIT + else: + field_map = {'name': '名称', 'cost': '费用', 'currency': '货币', 'category': '类别', 'next_due': '下次付款日', + 'notes': '备注'} + prompt = f"好的,请输入新的 *{field_map.get(field_to_edit, field_to_edit)}* 值:" + if field_to_edit == 'notes': + prompt += "\n(如需清空备注,请输入 /empty )" + await query.edit_message_text(prompt, parse_mode='MarkdownV2') + return EDIT_GET_NEW_VALUE + + +async def edit_freq_unit_received(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + context.user_data['new_freq_unit'] = query.data.split('_')[2] + await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2') + return EDIT_FREQ_VALUE + + +async def edit_freq_value_received(update: Update, context: CallbackContext): + user_id = update.effective_user.id + try: + value = int(update.message.text) + if value <= 0: + raise ValueError + except (ValueError, TypeError): + await update.message.reply_text("请输入一个有效的正整数。") + return EDIT_FREQ_VALUE + unit = context.user_data.get('new_freq_unit') + sub_id = int(context.user_data.get('sub_id_for_action')) + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?", + (unit, value, sub_id, user_id)) + conn.commit() + await update.message.reply_text("✅ 周期已更新!") + context.user_data.clear() + await show_subscription_view(update, context, sub_id) + return ConversationHandler.END + + +async def edit_new_value_received(update: Update, context: CallbackContext): + user_id = update.effective_user.id + field = context.user_data.get('field_to_edit') + try: + sub_id = int(context.user_data.get('sub_id_for_action')) + except (ValueError, TypeError): + if update.effective_message: + await update.effective_message.reply_text("错误:无效的订阅ID。") + return ConversationHandler.END + if not field: + if update.effective_message: + await update.effective_message.reply_text("错误:未选择要编辑的字段。") + return ConversationHandler.END + query, new_value = update.callback_query, "" + message_to_reply = update.effective_message + + if update.message and update.message.text == '/empty' and field == 'notes': + new_value = None + elif query: + new_value = query.data.split('_')[1] + elif update.message: + new_value = update.message.text + else: + if message_to_reply: + await message_to_reply.reply_text("错误:未提供新值。") + return ConversationHandler.END + + validation_failed = False + if field == 'cost': + try: + new_value = float(new_value) + if new_value < 0: + raise ValueError("费用不能为负数") + except (ValueError, TypeError): + if message_to_reply: await message_to_reply.reply_text("费用必须是有效的非负数字。") + validation_failed = True + elif field == 'currency': + new_value = str(new_value).upper() + if not (len(new_value) == 3 and new_value.isalpha()): + if message_to_reply: await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。") + validation_failed = True + elif field == 'next_due': + parsed = parse_date(str(new_value)) + if not parsed: + if message_to_reply: await message_to_reply.reply_text( + "无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。") + validation_failed = True + else: + new_value = parsed + elif field == 'category': + new_value = str(new_value).strip() + if not new_value: + if message_to_reply: await message_to_reply.reply_text("类别不能为空。") + validation_failed = True + else: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, new_value)) + conn.commit() + + if validation_failed: + return EDIT_GET_NEW_VALUE + + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(f"UPDATE subscriptions SET {field} = ? WHERE id = ? AND user_id = ?", + (new_value, sub_id, user_id)) + conn.commit() + + if query: + await query.answer(f"✅ 字段已更新!") + elif message_to_reply: + await message_to_reply.reply_text("✅ 字段已更新!") + + context.user_data.clear() + await show_subscription_view(update, context, sub_id) + return ConversationHandler.END + + +# --- Reminder Settings Conversation --- +async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int): + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days FROM subscriptions WHERE id = ?", + (sub_id,)) + sub = cursor.fetchone() + if not sub: + await query.edit_message_text("错误:找不到该订阅。") + return + enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒" + due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒" + keyboard = [ + [InlineKeyboardButton(enabled_text, callback_data='remindaction_toggle_enabled')], + [InlineKeyboardButton(due_date_text, callback_data='remindaction_toggle_due_date')] + ] + safe_name = escape_markdown(sub['name'], version=2) + current_status = f"*🔔 提醒设置: {safe_name}*\n\n" + if sub['renewal_type'] == 'manual': + current_status += f"当前提前提醒: *{sub['reminder_days']}天*\n" + keyboard.append([InlineKeyboardButton("⚙️ 更改提前天数", callback_data='remindaction_ask_days')]) + keyboard.append([InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')]) + await query.edit_message_text(current_status, reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='MarkdownV2') + + +async def remind_settings_start(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + sub_id_str = query.data.partition('_')[2] + if not sub_id_str.isdigit(): + await query.edit_message_text("错误:无效的订阅ID。") + return ConversationHandler.END + sub_id = int(sub_id_str) + logger.debug(f"Starting reminder settings for sub_id: {sub_id}") + context.user_data['sub_id_for_action'] = sub_id + await _display_reminder_settings(query, context, sub_id) + return REMIND_SELECT_ACTION + + +async def remind_action_handler(update: Update, context: CallbackContext): + query = update.callback_query + await query.answer() + + if not query.data: + return REMIND_SELECT_ACTION + + action = query.data.partition('remindaction_')[2] + sub_id = context.user_data.get('sub_id_for_action') + if not sub_id: + await query.edit_message_text("错误:会话已过期,请重试。") + return ConversationHandler.END + + if action == 'ask_days': + await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)") + return REMIND_GET_DAYS + + if action not in ['toggle_enabled', 'toggle_due_date']: + logger.warning(f"Unexpected action '{query.data}' in remind_action_handler") + return REMIND_SELECT_ACTION + + with get_db_connection() as conn: + cursor = conn.cursor() + if action == 'toggle_enabled': + cursor.execute("UPDATE subscriptions SET reminders_enabled = NOT reminders_enabled WHERE id = ?", (sub_id,)) + elif action == 'toggle_due_date': + cursor.execute("UPDATE subscriptions SET reminder_on_due_date = NOT reminder_on_due_date WHERE id = ?", + (sub_id,)) + conn.commit() + await _display_reminder_settings(query, context, sub_id) + return REMIND_SELECT_ACTION + + +async def remind_days_received(update: Update, context: CallbackContext): + sub_id = context.user_data.get('sub_id_for_action') + if not sub_id: + await update.message.reply_text("错误:会话已过期,请重试。") + return ConversationHandler.END + try: + days = int(update.message.text) + if days < 0: + raise ValueError + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ?", (days, sub_id)) + conn.commit() + await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。") + context.user_data.clear() + await show_subscription_view(update, context, sub_id) + except (ValueError, TypeError): + await update.message.reply_text("请输入一个有效的非负整数。") + return REMIND_GET_DAYS + return ConversationHandler.END + + +# --- Other Commands --- +async def set_currency(update: Update, context: CallbackContext): + user_id, args = update.effective_user.id, context.args + if len(args) != 1: + await update.message.reply_text("用法: /set_currency ``(例如 /set_currency USD)", parse_mode='MarkdownV2') + return + new_currency = args[0].upper() + if len(new_currency) != 3 or not new_currency.isalpha(): + await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。") + return + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("INSERT OR REPLACE INTO users (user_id, main_currency) VALUES (?, ?)", (user_id, new_currency)) + conn.commit() + await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}。", + parse_mode='MarkdownV2') + + +async def cancel(update: Update, context: CallbackContext): + context.user_data.clear() + if update.callback_query: + await update.callback_query.answer() + await update.callback_query.edit_message_text('操作已取消。') + else: + await update.message.reply_text('操作已取消。') + return ConversationHandler.END + + +# --- Main --- +def main(): + if not TELEGRAM_TOKEN: + logger.critical("TELEGRAM_TOKEN 环境变量未设置!") + return + + application = Application.builder().token(TELEGRAM_TOKEN).build() + + async def post_init(app: Application): + try: + bot_info = await app.bot.get_me() + logger.info(f"TELEGRAM_TOKEN 验证成功: {bot_info.username}") + except TelegramError as e: + logger.critical(f"TELEGRAM_TOKEN 无效或无法连接 Telegram API: {e}") + raise SystemExit + + commands = [ + BotCommand("start", "🚀 开始使用"), + BotCommand("add_sub", "➕ 添加新订阅"), + BotCommand("list_subs", "📋 列出所有订阅"), + BotCommand("list_categories", "🗂️ 按分类浏览"), + BotCommand("stats", "📊 查看订阅统计"), + BotCommand("import", "📥 导入订阅"), + BotCommand("export", "📤 导出订阅"), + BotCommand("set_currency", "💲 设置主货币"), + BotCommand("help", "ℹ️ 获取帮助"), + BotCommand("cancel", "❌ 取消当前操作") + ] + try: + await app.bot.delete_my_commands() + logger.debug("Cleared existing bot commands") + await app.bot.set_my_commands(commands) + logger.info("Bot commands registered successfully") + except TelegramError as e: + logger.error(f"Failed to register bot commands: {e}") + + app.job_queue.run_daily( + check_and_send_reminders, + time=datetime.time(hour=9, minute=0, tzinfo=datetime.timezone(datetime.timedelta(hours=8))), + name='daily_reminders' + ) + logger.info("Daily reminder job scheduled.") + + application.post_init = post_init + + add_conv = ConversationHandler( + entry_points=[CommandHandler('add_sub', add_sub_start)], + states={ + ADD_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_name_received)], + ADD_COST: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_cost_received)], + ADD_CURRENCY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_currency_received)], + ADD_CATEGORY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_category_received)], + ADD_NEXT_DUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_next_due_received)], + ADD_FREQ_UNIT: [CallbackQueryHandler(add_freq_unit_received, pattern='^freq_unit_')], + ADD_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_freq_value_received)], + ADD_RENEWAL_TYPE: [CallbackQueryHandler(add_renewal_type_received, pattern='^renewal_')], + ADD_NOTES: [ + MessageHandler(filters.TEXT & ~filters.COMMAND, add_notes_received), + CommandHandler('skip', skip_notes) + ], + }, + fallbacks=[CommandHandler('cancel', cancel)] + ) + + edit_conv = ConversationHandler( + entry_points=[CallbackQueryHandler(edit_start, pattern='^edit_')], + states={ + EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')], + EDIT_GET_NEW_VALUE: [ + MessageHandler(filters.TEXT & ~filters.COMMAND, edit_new_value_received), + CallbackQueryHandler(edit_new_value_received, pattern='^editvalue_'), + CommandHandler('empty', edit_new_value_received) + ], + EDIT_FREQ_UNIT: [CallbackQueryHandler(edit_freq_unit_received, pattern='^freq_unit_')], + EDIT_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, edit_freq_value_received)], + }, + fallbacks=[ + CommandHandler('cancel', cancel), + # 【修改】使用新的包装函数来确保会话能正确结束 + CallbackQueryHandler(fallback_view_button, pattern='^view_'), + CallbackQueryHandler(edit_start, pattern='^edit_'), + CallbackQueryHandler(remind_settings_start, pattern='^remind_') + ], + per_message=False + ) + + remind_conv = ConversationHandler( + entry_points=[CallbackQueryHandler(remind_settings_start, pattern='^remind_')], + states={ + REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')], + REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)], + }, + fallbacks=[ + CommandHandler('cancel', cancel), + # 【修改】使用新的包装函数来确保会话能正确结束 + CallbackQueryHandler(fallback_view_button, pattern='^view_'), + CallbackQueryHandler(edit_start, pattern='^edit_'), + CallbackQueryHandler(remind_settings_start, pattern='^remind_') + ], + per_message=False + ) + + import_conv = ConversationHandler( + entry_points=[CommandHandler('import', import_start)], + states={ + IMPORT_UPLOAD: [MessageHandler(filters.Document.ALL, import_upload_received)], + }, + fallbacks=[CommandHandler('cancel', cancel)] + ) + + button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_.+|list_categories|list_all_subs)$' + + application.add_handler(CommandHandler('start', start)) + application.add_handler(CommandHandler('help', help_command)) + application.add_handler(CommandHandler('list_subs', list_subs)) + application.add_handler(CommandHandler('list_categories', list_categories)) + application.add_handler(CommandHandler('set_currency', set_currency)) + application.add_handler(CommandHandler('stats', stats)) + application.add_handler(CommandHandler('export', export_command)) + application.add_handler(CommandHandler('cancel', cancel)) + + application.add_handler(add_conv) + application.add_handler(edit_conv) + application.add_handler(remind_conv) + application.add_handler(import_conv) + application.add_handler(CallbackQueryHandler(button_callback_handler, pattern=button_pattern)) + + logger.info(f"{PROJECT_NAME} Bot is starting...") + application.run_polling() + + +if __name__ == '__main__': + update_past_due_dates() + main() \ No newline at end of file