import sqlite3
import asyncio
import os
import html
import requests
import datetime
import dateparser
import logging
import tempfile
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import re
from dotenv import load_dotenv
from dateutil.relativedelta import relativedelta
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, BotCommand, CallbackQuery
from telegram.ext import (
Application, CommandHandler, MessageHandler, filters,
CallbackContext, CallbackQueryHandler, ConversationHandler
)
from telegram.error import TelegramError
from telegram.helpers import escape_html
# --- 加载 .env 和设置 ---
load_dotenv()
# --- 日志配置 ---
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# --- 环境变量和项目配置 ---
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY')
PROJECT_NAME = "SubMind"
DB_FILE = 'submind.db'
# --- 对话处理器状态 ---
(ADD_NAME, ADD_COST, ADD_CURRENCY, ADD_CATEGORY, ADD_NEXT_DUE,
ADD_FREQ_UNIT, ADD_FREQ_VALUE, ADD_RENEWAL_TYPE, ADD_NOTES) = range(9)
(EDIT_SELECT_FIELD, EDIT_GET_NEW_VALUE, EDIT_FREQ_UNIT, EDIT_FREQ_VALUE) = range(4)
(REMIND_SELECT_ACTION, REMIND_GET_DAYS) = range(2)
(IMPORT_UPLOAD,) = range(1)
# --- 辅助函数:数据库连接管理 ---
def get_db_connection():
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
return conn
# --- 字体管理 ---
def get_chinese_font():
font_name = 'SourceHanSansSC-Regular.otf'
font_path = os.path.join('fonts', font_name)
if os.path.exists(font_path):
logger.debug(f"Found font at {font_path}")
return fm.FontProperties(fname=font_path)
logger.info(f"Font '{font_name}' not found. Attempting to download...")
os.makedirs('fonts', exist_ok=True)
urls = [
'https://github.com/wweir/source-han-sans-sc/raw/refs/heads/master/SourceHanSansSC-Regular.otf',
'https://cdn.jsdelivr.net/gh/wweir/source-han-sans-sc@master/SourceHanSansSC-Regular.otf',
'https://fastly.jsdelivr.net/gh/wweir/source-han-sans-sc@master/SourceHanSansSC-Regular.otf'
]
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
for url in urls:
try:
logger.info(f"Trying to download font from: {url}")
response = requests.get(url, stream=True, headers=headers, timeout=15)
response.raise_for_status()
with open(font_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Font '{font_name}' downloaded successfully to '{font_path}'.")
fm._load_fontmanager(try_read_cache=False)
return fm.FontProperties(fname=font_path)
except requests.exceptions.RequestException as e:
logger.warning(f"Failed to download font from {url}. Error: {e}")
continue
logger.error("All font download attempts failed. Falling back to default sans-serif.")
return fm.FontProperties(family='sans-serif')
# --- 数据库初始化与迁移 ---
def init_db():
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS subscriptions (
id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, cost REAL, currency TEXT,
category TEXT, next_due DATE, frequency TEXT,
renewal_type TEXT DEFAULT 'auto',
reminders_enabled BOOLEAN DEFAULT TRUE,
reminder_days INTEGER DEFAULT 3,
reminder_on_due_date BOOLEAN DEFAULT TRUE,
frequency_unit TEXT,
frequency_value INTEGER,
notes TEXT
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_subscriptions_user_id ON subscriptions(user_id)')
cursor.execute("PRAGMA table_info(subscriptions)")
columns = [info[1] for info in cursor.fetchall()]
if 'frequency_unit' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_unit TEXT")
if 'frequency_value' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_value INTEGER")
if 'reminders_enabled' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminders_enabled BOOLEAN DEFAULT TRUE")
if 'reminder_days' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_days INTEGER DEFAULT 3")
if 'reminder_on_due_date' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_on_due_date BOOLEAN DEFAULT TRUE")
if 'notes' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN notes TEXT")
if 'last_reminded_date' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN last_reminded_date DATE")
cursor.execute('''
CREATE TABLE IF NOT EXISTS categories (
id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, UNIQUE(user_id, name)
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_categories_user_id ON categories(user_id)')
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY, main_currency TEXT DEFAULT "USD", language TEXT DEFAULT "en"
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS exchange_rates (
from_currency TEXT, to_currency TEXT, rate REAL, last_updated TIMESTAMP,
PRIMARY KEY (from_currency, to_currency)
)
''')
migrate_frequency_data(conn, cursor)
conn.commit()
def migrate_frequency_data(conn, cursor):
cursor.execute("SELECT id, frequency FROM subscriptions WHERE frequency IS NOT NULL AND frequency_unit IS NULL")
subs_to_migrate = cursor.fetchall()
if not subs_to_migrate:
return
freq_map = {
'daily': ('day', 1), 'weekly': ('week', 1), '周付': ('week', 1), 'monthly': ('month', 1),
'月付': ('month', 1), 'quarterly': ('month', 3), '季付': ('month', 3), '半年': ('month', 6),
'half-year': ('month', 6), 'biannually': ('month', 6), 'yearly': ('year', 1), '年付': ('year', 1)
}
for sub_id, freq_str in subs_to_migrate:
unit, value = freq_map.get(str(freq_str).lower(), (None, None))
if unit and value:
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ?",
(unit, value, sub_id))
conn.commit()
init_db()
# --- 辅助函数 ---
def get_user_main_currency(user_id):
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('SELECT main_currency FROM users WHERE user_id = ?', (user_id,))
result = cursor.fetchone()
return result['main_currency'] if result else 'USD'
def convert_currency(amount, from_curr, to_curr):
if from_curr.upper() == to_curr.upper():
return amount
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT rate, last_updated FROM exchange_rates
WHERE from_currency = ? AND to_currency = ?
''', (from_curr.upper(), to_curr.upper()))
result = cursor.fetchone()
now = datetime.datetime.now()
cache_validity = datetime.timedelta(hours=24)
if result and (now - datetime.datetime.fromisoformat(result['last_updated'])) < cache_validity:
logger.debug(f"Using cached exchange rate for {from_curr} to {to_curr}: {result['rate']}")
return amount * result['rate']
if not EXCHANGE_API_KEY:
logger.warning("EXCHANGE_API_KEY not set, returning original amount")
return amount
try:
url = f"https://v6.exchangerate-api.com/v6/{EXCHANGE_API_KEY}/pair/{from_curr}/{to_curr}/{amount}"
response = requests.get(url, timeout=5)
response.raise_for_status()
data = response.json()
rate = data.get('conversion_rate', 1.0)
cursor.execute('''
INSERT OR REPLACE INTO exchange_rates (from_currency, to_currency, rate, last_updated)
VALUES (?, ?, ?, ?)
''', (from_curr.upper(), to_curr.upper(), rate, now.isoformat()))
conn.commit()
logger.debug(f"Updated exchange rate cache for {from_curr} to {to_curr}: {rate}")
return amount * rate
except requests.exceptions.RequestException as e:
logger.error(f"Currency conversion API error: {e}")
if result:
logger.warning(f"Falling back to cached rate: {result['rate']}")
return amount * result['rate']
logger.warning("No cached rate available, returning original amount")
return amount
def parse_date(date_string: str) -> str:
today = datetime.datetime.now()
try:
dt = dateparser.parse(date_string, languages=['en', 'zh'], settings={'TIMEZONE': 'Asia/Shanghai', 'RETURN_AS_TIMEZONE_AWARE': False})
if not dt:
return None
has_year_info = any(c in date_string for c in ['年', '/']) or (re.search(r'\d{4}', date_string) is not None)
if not has_year_info and dt.date() < today.date():
dt = dt.replace(year=dt.year + 1)
return dt.strftime('%Y-%m-%d')
except Exception as e:
logger.error(f"Date parsing failed for string '{date_string}'. Error: {e}")
return None
def calculate_new_due_date(base_date, unit, value):
delta_map = {
'day': relativedelta(days=+value), 'week': relativedelta(weeks=+value),
'month': relativedelta(months=+value), 'year': relativedelta(years=+value)
}
delta = delta_map.get(str(unit).lower())
return base_date + delta if delta else None
def format_frequency(unit, value) -> str:
if not unit or value is None:
return "未知"
unit_map = {'day': '天', 'week': '周', 'month': '个月', 'year': '年'}
if value == 1:
single_unit_map = {'day': '每天', 'week': '每周', 'month': '每月', 'year': '每年'}
return single_unit_map.get(unit, f"每 {value} {unit_map.get(unit, unit)}")
return f"每 {value} {unit_map.get(unit, unit)}"
CATEGORY_CB_PREFIX = "list_subs_in_category_id_"
EDITABLE_SUB_FIELDS = {
'name': 'name',
'cost': 'cost',
'currency': 'currency',
'category': 'category',
'next_due': 'next_due',
'renewal_type': 'renewal_type',
'notes': 'notes'
}
MAX_NAME_LEN = 128
MAX_CATEGORY_LEN = 64
MAX_NOTES_LEN = 1000
VALID_FREQ_UNITS = {'day', 'week', 'month', 'year'}
VALID_RENEWAL_TYPES = {'auto', 'manual'}
def _build_category_callback_data(category_id: int) -> str:
return f"{CATEGORY_CB_PREFIX}{category_id}"
def _parse_category_id_from_callback(data: str) -> int | None:
payload = data.replace(CATEGORY_CB_PREFIX, '', 1)
return int(payload) if payload.isdigit() else None
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
with get_db_connection() as conn:
cursor = conn.cursor()
query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id]
if category_filter:
query += "AND category = ? "
params.append(category_filter)
query += "ORDER BY next_due ASC"
cursor.execute(query, tuple(params))
subs = cursor.fetchall()
if not subs:
return None
buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs]
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
if category_filter:
keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')])
else:
keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')])
return InlineKeyboardMarkup(keyboard)
def _clear_action_state(context: CallbackContext, keys: list[str]):
for key in keys:
context.user_data.pop(key, None)
def _get_new_sub_data_or_end(update: Update, context: CallbackContext):
sub_data = context.user_data.get('new_sub_data')
if sub_data is None:
message_obj = update.message or (update.callback_query.message if update.callback_query else None)
if message_obj:
# 统一提示,避免 KeyError 导致会话崩溃
return None, message_obj
return sub_data, None
# --- 自动任务 ---
def update_past_due_dates():
today = datetime.date.today()
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM subscriptions WHERE next_due < ? AND renewal_type = 'auto'", (today,))
past_due_subs = cursor.fetchall()
if not past_due_subs:
return
for sub in past_due_subs:
try:
last_due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date()
new_due_date = last_due_date
while new_due_date <= today:
calculated_date = calculate_new_due_date(new_due_date, sub['frequency_unit'],
sub['frequency_value'])
if calculated_date:
new_due_date = calculated_date
else:
break
if new_due_date > last_due_date:
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?",
(new_due_date.strftime('%Y-%m-%d'), sub['id']))
except Exception as e:
logger.error(f"Failed to update subscription {sub['id']}: {e}")
conn.commit()
async def check_and_send_reminders(context: CallbackContext):
logger.info("Running job: Checking for subscription reminders...")
today = datetime.date.today()
today_str = today.strftime('%Y-%m-%d')
with get_db_connection() as conn:
cursor = conn.cursor()
# 过滤掉今天已经提醒过的订阅
cursor.execute("SELECT * FROM subscriptions WHERE reminders_enabled = TRUE AND next_due IS NOT NULL AND (last_reminded_date IS NULL OR last_reminded_date != ?)", (today_str,))
subs_to_check = cursor.fetchall()
for sub in subs_to_check:
try:
due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date()
user_id = sub['user_id']
renewal_type = sub['renewal_type']
safe_sub_name = escape_html(sub['name'])
message = None
keyboard = None
if renewal_type == 'manual':
keyboard = InlineKeyboardMarkup([
[InlineKeyboardButton("✅ 我已续费", callback_data=f"renewfromremind_{sub['id']}")]
])
if sub['reminder_on_due_date'] and due_date == today:
message = f"🔔 *订阅到期提醒*\n\n您的订阅 `{safe_sub_name}` 今天到期。"
if renewal_type == 'manual':
message += " 请记得手动续费。"
else:
message += " 将会自动续费。"
keyboard = None
elif renewal_type == 'manual' and sub['reminder_days'] > 0:
reminder_date = due_date - datetime.timedelta(days=sub['reminder_days'])
if reminder_date == today:
days_left = (due_date - today).days
days_text = f"{days_left}天后" if days_left > 0 else "今天"
message = f"🔔 *订阅即将到期提醒*\n\n您的手动续费订阅 `{safe_sub_name}` 将在 {days_text} 到期。"
if message:
await context.bot.send_message(
chat_id=user_id,
text=message,
parse_mode='HTML',
reply_markup=keyboard
)
# 记录今天已发送提醒
with get_db_connection() as update_conn:
update_cursor = update_conn.cursor()
update_cursor.execute("UPDATE subscriptions SET last_reminded_date = ? WHERE id = ?", (today_str, sub['id']))
update_conn.commit()
logger.info(f"Reminder sent for sub_id {sub['id']}")
except Exception as e:
logger.error(f"Failed to process reminder for sub_id {sub.get('id', 'N/A')}: {e}")
# --- 命令处理器 ---
async def start(update: Update, context: CallbackContext):
user_id = update.effective_user.id
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('INSERT OR IGNORE INTO users (user_id) VALUES (?)', (user_id,))
conn.commit()
await update.message.reply_text(f'欢迎使用 {escape_html(PROJECT_NAME)}!\n您的私人订阅智能管家。',
parse_mode='HTML')
async def help_command(update: Update, context: CallbackContext):
help_text = fr"""
*{escape_html(PROJECT_NAME)} 命令列表*
*🌟 核心功能*
/add\_sub \- 引导您添加一个新的订阅
/list\_subs \- 列出您的所有订阅
/list\_categories \- 按分类浏览您的订阅
*📊 数据管理*
/stats \- 查看按类别分类的订阅统计
/import \- 通过上传 CSV 文件批量导入订阅
/export \- 将您的所有订阅导出为 CSV 文件
*⚙️ 个性化设置*
/set\_currency \`\` \- 设置您的主要货币
/cancel \- 在任何流程中取消当前操作
"""
await update.message.reply_text(help_text, parse_mode='HTML')
def make_autopct(values, currency_code):
currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'}
symbol = currency_symbols.get(currency_code.upper(), f'{currency_code} ')
def my_autopct(pct):
total = sum(values)
val = float(pct * total / 100.0)
return f'{symbol}{val:.2f}\n({pct:.1f}%)'
return my_autopct
async def stats(update: Update, context: CallbackContext):
user_id = update.effective_user.id
await update.message.reply_text("正在为您生成更美观的统计图,请稍候...")
def generate_chart():
font_prop = get_chinese_font()
main_currency = get_user_main_currency(user_id)
with get_db_connection() as conn:
df = pd.read_sql_query("SELECT * FROM subscriptions WHERE user_id = ?", conn, params=(user_id,))
cursor = conn.cursor()
cursor.execute("SELECT from_currency, to_currency, rate FROM exchange_rates WHERE to_currency = ?", (main_currency.upper(),))
rate_cache = {(row['from_currency'], row['to_currency']): row['rate'] for row in cursor.fetchall()}
if df.empty:
return False, "您还没有任何订阅数据。"
def fast_convert(amount, from_curr, to_curr):
if from_curr.upper() == to_curr.upper():
return amount
cached_rate = rate_cache.get((from_curr.upper(), to_curr.upper()))
if cached_rate is not None:
return amount * cached_rate
return convert_currency(amount, from_curr, to_curr)
df['converted_cost'] = df.apply(lambda row: fast_convert(row['cost'], row['currency'], main_currency), axis=1)
unit_to_days = {'day': 1, 'week': 7, 'month': 30.4375, 'year': 365.25}
def normalize_to_monthly(row):
if pd.isna(row['frequency_unit']) or pd.isna(row['frequency_value']) or row['frequency_value'] == 0:
return 0
total_days = row['frequency_value'] * unit_to_days.get(row['frequency_unit'], 0)
if total_days == 0:
return 0
return (row['converted_cost'] / total_days) * 30.4375
df['monthly_cost'] = df.apply(normalize_to_monthly, axis=1)
category_costs = df.groupby('category')['monthly_cost'].sum().sort_values(ascending=False)
if category_costs.empty or category_costs.sum() == 0:
return False, "您的订阅没有有效的费用信息。"
max_categories = 8
if len(category_costs) > max_categories:
top = category_costs.iloc[:max_categories]
others_sum = category_costs.iloc[max_categories:].sum()
if others_sum > 0:
category_costs = pd.concat([top, pd.Series({'其他': others_sum})])
else:
category_costs = top
total_monthly = category_costs.sum()
currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'}
symbol = currency_symbols.get(main_currency.upper(), f'{main_currency.upper()} ')
def autopct_if_large(pct):
if pct < 4:
return ''
value = total_monthly * pct / 100
return f"{pct:.1f}%\n{symbol}{value:.2f}"
fig = plt.figure(figsize=(15, 8.5), facecolor='#FAFAFA')
gs = fig.add_gridspec(1, 2, width_ratios=[1.1, 1], wspace=0.15)
ax_pie = fig.add_subplot(gs[0, 0])
ax_bar = fig.add_subplot(gs[0, 1])
image_path = None
try:
theme_colors = ['#3B82F6', '#10B981', '#F59E0B', '#EF4444', '#8B5CF6', '#EC4899', '#14B8A6', '#F97316', '#6366F1', '#84CC16']
if len(category_costs) > len(theme_colors):
import matplotlib.pyplot as plt
extra_colors = [matplotlib.colors.to_hex(c) for c in plt.get_cmap('tab20').colors]
theme_colors.extend(extra_colors)
color_map = {cat: theme_colors[i] for i, cat in enumerate(category_costs.index)}
pie_colors = [color_map[cat] for cat in category_costs.index]
wedges, texts, autotexts = ax_pie.pie(
category_costs.values,
labels=category_costs.index,
autopct=autopct_if_large,
startangle=140,
counterclock=False,
pctdistance=0.75,
labeldistance=1.1,
colors=pie_colors,
wedgeprops={'width': 0.35, 'edgecolor': '#FAFAFA', 'linewidth': 2.5}
)
for t in texts:
t.set_fontproperties(font_prop)
t.set_fontsize(13)
t.set_color('#374151')
for t in autotexts:
t.set_fontproperties(font_prop)
t.set_fontsize(10)
t.set_color('#FFFFFF')
t.set_weight('bold')
ax_pie.text(
0, 0,
f"月总支出\n{symbol}{total_monthly:.2f}",
ha='center', va='center',
fontproperties=font_prop,
fontsize=18,
color='#1F2937',
weight='bold'
)
ax_pie.set_title('支出占比结构', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold')
ax_pie.axis('equal')
bar_series = category_costs.sort_values(ascending=True)
bar_colors = [color_map[cat] for cat in bar_series.index]
bars = ax_bar.barh(bar_series.index, bar_series.values, color=bar_colors, height=0.6, alpha=0.95, edgecolor='none')
ax_bar.set_title('各类别月支出对比', fontproperties=font_prop, fontsize=18, pad=20, color='#111827', weight='bold')
ax_bar.set_xlabel(f'金额({main_currency.upper()})', fontproperties=font_prop, fontsize=12, color='#6B7280', labelpad=10)
ax_bar.spines['top'].set_visible(False)
ax_bar.spines['right'].set_visible(False)
ax_bar.spines['left'].set_visible(False)
ax_bar.spines['bottom'].set_color('#E5E7EB')
ax_bar.tick_params(axis='x', colors='#6B7280', labelsize=11)
ax_bar.tick_params(axis='y', length=0, pad=10)
ax_bar.grid(axis='x', color='#F3F4F6', linestyle='-', linewidth=1.5, alpha=1)
ax_bar.set_axisbelow(True)
for label in ax_bar.get_yticklabels():
label.set_fontproperties(font_prop)
label.set_fontsize(13)
label.set_color('#374151')
max_val = bar_series.max() if len(bar_series) else 0
offset = max_val * 0.02 if max_val > 0 else 0.1
for rect, value in zip(bars, bar_series.values):
ax_bar.text(
rect.get_width() + offset,
rect.get_y() + rect.get_height() / 2,
f"{symbol}{value:.2f}",
va='center',
ha='left',
fontproperties=font_prop,
fontsize=11,
color='#4B5563',
weight='bold'
)
fig.suptitle('📊 您的订阅支出洞察', fontproperties=font_prop, fontsize=24, color='#0F172A', y=1.02, weight='bold')
fig.tight_layout(rect=[0, 0, 1, 0.95])
with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp:
image_path = tmp.name
plt.savefig(image_path, dpi=250, bbox_inches='tight', facecolor=fig.get_facecolor())
return True, image_path
finally:
plt.close(fig)
success, result = await asyncio.to_thread(generate_chart)
if success:
try:
with open(result, 'rb') as photo:
await update.message.reply_photo(photo, caption="✨ 已为您生成全新的精美订阅统计图!")
finally:
if os.path.exists(result):
os.remove(result)
else:
await update.message.reply_text(result)
# --- Import/Export Commands ---
async def export_command(update: Update, context: CallbackContext):
user_id = update.effective_user.id
# 将重度 I/O 和 CPU 绑定的 pandas 导出操作放入后台线程
def process_export():
with get_db_connection() as conn:
df = pd.read_sql_query(
"SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ?",
conn, params=(user_id,))
if df.empty:
return False, None
tmp = tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False)
export_path = tmp.name
tmp.close()
df.to_csv(export_path, index=False, encoding='utf-8-sig')
return True, export_path
success, export_path = await asyncio.to_thread(process_export)
if not success:
await update.message.reply_text("您还没有任何订阅数据,无法导出。")
return
try:
with open(export_path, 'rb') as file:
await update.message.reply_document(document=file, filename='subscriptions.csv',
caption="您的订阅数据已导出为 CSV 文件。")
finally:
if export_path and os.path.exists(export_path):
os.remove(export_path)
async def import_start(update: Update, context: CallbackContext):
await update.message.reply_text(
"请上传一个 CSV 文件以导入订阅数据。\n文件应包含以下列:name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes(notes 可为空)。")
return IMPORT_UPLOAD
async def import_upload_received(update: Update, context: CallbackContext):
user_id = update.effective_user.id
if not update.message.document or not update.message.document.file_name.endswith('.csv'):
await update.message.reply_text("请上传一个有效的 CSV 文件。")
return IMPORT_UPLOAD
file = await update.message.document.get_file()
with tempfile.NamedTemporaryFile(prefix=f'import_{user_id}_', suffix='.csv', delete=False) as tmp:
file_path = tmp.name
try:
await file.download_to_drive(file_path)
df = pd.read_csv(file_path, encoding='utf-8-sig')
required_columns = ['name', 'cost', 'currency', 'category', 'next_due', 'frequency_unit', 'frequency_value',
'renewal_type']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
await update.message.reply_text(f"CSV 文件缺少以下必要列:{', '.join(missing_columns)}")
return ConversationHandler.END
valid_units = ['day', 'week', 'month', 'year']
valid_renewal_types = ['auto', 'manual']
records = []
for _, row in df.iterrows():
try:
cost = float(row['cost'])
if cost < 0:
raise ValueError("费用不能为负数")
currency = str(row['currency']).upper()
if not (len(currency) == 3 and currency.isalpha()):
raise ValueError(f"无效货币代码: {currency}")
next_due = parse_date(str(row['next_due']))
if not next_due:
raise ValueError(f"无效日期格式: {row['next_due']}")
frequency_unit = str(row['frequency_unit']).lower()
if frequency_unit not in valid_units:
raise ValueError(f"无效周期单位: {frequency_unit}")
frequency_value = int(row['frequency_value'])
if frequency_value <= 0:
raise ValueError(f"无效周期数量: {frequency_value}")
renewal_type = str(row['renewal_type']).lower()
if renewal_type not in valid_renewal_types:
raise ValueError(f"无效续费类型: {renewal_type}")
notes = str(row['notes']).strip() if pd.notna(row['notes']) else None
if notes and len(notes) > MAX_NOTES_LEN:
raise ValueError(f"备注过长(>{MAX_NOTES_LEN})")
name = str(row['name']).strip()
category = str(row['category']).strip()
if not name:
raise ValueError("名称不能为空")
if not category:
raise ValueError("类别不能为空")
if len(name) > MAX_NAME_LEN:
raise ValueError(f"名称过长(>{MAX_NAME_LEN})")
if len(category) > MAX_CATEGORY_LEN:
raise ValueError(f"类别过长(>{MAX_CATEGORY_LEN})")
records.append((
user_id, name, cost, currency, category,
next_due, frequency_unit, frequency_value, renewal_type, notes
))
except Exception as e:
logger.error(f"Invalid row in CSV import, error: {e}")
await update.message.reply_text(f"导入失败,存在无效行:{e}")
return ConversationHandler.END
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.executemany('''
INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', records)
for record in records:
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, record[4]))
conn.commit()
await update.message.reply_text(f"✅ 成功导入 {len(records)} 条订阅数据!")
except Exception as e:
logger.error(f"Import failed: {e}")
await update.message.reply_text(f"导入失败:{e}")
finally:
if os.path.exists(file_path):
os.remove(file_path)
return ConversationHandler.END
# --- Add Subscription Conversation ---
async def add_sub_start(update: Update, context: CallbackContext):
context.user_data['new_sub_data'] = {}
await update.message.reply_text("好的,我们来添加一个新订阅。\n\n第一步:请输入订阅的 名称",
parse_mode='HTML')
return ADD_NAME
async def add_name_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
name = update.message.text.strip()
if not name:
await update.message.reply_text("订阅名称不能为空。")
return ADD_NAME
if len(name) > MAX_NAME_LEN:
await update.message.reply_text(f"订阅名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
return ADD_NAME
sub_data['name'] = name
await update.message.reply_text("第二步:请输入订阅 费用", parse_mode='HTML')
return ADD_COST
async def add_cost_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
try:
cost = float(update.message.text)
if cost < 0:
raise ValueError("费用不能为负数")
sub_data['cost'] = cost
except (ValueError, TypeError):
await update.message.reply_text("费用必须是有效的非负数字。")
return ADD_COST
await update.message.reply_text("第三步:请输入 货币 代码(例如 USD, CNY)", parse_mode='HTML')
return ADD_CURRENCY
async def add_currency_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
currency = update.message.text.upper()
if not (len(currency) == 3 and currency.isalpha()):
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
return ADD_CURRENCY
sub_data['currency'] = currency
await update.message.reply_text("第四步:请为订阅指定一个 类别", parse_mode='HTML')
return ADD_CATEGORY
async def add_category_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
user_id, category_name = update.effective_user.id, update.message.text.strip()
if not category_name:
await update.message.reply_text("类别不能为空。")
return ADD_CATEGORY
if len(category_name) > MAX_CATEGORY_LEN:
await update.message.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
return ADD_CATEGORY
sub_data['category'] = category_name
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
conn.commit()
await update.message.reply_text("第五步:请输入 *下一次付款日期*(例如 2025\\-10\\-01 或 10月1日)",
parse_mode='HTML')
return ADD_NEXT_DUE
async def add_next_due_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
parsed_date = parse_date(update.message.text)
if not parsed_date:
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
return ADD_NEXT_DUE
sub_data['next_due'] = parsed_date
keyboard = [
[InlineKeyboardButton("天", callback_data='freq_unit_day'),
InlineKeyboardButton("周", callback_data='freq_unit_week')],
[InlineKeyboardButton("月", callback_data='freq_unit_month'),
InlineKeyboardButton("年", callback_data='freq_unit_year')]
]
await update.message.reply_text("第六步:请选择付款周期的单位", reply_markup=InlineKeyboardMarkup(keyboard),
parse_mode='HTML')
return ADD_FREQ_UNIT
async def add_freq_unit_received(update: Update, context: CallbackContext):
sub_data, _ = _get_new_sub_data_or_end(update, context)
query = update.callback_query
await query.answer()
if sub_data is None:
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
unit = query.data.partition('freq_unit_')[2]
if unit not in VALID_FREQ_UNITS:
await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END
sub_data['unit'] = unit
await query.edit_message_text("第七步:请输入周期的数量(例如:每3个月,输入 3)", parse_mode='Markdown')
return ADD_FREQ_VALUE
async def add_freq_value_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
try:
value = int(update.message.text)
if value <= 0:
raise ValueError
sub_data['value'] = value
except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的正整数。")
return ADD_FREQ_VALUE
keyboard = [
[InlineKeyboardButton("自动续费", callback_data='renewal_auto'),
InlineKeyboardButton("手动续费", callback_data='renewal_manual')]
]
await update.message.reply_text("第八步:请选择 续费方式", reply_markup=InlineKeyboardMarkup(keyboard),
parse_mode='HTML')
return ADD_RENEWAL_TYPE
async def add_renewal_type_received(update: Update, context: CallbackContext):
sub_data, _ = _get_new_sub_data_or_end(update, context)
query = update.callback_query
await query.answer()
if sub_data is None:
await query.edit_message_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
renewal_type = query.data.partition('renewal_')[2]
if renewal_type not in VALID_RENEWAL_TYPES:
await query.edit_message_text("错误:无效的续费类型,请重试。")
return ConversationHandler.END
sub_data['renewal_type'] = renewal_type
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)")
return ADD_NOTES
async def add_notes_received(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
_clear_action_state(context, ['new_sub_data'])
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
note = update.message.text.strip()
if len(note) > MAX_NOTES_LEN:
await update.message.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。")
return ADD_NOTES
sub_data['notes'] = note if note else None
save_subscription(update.effective_user.id, sub_data)
await update.message.reply_text(text=f"✅ 订阅 '{escape_html(sub_data.get('name', ''))}' 已添加!",
parse_mode='HTML')
_clear_action_state(context, ['new_sub_data'])
return ConversationHandler.END
async def skip_notes(update: Update, context: CallbackContext):
sub_data, err_msg_obj = _get_new_sub_data_or_end(update, context)
if sub_data is None:
_clear_action_state(context, ['new_sub_data'])
if err_msg_obj:
await err_msg_obj.reply_text("会话已过期,请重新使用 /add_sub 开始。")
return ConversationHandler.END
sub_data['notes'] = None
save_subscription(update.effective_user.id, sub_data)
await update.message.reply_text(text=f"✅ 订阅 '{escape_html(sub_data.get('name', ''))}' 已添加!",
parse_mode='HTML')
_clear_action_state(context, ['new_sub_data'])
return ConversationHandler.END
def save_subscription(user_id, data):
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
user_id, data.get('name'), data.get('cost'), data.get('currency'), data.get('category'),
data.get('next_due'),
data.get('unit'), data.get('value'), data.get('renewal_type', 'auto'), data.get('notes')
))
conn.commit()
# --- List, View, Edit, Delete ---
async def list_subs(update: Update, context: CallbackContext):
user_id = update.effective_user.id
keyboard = await get_subs_list_keyboard(user_id)
if not keyboard:
await update.message.reply_text("您还没有任何订阅。")
return
await update.message.reply_text("您的所有订阅:", reply_markup=keyboard)
async def list_categories(update: Update, context: CallbackContext):
user_id = update.effective_user.id
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT id, name FROM categories WHERE user_id = ? ORDER BY name", (user_id,))
categories = cursor.fetchall()
if not categories:
if update.callback_query:
await update.callback_query.edit_message_text("您还没有任何分类。")
else:
await update.message.reply_text("您还没有任何分类。")
return
buttons = []
for cat in categories:
cat_id, cat_name = cat[0], cat[1]
buttons.append(InlineKeyboardButton(cat_name, callback_data=_build_category_callback_data(cat_id)))
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
if update.callback_query:
await update.callback_query.edit_message_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard))
else:
await update.message.reply_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard))
async def show_subscription_view(update: Update, context: CallbackContext, sub_id: int):
user_id = update.effective_user.id
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
sub = cursor.fetchone()
if not sub:
logger.error(f"Subscription with id {sub_id} not found for user {user_id}")
if update.effective_message:
await update.effective_message.reply_text("错误:找不到该订阅。")
return
name, cost, currency, category, next_due, renewal_type, reminders_enabled, notes = (
sub['name'], sub['cost'], sub['currency'], sub['category'], sub['next_due'], sub['renewal_type'],
sub['reminders_enabled'], sub['notes'])
freq_text = format_frequency(sub['frequency_unit'], sub['frequency_value'])
main_currency = get_user_main_currency(user_id)
converted_cost = convert_currency(cost, currency, main_currency)
safe_name, safe_category, safe_freq = escape_html(name), escape_html(category), escape_html(freq_text)
cost_str, converted_cost_str = escape_html(f"{cost:.2f}"), escape_html(f"{converted_cost:.2f}")
renewal_text = "手动续费" if renewal_type == 'manual' else "自动续费"
reminder_status = "开启" if reminders_enabled else "关闭"
text = (f"*订阅详情: {safe_name}*\n\n"
f"\\- *费用*: `{cost_str} {currency.upper()}` \\(\\~`{converted_cost_str} {main_currency.upper()}`\\)\n"
f"\\- *类别*: `{safe_category}`\n"
f"\\- *下次付款*: `{next_due}` \\(周期: {safe_freq}\\)\n"
f"\\- *续费方式*: `{renewal_text}`\n"
f"\\- *提醒状态*: `{reminder_status}`")
if notes:
text += f"\n\\- *备注*: {escape_html(notes)}"
keyboard_buttons = [
[InlineKeyboardButton("✏️ 编辑", callback_data=f'edit_{sub_id}'),
InlineKeyboardButton("🗑️ 删除", callback_data=f'delete_{sub_id}')],
[InlineKeyboardButton("🔔 提醒设置", callback_data=f'remind_{sub_id}')]
]
if renewal_type == 'manual':
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
if 'list_subs_in_category' in context.user_data:
cat_filter = context.user_data['list_subs_in_category']
category_id = context.user_data.get('list_subs_in_category_id')
if category_id:
back_cb = _build_category_callback_data(category_id)
else:
back_cb = 'list_categories'
keyboard_buttons.append([InlineKeyboardButton("« 返回分类订阅", callback_data=back_cb)])
else:
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
if update.callback_query:
await update.callback_query.edit_message_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons),
parse_mode='HTML')
elif update.effective_message:
await update.effective_message.reply_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons),
parse_mode='HTML')
async def button_callback_handler(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
data = query.data
user_id = query.from_user.id
logger.debug(f"Received callback query: {data} from user {user_id}")
if data.startswith(CATEGORY_CB_PREFIX):
category_id = _parse_category_id_from_callback(data)
if not category_id:
await query.edit_message_text("错误:无效或已过期的分类,请重新选择。")
return
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM categories WHERE id = ? AND user_id = ?", (category_id, user_id))
row = cursor.fetchone()
if not row:
await query.edit_message_text("错误:分类不存在或无权限。")
return
category = row['name']
context.user_data['list_subs_in_category'] = category
context.user_data['list_subs_in_category_id'] = category_id
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
msg_text = f"分类 {escape_html(category)} 下的订阅:"
if not keyboard:
msg_text = f"分类 {escape_html(category)} 下没有订阅。"
keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]])
await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='HTML')
return
if data == 'list_categories':
context.user_data.pop('list_subs_in_category', None)
context.user_data.pop('list_subs_in_category_id', None)
await list_categories(update, context)
return
if data == 'list_all_subs':
context.user_data.pop('list_subs_in_category', None)
context.user_data.pop('list_subs_in_category_id', None)
keyboard = await get_subs_list_keyboard(user_id)
if not keyboard:
await query.edit_message_text("您还没有任何订阅。")
return
await query.edit_message_text("您的所有订阅:", reply_markup=keyboard)
return
action, _, sub_id_str = data.partition('_')
sub_id = int(sub_id_str) if sub_id_str.isdigit() else None
if not sub_id:
logger.error(f"Invalid sub_id in callback data: {data}")
await query.edit_message_text("错误:无效的订阅ID。")
return
if action == 'view':
await show_subscription_view(update, context, sub_id)
elif action == 'renewmanual':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
sub = cursor.fetchone()
if sub:
today = datetime.date.today()
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
if new_due_date:
new_date_str = new_due_date.strftime('%Y-%m-%d')
cursor.execute(
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
(new_date_str, sub_id, user_id)
)
conn.commit()
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
await show_subscription_view(update, context, sub_id)
else:
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
else:
await query.answer("续费失败:订阅不存在或无权限。", show_alert=True)
elif action == 'renewfromremind':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
sub = cursor.fetchone()
if sub:
today = datetime.date.today()
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
if new_due_date:
new_date_str = new_due_date.strftime('%Y-%m-%d')
cursor.execute(
"UPDATE subscriptions SET next_due = ? WHERE id = ? AND user_id = ?",
(new_date_str, sub_id, user_id)
)
conn.commit()
safe_sub_name = escape_html(sub['name'])
await query.edit_message_text(
text=f"✅ 续费成功\n\n您的订阅 {safe_sub_name} 新的到期日为: {new_date_str}",
parse_mode='HTML',
reply_markup=None
)
else:
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
else:
await query.answer("续费失败:此订阅可能已被删除或无权限。", show_alert=True)
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅不存在或无权限)*",
parse_mode='HTML', reply_markup=None)
elif action == 'delete':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
if not cursor.fetchone():
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
return
keyboard = InlineKeyboardMarkup([
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
])
await query.edit_message_text(text="您确定要删除这个订阅吗?", reply_markup=keyboard)
elif action == 'confirmdelete':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
deleted = cursor.rowcount
conn.commit()
if deleted == 0:
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
return
await query.answer("订阅已删除")
if 'list_subs_in_category' in context.user_data:
category = context.user_data['list_subs_in_category']
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
msg_text = f"分类 {escape_html(category)} 下的订阅:"
if not keyboard:
msg_text = f"分类 {escape_html(category)} 下没有订阅。"
keyboard = InlineKeyboardMarkup(
[[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]])
await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='HTML')
else:
keyboard = await get_subs_list_keyboard(user_id)
if not keyboard:
await query.edit_message_text("您还没有任何订阅。")
else:
await query.edit_message_text("您的所有订阅:", reply_markup=keyboard)
# --- 【新增】包装函数,用于在会话中处理“返回”按钮 ---
async def fallback_view_button(update: Update, context: CallbackContext):
"""
在会话的 fallback 中调用,处理 view_... 按钮的点击。
它会先显示订阅详情,然后明确地结束当前会话。
"""
# 先执行通用的按钮处理逻辑来显示界面
await button_callback_handler(update, context)
# 然后返回 END,以确保当前会话(如编辑、提醒设置)被正确终止
return ConversationHandler.END
# --- Edit Subscription Conversation ---
async def edit_start(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
sub_id_str = query.data.split('_')[1]
user_id = query.from_user.id
if not sub_id_str.isdigit():
await query.edit_message_text("错误:无效的订阅ID。")
return ConversationHandler.END
sub_id = int(sub_id_str)
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
if not cursor.fetchone():
await query.edit_message_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
logger.debug(f"Starting edit for sub_id: {sub_id}")
context.user_data['sub_id_for_action'] = sub_id
keyboard = [
[InlineKeyboardButton("名称", callback_data="editfield_name"),
InlineKeyboardButton("费用", callback_data="editfield_cost")],
[InlineKeyboardButton("货币", callback_data="editfield_currency"),
InlineKeyboardButton("类别", callback_data="editfield_category")],
[InlineKeyboardButton("下次付款日", callback_data="editfield_next_due"),
InlineKeyboardButton("周期", callback_data="editfield_frequency")],
[InlineKeyboardButton("续费方式", callback_data="editfield_renewal_type"),
InlineKeyboardButton("📝 备注", callback_data="editfield_notes")],
[InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')]
]
await query.edit_message_text("请选择您想编辑的字段:", reply_markup=InlineKeyboardMarkup(keyboard))
return EDIT_SELECT_FIELD
async def edit_field_selected(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
field_to_edit = query.data.partition('_')[2]
context.user_data['field_to_edit'] = field_to_edit
if field_to_edit == 'renewal_type':
keyboard = [
[InlineKeyboardButton("自动续费", callback_data='editvalue_auto'),
InlineKeyboardButton("手动续费", callback_data='editvalue_manual')]
]
await query.edit_message_text("请选择新的续费方式:", reply_markup=InlineKeyboardMarkup(keyboard))
return EDIT_GET_NEW_VALUE
if field_to_edit == 'frequency':
keyboard = [
[InlineKeyboardButton("天", callback_data='freq_unit_day'),
InlineKeyboardButton("周", callback_data='freq_unit_week')],
[InlineKeyboardButton("月", callback_data='freq_unit_month'),
InlineKeyboardButton("年", callback_data='freq_unit_year')]
]
await query.edit_message_text("请选择新的周期单位", reply_markup=InlineKeyboardMarkup(keyboard),
parse_mode='HTML')
return EDIT_FREQ_UNIT
else:
field_map = {'name': '名称', 'cost': '费用', 'currency': '货币', 'category': '类别', 'next_due': '下次付款日',
'notes': '备注'}
prompt = f"好的,请输入新的 {field_map.get(field_to_edit, field_to_edit)} 值:"
if field_to_edit == 'notes':
prompt += "\n(如需清空备注,请输入 /empty )"
await query.edit_message_text(prompt, parse_mode='HTML')
return EDIT_GET_NEW_VALUE
async def edit_freq_unit_received(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
unit = query.data.partition('freq_unit_')[2]
if unit not in VALID_FREQ_UNITS:
await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END
context.user_data['new_freq_unit'] = unit
await query.edit_message_text("好的,现在请输入新的周期数量。", parse_mode='HTML')
return EDIT_FREQ_VALUE
async def edit_freq_value_received(update: Update, context: CallbackContext):
user_id = update.effective_user.id
try:
value = int(update.message.text)
if value <= 0:
raise ValueError
except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的正整数。")
return EDIT_FREQ_VALUE
unit = context.user_data.get('new_freq_unit')
try:
sub_id = int(context.user_data.get('sub_id_for_action'))
except (ValueError, TypeError):
await update.message.reply_text("错误:会话已过期,请重试。")
return ConversationHandler.END
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?",
(unit, value, sub_id, user_id))
if cursor.rowcount == 0:
await update.message.reply_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit()
await update.message.reply_text("✅ 周期已更新!")
_clear_action_state(context, ['sub_id_for_action', 'new_freq_unit', 'field_to_edit'])
await show_subscription_view(update, context, sub_id)
return ConversationHandler.END
async def edit_new_value_received(update: Update, context: CallbackContext):
user_id = update.effective_user.id
field = context.user_data.get('field_to_edit')
try:
sub_id = int(context.user_data.get('sub_id_for_action'))
except (ValueError, TypeError):
if update.effective_message:
await update.effective_message.reply_text("错误:无效的订阅ID。")
return ConversationHandler.END
if not field:
if update.effective_message:
await update.effective_message.reply_text("错误:未选择要编辑的字段。")
return ConversationHandler.END
db_field = EDITABLE_SUB_FIELDS.get(field)
if not db_field or not db_field.isidentifier():
if update.effective_message:
await update.effective_message.reply_text("错误:不允许编辑该字段。")
logger.warning(f"Blocked unsafe field update attempt: {field}")
return ConversationHandler.END
query, new_value = update.callback_query, ""
message_to_reply = update.effective_message
if update.message and update.message.text == '/empty' and field == 'notes':
new_value = None
elif query:
new_value = query.data.split('_')[1]
elif update.message:
new_value = update.message.text
else:
if message_to_reply:
await message_to_reply.reply_text("错误:未提供新值。")
return ConversationHandler.END
validation_failed = False
if field == 'cost':
try:
new_value = float(new_value)
if new_value < 0:
raise ValueError("费用不能为负数")
except (ValueError, TypeError):
if message_to_reply:
await message_to_reply.reply_text("费用必须是有效的非负数字。")
validation_failed = True
elif field == 'name':
new_value = str(new_value).strip()
if not new_value:
if message_to_reply:
await message_to_reply.reply_text("名称不能为空。")
validation_failed = True
elif len(new_value) > MAX_NAME_LEN:
if message_to_reply:
await message_to_reply.reply_text(f"名称过长,请控制在 {MAX_NAME_LEN} 个字符以内。")
validation_failed = True
elif field == 'currency':
new_value = str(new_value).upper()
if not (len(new_value) == 3 and new_value.isalpha()):
if message_to_reply:
await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
validation_failed = True
elif field == 'next_due':
parsed = parse_date(str(new_value))
if not parsed:
if message_to_reply:
await message_to_reply.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
validation_failed = True
else:
new_value = parsed
elif field == 'renewal_type':
if str(new_value) not in VALID_RENEWAL_TYPES:
if message_to_reply:
await message_to_reply.reply_text("续费方式只能为 auto 或 manual。")
validation_failed = True
elif field == 'notes':
note_val = str(new_value).strip()
if note_val and len(note_val) > MAX_NOTES_LEN:
if message_to_reply:
await message_to_reply.reply_text(f"备注过长,请控制在 {MAX_NOTES_LEN} 个字符以内。")
validation_failed = True
else:
new_value = note_val if note_val else None
elif field == 'category':
new_value = str(new_value).strip()
if not new_value:
if message_to_reply:
await message_to_reply.reply_text("类别不能为空。")
validation_failed = True
elif len(new_value) > MAX_CATEGORY_LEN:
if message_to_reply:
await message_to_reply.reply_text(f"类别名称过长,请控制在 {MAX_CATEGORY_LEN} 个字符以内。")
validation_failed = True
else:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, new_value))
conn.commit()
if validation_failed:
return EDIT_GET_NEW_VALUE
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE subscriptions SET {db_field} = ? WHERE id = ? AND user_id = ?",
(new_value, sub_id, user_id))
if cursor.rowcount == 0:
if message_to_reply:
await message_to_reply.reply_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit()
if query:
await query.answer("✅ 字段已更新!")
elif message_to_reply:
await message_to_reply.reply_text("✅ 字段已更新!")
_clear_action_state(context, ['sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
await show_subscription_view(update, context, sub_id)
return ConversationHandler.END
# --- Reminder Settings Conversation ---
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
user_id = query.from_user.id
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days "
"FROM subscriptions WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
sub = cursor.fetchone()
if not sub:
await query.edit_message_text("错误:找不到该订阅或无权限。")
return
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
keyboard = [
[InlineKeyboardButton(enabled_text, callback_data='remindaction_toggle_enabled')],
[InlineKeyboardButton(due_date_text, callback_data='remindaction_toggle_due_date')]
]
safe_name = escape_html(sub['name'])
current_status = f"🔔 提醒设置: {safe_name}\n\n"
if sub['renewal_type'] == 'manual':
current_status += f"当前提前提醒: *{sub['reminder_days']}天*\n"
keyboard.append([InlineKeyboardButton("⚙️ 更改提前天数", callback_data='remindaction_ask_days')])
keyboard.append([InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')])
await query.edit_message_text(current_status, reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='HTML')
async def remind_settings_start(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
sub_id_str = query.data.partition('_')[2]
user_id = query.from_user.id
if not sub_id_str.isdigit():
await query.edit_message_text("错误:无效的订阅ID。")
return ConversationHandler.END
sub_id = int(sub_id_str)
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
if not cursor.fetchone():
await query.edit_message_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
logger.debug(f"Starting reminder settings for sub_id: {sub_id}")
context.user_data['sub_id_for_action'] = sub_id
await _display_reminder_settings(query, context, sub_id)
return REMIND_SELECT_ACTION
async def remind_action_handler(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
if not query.data:
return REMIND_SELECT_ACTION
action = query.data.partition('remindaction_')[2]
sub_id = context.user_data.get('sub_id_for_action')
if not sub_id:
await query.edit_message_text("错误:会话已过期,请重试。")
return ConversationHandler.END
user_id = query.from_user.id
if action == 'ask_days':
await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)")
return REMIND_GET_DAYS
if action not in ['toggle_enabled', 'toggle_due_date']:
logger.warning(f"Unexpected action '{query.data}' in remind_action_handler")
return REMIND_SELECT_ACTION
with get_db_connection() as conn:
cursor = conn.cursor()
if action == 'toggle_enabled':
cursor.execute(
"UPDATE subscriptions SET reminders_enabled = CASE WHEN reminders_enabled THEN 0 ELSE 1 END "
"WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
elif action == 'toggle_due_date':
cursor.execute(
"UPDATE subscriptions SET reminder_on_due_date = CASE WHEN reminder_on_due_date THEN 0 ELSE 1 END "
"WHERE id = ? AND user_id = ?",
(sub_id, user_id)
)
if cursor.rowcount == 0:
await query.edit_message_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit()
await _display_reminder_settings(query, context, sub_id)
return REMIND_SELECT_ACTION
async def remind_days_received(update: Update, context: CallbackContext):
sub_id = context.user_data.get('sub_id_for_action')
if not sub_id:
await update.message.reply_text("错误:会话已过期,请重试。")
return ConversationHandler.END
user_id = update.effective_user.id
try:
days = int(update.message.text)
if days < 0:
raise ValueError
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ? AND user_id = ?", (days, sub_id, user_id))
if cursor.rowcount == 0:
await update.message.reply_text("错误:找不到该订阅或无权限。")
return ConversationHandler.END
conn.commit()
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
_clear_action_state(context, ['sub_id_for_action'])
await show_subscription_view(update, context, sub_id)
except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的非负整数。")
return REMIND_GET_DAYS
return ConversationHandler.END
# --- Other Commands ---
async def set_currency(update: Update, context: CallbackContext):
user_id, args = update.effective_user.id, context.args
if len(args) != 1:
await update.message.reply_text("用法: /set_currency 代码(例如 /set_currency USD)", parse_mode='HTML')
return
new_currency = args[0].upper()
if len(new_currency) != 3 or not new_currency.isalpha():
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
return
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO users (user_id, main_currency)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET main_currency = excluded.main_currency
""", (user_id, new_currency))
conn.commit()
await update.message.reply_text(f"您的主货币已设为 {escape_html(new_currency)}。",
parse_mode='HTML')
return ConversationHandler.END
async def cancel(update: Update, context: CallbackContext):
_clear_action_state(context, ['new_sub_data', 'sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
if update.callback_query:
await update.callback_query.answer()
await update.callback_query.edit_message_text('操作已取消。')
else:
await update.message.reply_text('操作已取消。')
return ConversationHandler.END
# --- Main ---
def main():
if not TELEGRAM_TOKEN:
logger.critical("TELEGRAM_TOKEN 环境变量未设置!")
return
if not EXCHANGE_API_KEY:
logger.info("未配置 EXCHANGE_API_KEY,多货币换算将降级为只使用本地缓存(若无缓存则不转换)。")
application = Application.builder().token(TELEGRAM_TOKEN).build()
async def post_init(app: Application):
try:
bot_info = await app.bot.get_me()
logger.info(f"TELEGRAM_TOKEN 验证成功: {bot_info.username}")
except TelegramError as e:
logger.critical(f"TELEGRAM_TOKEN 无效或无法连接 Telegram API: {e}")
raise SystemExit
commands = [
BotCommand("start", "🚀 开始使用"),
BotCommand("add_sub", "➕ 添加新订阅"),
BotCommand("list_subs", "📋 列出所有订阅"),
BotCommand("list_categories", "🗂️ 按分类浏览"),
BotCommand("stats", "📊 查看订阅统计"),
BotCommand("import", "📥 导入订阅"),
BotCommand("export", "📤 导出订阅"),
BotCommand("set_currency", "💲 设置主货币"),
BotCommand("help", "ℹ️ 获取帮助"),
BotCommand("cancel", "❌ 取消当前操作")
]
try:
await app.bot.delete_my_commands()
logger.debug("Cleared existing bot commands")
await app.bot.set_my_commands(commands)
logger.info("Bot commands registered successfully")
except TelegramError as e:
logger.error(f"Failed to register bot commands: {e}")
app.job_queue.run_daily(
check_and_send_reminders,
time=datetime.time(hour=9, minute=0, tzinfo=datetime.timezone(datetime.timedelta(hours=8))),
name='daily_reminders'
)
logger.info("Daily reminder job scheduled.")
application.post_init = post_init
add_conv = ConversationHandler(
entry_points=[CommandHandler('add_sub', add_sub_start)],
states={
ADD_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_name_received)],
ADD_COST: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_cost_received)],
ADD_CURRENCY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_currency_received)],
ADD_CATEGORY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_category_received)],
ADD_NEXT_DUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_next_due_received)],
ADD_FREQ_UNIT: [CallbackQueryHandler(add_freq_unit_received, pattern='^freq_unit_')],
ADD_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_freq_value_received)],
ADD_RENEWAL_TYPE: [CallbackQueryHandler(add_renewal_type_received, pattern='^renewal_')],
ADD_NOTES: [
MessageHandler(filters.TEXT & ~filters.COMMAND, add_notes_received),
CommandHandler('skip', skip_notes)
],
},
fallbacks=[CommandHandler('cancel', cancel)]
)
edit_conv = ConversationHandler(
entry_points=[CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$')],
states={
EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')],
EDIT_GET_NEW_VALUE: [
MessageHandler(filters.TEXT & ~filters.COMMAND, edit_new_value_received),
CallbackQueryHandler(edit_new_value_received, pattern='^editvalue_'),
CommandHandler('empty', edit_new_value_received)
],
EDIT_FREQ_UNIT: [CallbackQueryHandler(edit_freq_unit_received, pattern='^freq_unit_')],
EDIT_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, edit_freq_value_received)],
},
fallbacks=[
CommandHandler('cancel', cancel),
# 【修改】使用新的包装函数来确保会话能正确结束
CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'),
CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'),
CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')
],
per_message=False
)
remind_conv = ConversationHandler(
entry_points=[CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')],
states={
REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')],
REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)],
},
fallbacks=[
CommandHandler('cancel', cancel),
# 【修改】使用新的包装函数来确保会话能正确结束
CallbackQueryHandler(fallback_view_button, pattern=r'^view_\d+$'),
CallbackQueryHandler(edit_start, pattern=r'^edit_\d+$'),
CallbackQueryHandler(remind_settings_start, pattern=r'^remind_\d+$')
],
per_message=False
)
import_conv = ConversationHandler(
entry_points=[CommandHandler('import', import_start)],
states={
IMPORT_UPLOAD: [MessageHandler(filters.Document.ALL, import_upload_received)],
},
fallbacks=[CommandHandler('cancel', cancel)]
)
button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_id_\d+|list_categories|list_all_subs)$'
application.add_handler(CommandHandler('start', start))
application.add_handler(CommandHandler('help', help_command))
application.add_handler(CommandHandler('list_subs', list_subs))
application.add_handler(CommandHandler('list_categories', list_categories))
application.add_handler(CommandHandler('set_currency', set_currency))
application.add_handler(CommandHandler('stats', stats))
application.add_handler(CommandHandler('export', export_command))
application.add_handler(CommandHandler('cancel', cancel))
application.add_handler(add_conv)
application.add_handler(edit_conv)
application.add_handler(remind_conv)
application.add_handler(import_conv)
application.add_handler(CallbackQueryHandler(button_callback_handler, pattern=button_pattern))
logger.info(f"{PROJECT_NAME} Bot is starting...")
application.run_polling()
if __name__ == '__main__':
update_past_due_dates()
main()