Files
SubMind/SubMind.py
2025-12-08 09:30:12 +08:00

1325 lines
60 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sqlite3
import os
import requests
import datetime
import dateparser
import logging
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import re
from dotenv import load_dotenv
from dateutil.relativedelta import relativedelta
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, BotCommand, CallbackQuery
from telegram.ext import (
Application, CommandHandler, MessageHandler, filters,
CallbackContext, CallbackQueryHandler, ConversationHandler
)
from telegram.error import TelegramError
from telegram.helpers import escape_markdown
# --- 加载 .env 和设置 ---
load_dotenv()
# --- 日志配置 ---
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# --- 环境变量和项目配置 ---
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY')
PROJECT_NAME = "SubMind"
DB_FILE = 'submind.db'
# --- 对话处理器状态 ---
(ADD_NAME, ADD_COST, ADD_CURRENCY, ADD_CATEGORY, ADD_NEXT_DUE,
ADD_FREQ_UNIT, ADD_FREQ_VALUE, ADD_RENEWAL_TYPE, ADD_NOTES) = range(9)
(EDIT_SELECT_FIELD, EDIT_GET_NEW_VALUE, EDIT_FREQ_UNIT, EDIT_FREQ_VALUE) = range(4)
(REMIND_SELECT_ACTION, REMIND_GET_DAYS) = range(2)
(IMPORT_UPLOAD,) = range(1)
# --- 辅助函数:数据库连接管理 ---
def get_db_connection():
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
return conn
# --- 字体管理 ---
def get_chinese_font():
font_name = 'SourceHanSansSC-Regular.otf'
font_path = os.path.join('fonts', font_name)
if os.path.exists(font_path):
logger.debug(f"Found font at {font_path}")
return fm.FontProperties(fname=font_path)
logger.info(f"Font '{font_name}' not found. Attempting to download...")
os.makedirs('fonts', exist_ok=True)
url = 'https://github.com/wweir/source-han-sans-sc/raw/refs/heads/master/SourceHanSansSC-Regular.otf'
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
try:
response = requests.get(url, stream=True, headers=headers, timeout=10)
response.raise_for_status()
with open(font_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Font '{font_name}' downloaded successfully to '{font_path}'.")
fm._load_fontmanager(try_read_cache=False)
return fm.FontProperties(fname=font_path)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to download font. Error: {e}")
return fm.FontProperties(family='sans-serif')
# --- 数据库初始化与迁移 ---
def init_db():
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS subscriptions (
id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, cost REAL, currency TEXT,
category TEXT, next_due DATE, frequency TEXT,
renewal_type TEXT DEFAULT 'auto',
reminders_enabled BOOLEAN DEFAULT TRUE,
reminder_days INTEGER DEFAULT 3,
reminder_on_due_date BOOLEAN DEFAULT TRUE,
frequency_unit TEXT,
frequency_value INTEGER,
notes TEXT
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_subscriptions_user_id ON subscriptions(user_id)')
cursor.execute("PRAGMA table_info(subscriptions)")
columns = [info[1] for info in cursor.fetchall()]
if 'frequency_unit' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_unit TEXT")
if 'frequency_value' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_value INTEGER")
if 'reminders_enabled' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminders_enabled BOOLEAN DEFAULT TRUE")
if 'reminder_days' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_days INTEGER DEFAULT 3")
if 'reminder_on_due_date' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_on_due_date BOOLEAN DEFAULT TRUE")
if 'notes' not in columns:
cursor.execute("ALTER TABLE subscriptions ADD COLUMN notes TEXT")
cursor.execute('''
CREATE TABLE IF NOT EXISTS categories (
id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, UNIQUE(user_id, name)
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_categories_user_id ON categories(user_id)')
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY, main_currency TEXT DEFAULT "USD", language TEXT DEFAULT "en"
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS exchange_rates (
from_currency TEXT, to_currency TEXT, rate REAL, last_updated TIMESTAMP,
PRIMARY KEY (from_currency, to_currency)
)
''')
migrate_frequency_data(conn, cursor)
conn.commit()
def migrate_frequency_data(conn, cursor):
cursor.execute("SELECT id, frequency FROM subscriptions WHERE frequency IS NOT NULL AND frequency_unit IS NULL")
subs_to_migrate = cursor.fetchall()
if not subs_to_migrate:
return
freq_map = {
'daily': ('day', 1), 'weekly': ('week', 1), '周付': ('week', 1), 'monthly': ('month', 1),
'月付': ('month', 1), 'quarterly': ('month', 3), '季付': ('month', 3), '半年': ('month', 6),
'half-year': ('month', 6), 'biannually': ('month', 6), 'yearly': ('year', 1), '年付': ('year', 1)
}
for sub_id, freq_str in subs_to_migrate:
unit, value = freq_map.get(str(freq_str).lower(), (None, None))
if unit and value:
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ?",
(unit, value, sub_id))
conn.commit()
init_db()
# --- 辅助函数 ---
def get_user_main_currency(user_id):
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('SELECT main_currency FROM users WHERE user_id = ?', (user_id,))
result = cursor.fetchone()
return result['main_currency'] if result else 'USD'
def convert_currency(amount, from_curr, to_curr):
if from_curr.upper() == to_curr.upper():
return amount
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT rate, last_updated FROM exchange_rates
WHERE from_currency = ? AND to_currency = ?
''', (from_curr.upper(), to_curr.upper()))
result = cursor.fetchone()
now = datetime.datetime.now()
cache_validity = datetime.timedelta(hours=24)
if result and (now - datetime.datetime.fromisoformat(result['last_updated'])) < cache_validity:
logger.debug(f"Using cached exchange rate for {from_curr} to {to_curr}: {result['rate']}")
return amount * result['rate']
if not EXCHANGE_API_KEY:
logger.warning("EXCHANGE_API_KEY not set, returning original amount")
return amount
try:
url = f"https://v6.exchangerate-api.com/v6/{EXCHANGE_API_KEY}/pair/{from_curr}/{to_curr}/{amount}"
response = requests.get(url, timeout=5)
response.raise_for_status()
data = response.json()
rate = data.get('conversion_rate', 1.0)
cursor.execute('''
INSERT OR REPLACE INTO exchange_rates (from_currency, to_currency, rate, last_updated)
VALUES (?, ?, ?, ?)
''', (from_curr.upper(), to_curr.upper(), rate, now.isoformat()))
conn.commit()
logger.debug(f"Updated exchange rate cache for {from_curr} to {to_curr}: {rate}")
return amount * rate
except requests.exceptions.RequestException as e:
logger.error(f"Currency conversion API error: {e}")
if result:
logger.warning(f"Falling back to cached rate: {result['rate']}")
return amount * result['rate']
logger.warning("No cached rate available, returning original amount")
return amount
def parse_date(date_string: str) -> str:
today = datetime.datetime.now()
try:
dt = dateparser.parse(date_string, languages=['en', 'zh'])
if not dt:
return None
has_year_info = any(c in date_string for c in ['', '/']) or (re.search(r'\d{4}', date_string) is not None)
if not has_year_info and dt.date() < today.date():
dt = dt.replace(year=dt.year + 1)
return dt.strftime('%Y-%m-%d')
except Exception as e:
logger.error(f"Date parsing failed for string '{date_string}'. Error: {e}")
return None
def calculate_new_due_date(base_date, unit, value):
delta_map = {
'day': relativedelta(days=+value), 'week': relativedelta(weeks=+value),
'month': relativedelta(months=+value), 'year': relativedelta(years=+value)
}
delta = delta_map.get(str(unit).lower())
return base_date + delta if delta else None
def format_frequency(unit, value) -> str:
if not unit or value is None:
return "未知"
unit_map = {'day': '', 'week': '', 'month': '个月', 'year': ''}
if value == 1:
single_unit_map = {'day': '每天', 'week': '每周', 'month': '每月', 'year': '每年'}
return single_unit_map.get(unit, f"{value} {unit_map.get(unit, unit)}")
return f"{value} {unit_map.get(unit, unit)}"
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
with get_db_connection() as conn:
cursor = conn.cursor()
query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id]
if category_filter:
query += "AND category = ? "
params.append(category_filter)
query += "ORDER BY next_due ASC"
cursor.execute(query, tuple(params))
subs = cursor.fetchall()
if not subs:
return None
buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs]
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
if category_filter:
keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')])
else:
keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')])
return InlineKeyboardMarkup(keyboard)
# --- 自动任务 ---
def update_past_due_dates():
today = datetime.date.today()
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM subscriptions WHERE next_due < ? AND renewal_type = 'auto'", (today,))
past_due_subs = cursor.fetchall()
if not past_due_subs:
return
for sub in past_due_subs:
try:
last_due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date()
new_due_date = last_due_date
while new_due_date <= today:
calculated_date = calculate_new_due_date(new_due_date, sub['frequency_unit'],
sub['frequency_value'])
if calculated_date:
new_due_date = calculated_date
else:
break
if new_due_date > last_due_date:
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?",
(new_due_date.strftime('%Y-%m-%d'), sub['id']))
except Exception as e:
logger.error(f"Failed to update subscription {sub['id']}: {e}")
conn.commit()
async def check_and_send_reminders(context: CallbackContext):
logger.info("Running job: Checking for subscription reminders...")
today = datetime.date.today()
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM subscriptions WHERE reminders_enabled = TRUE AND next_due IS NOT NULL")
subs_to_check = cursor.fetchall()
for sub in subs_to_check:
try:
due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date()
user_id = sub['user_id']
renewal_type = sub['renewal_type']
safe_sub_name = escape_markdown(sub['name'], version=2)
message = None
keyboard = None
if renewal_type == 'manual':
keyboard = InlineKeyboardMarkup([
[InlineKeyboardButton("✅ 我已续费", callback_data=f"renewfromremind_{sub['id']}")]
])
if sub['reminder_on_due_date'] and due_date == today:
message = f"🔔 *订阅到期提醒*\n\n您的订阅 `{safe_sub_name}` 今天到期。"
if renewal_type == 'manual':
message += " 请记得手动续费。"
else:
message += " 将会自动续费。"
keyboard = None
elif renewal_type == 'manual' and sub['reminder_days'] > 0:
reminder_date = due_date - datetime.timedelta(days=sub['reminder_days'])
if reminder_date == today:
days_left = (due_date - today).days
days_text = f"*{days_left}天后*" if days_left > 0 else "*今天*"
message = f"🔔 *订阅即将到期提醒*\n\n您的手动续费订阅 `{safe_sub_name}` 将在 {days_text} 到期。"
if message:
await context.bot.send_message(
chat_id=user_id,
text=message,
parse_mode='MarkdownV2',
reply_markup=keyboard
)
except Exception as e:
logger.error(f"Failed to process reminder for sub_id {sub.get('id', 'N/A')}: {e}")
# --- 命令处理器 ---
async def start(update: Update, context: CallbackContext):
user_id = update.effective_user.id
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('INSERT OR IGNORE INTO users (user_id) VALUES (?)', (user_id,))
conn.commit()
await update.message.reply_text(f'欢迎使用 {escape_markdown(PROJECT_NAME, version=2)}\n您的私人订阅智能管家。',
parse_mode='MarkdownV2')
async def help_command(update: Update, context: CallbackContext):
help_text = f"""
*{escape_markdown(PROJECT_NAME, version=2)} 命令列表*
*🌟 核心功能*
/add\_sub \- 引导您添加一个新的订阅
/list\_subs \- 列出您的所有订阅
/list\_categories \- 按分类浏览您的订阅
*📊 数据管理*
/stats \- 查看按类别分类的订阅统计
/import \- 通过上传 CSV 文件批量导入订阅
/export \- 将您的所有订阅导出为 CSV 文件
*⚙️ 个性化设置*
/set\_currency \`<code>\` \- 设置您的主要货币
/cancel \- 在任何流程中取消当前操作
"""
await update.message.reply_text(help_text, parse_mode='MarkdownV2')
def make_autopct(values, currency_code):
currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '', 'GBP': '£', 'JPY': '¥'}
symbol = currency_symbols.get(currency_code.upper(), f'{currency_code} ')
def my_autopct(pct):
total = sum(values)
val = float(pct * total / 100.0)
return f'{symbol}{val:.2f}\n({pct:.1f}%)'
return my_autopct
async def stats(update: Update, context: CallbackContext):
user_id = update.effective_user.id
await update.message.reply_text("正在为您生成订阅统计图表,请稍候...")
font_prop = get_chinese_font()
main_currency = get_user_main_currency(user_id)
with get_db_connection() as conn:
df = pd.read_sql_query("SELECT * FROM subscriptions WHERE user_id = ?", conn, params=(user_id,))
if df.empty:
await update.message.reply_text("您还没有任何订阅数据。")
return
df['converted_cost'] = df.apply(lambda row: convert_currency(row['cost'], row['currency'], main_currency), axis=1)
unit_to_days = {'day': 1, 'week': 7, 'month': 30.4375, 'year': 365.25}
def normalize_to_monthly(row):
if pd.isna(row['frequency_unit']) or pd.isna(row['frequency_value']) or row['frequency_value'] == 0:
return 0
total_days = row['frequency_value'] * unit_to_days.get(row['frequency_unit'], 0)
if total_days == 0:
return 0
return (row['converted_cost'] / total_days) * 30.4375
df['monthly_cost'] = df.apply(normalize_to_monthly, axis=1)
category_costs = df.groupby('category')['monthly_cost'].sum().sort_values(ascending=False)
if category_costs.empty or category_costs.sum() == 0:
await update.message.reply_text("您的订阅没有有效的费用信息。")
return
plt.style.use('seaborn-v0_8-darkgrid')
fig, ax = plt.subplots(figsize=(12, 12))
autopct_function = make_autopct(category_costs.values, main_currency)
wedges, texts, autotexts = ax.pie(category_costs.values,
labels=category_costs.index,
autopct=autopct_function,
startangle=140,
pctdistance=0.7,
labeldistance=1.05)
ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20)
for text in texts:
text.set_fontproperties(font_prop)
text.set_fontsize(22)
for autotext in autotexts:
autotext.set_fontproperties(font_prop)
autotext.set_fontsize(20)
autotext.set_color('white')
ax.axis('equal')
fig.tight_layout()
image_path = f'stats_{user_id}.png'
plt.savefig(image_path)
plt.close(fig)
try:
with open(image_path, 'rb') as photo:
await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。")
finally:
if os.path.exists(image_path):
os.remove(image_path)
# --- Import/Export Commands ---
async def export_command(update: Update, context: CallbackContext):
user_id = update.effective_user.id
with get_db_connection() as conn:
df = pd.read_sql_query(
"SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ?",
conn, params=(user_id,))
if df.empty:
await update.message.reply_text("您还没有任何订阅数据,无法导出。")
return
export_path = f'export_{user_id}.csv'
df.to_csv(export_path, index=False, encoding='utf-8-sig')
try:
with open(export_path, 'rb') as file:
await update.message.reply_document(document=file, filename='subscriptions.csv',
caption="您的订阅数据已导出为 CSV 文件。")
finally:
if os.path.exists(export_path):
os.remove(export_path)
async def import_start(update: Update, context: CallbackContext):
await update.message.reply_text(
"请上传一个 CSV 文件以导入订阅数据。\n文件应包含以下列name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notesnotes 可为空)。")
return IMPORT_UPLOAD
async def import_upload_received(update: Update, context: CallbackContext):
user_id = update.effective_user.id
if not update.message.document or not update.message.document.file_name.endswith('.csv'):
await update.message.reply_text("请上传一个有效的 CSV 文件。")
return IMPORT_UPLOAD
file = await update.message.document.get_file()
file_path = f'import_{user_id}.csv'
try:
await file.download_to_drive(file_path)
df = pd.read_csv(file_path, encoding='utf-8-sig')
required_columns = ['name', 'cost', 'currency', 'category', 'next_due', 'frequency_unit', 'frequency_value',
'renewal_type']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
await update.message.reply_text(f"CSV 文件缺少以下必要列:{', '.join(missing_columns)}")
return ConversationHandler.END
valid_units = ['day', 'week', 'month', 'year']
valid_renewal_types = ['auto', 'manual']
records = []
for _, row in df.iterrows():
try:
cost = float(row['cost'])
if cost < 0:
raise ValueError("费用不能为负数")
currency = str(row['currency']).upper()
if not (len(currency) == 3 and currency.isalpha()):
raise ValueError(f"无效货币代码: {currency}")
next_due = parse_date(str(row['next_due']))
if not next_due:
raise ValueError(f"无效日期格式: {row['next_due']}")
frequency_unit = str(row['frequency_unit']).lower()
if frequency_unit not in valid_units:
raise ValueError(f"无效周期单位: {frequency_unit}")
frequency_value = int(row['frequency_value'])
if frequency_value <= 0:
raise ValueError(f"无效周期数量: {frequency_value}")
renewal_type = str(row['renewal_type']).lower()
if renewal_type not in valid_renewal_types:
raise ValueError(f"无效续费类型: {renewal_type}")
notes = str(row['notes']) if pd.notna(row['notes']) else None
records.append((
user_id, row['name'], cost, currency, row['category'],
next_due, frequency_unit, frequency_value, renewal_type, notes
))
except Exception as e:
logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}")
await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误:{e}")
return ConversationHandler.END
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.executemany('''
INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', records)
for record in records:
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, record[4]))
conn.commit()
await update.message.reply_text(f"✅ 成功导入 {len(records)} 条订阅数据!")
except Exception as e:
logger.error(f"Import failed: {e}")
await update.message.reply_text(f"导入失败:{e}")
finally:
if os.path.exists(file_path):
os.remove(file_path)
return ConversationHandler.END
# --- Add Subscription Conversation ---
async def add_sub_start(update: Update, context: CallbackContext):
context.user_data['new_sub_data'] = {}
await update.message.reply_text("好的,我们来添加一个新订阅。\n\n第一步:请输入订阅的 *名称*",
parse_mode='MarkdownV2')
return ADD_NAME
async def add_name_received(update: Update, context: CallbackContext):
context.user_data['new_sub_data']['name'] = update.message.text
await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2')
return ADD_COST
async def add_cost_received(update: Update, context: CallbackContext):
try:
cost = float(update.message.text)
if cost < 0:
raise ValueError("费用不能为负数")
context.user_data['new_sub_data']['cost'] = cost
except (ValueError, TypeError):
await update.message.reply_text("费用必须是有效的非负数字。")
return ADD_COST
await update.message.reply_text("第三步:请输入 *货币* 代码(例如 USD, CNY", parse_mode='MarkdownV2')
return ADD_CURRENCY
async def add_currency_received(update: Update, context: CallbackContext):
currency = update.message.text.upper()
if not (len(currency) == 3 and currency.isalpha()):
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY")
return ADD_CURRENCY
context.user_data['new_sub_data']['currency'] = currency
await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2')
return ADD_CATEGORY
async def add_category_received(update: Update, context: CallbackContext):
user_id, category_name = update.effective_user.id, update.message.text.strip()
if not category_name:
await update.message.reply_text("类别不能为空。")
return ADD_CATEGORY
context.user_data['new_sub_data']['category'] = category_name
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
conn.commit()
await update.message.reply_text("第五步:请输入 *下一次付款日期*(例如 2025\\-10\\-01 或 10月1日",
parse_mode='MarkdownV2')
return ADD_NEXT_DUE
async def add_next_due_received(update: Update, context: CallbackContext):
parsed_date = parse_date(update.message.text)
if not parsed_date:
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。")
return ADD_NEXT_DUE
context.user_data['new_sub_data']['next_due'] = parsed_date
keyboard = [
[InlineKeyboardButton("", callback_data='freq_unit_day'),
InlineKeyboardButton("", callback_data='freq_unit_week')],
[InlineKeyboardButton("", callback_data='freq_unit_month'),
InlineKeyboardButton("", callback_data='freq_unit_year')]
]
await update.message.reply_text("第六步:请选择付款周期的*单位*", reply_markup=InlineKeyboardMarkup(keyboard),
parse_mode='MarkdownV2')
return ADD_FREQ_UNIT
async def add_freq_unit_received(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
context.user_data['new_sub_data']['unit'] = query.data.split('_')[2]
await query.edit_message_text("第七步:请输入周期的*数量*例如每3个月输入 3", parse_mode='Markdown')
return ADD_FREQ_VALUE
async def add_freq_value_received(update: Update, context: CallbackContext):
try:
value = int(update.message.text)
if value <= 0:
raise ValueError
context.user_data['new_sub_data']['value'] = value
except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的正整数。")
return ADD_FREQ_VALUE
keyboard = [
[InlineKeyboardButton("自动续费", callback_data='renewal_auto'),
InlineKeyboardButton("手动续费", callback_data='renewal_manual')]
]
await update.message.reply_text("第八步:请选择 *续费方式*", reply_markup=InlineKeyboardMarkup(keyboard),
parse_mode='MarkdownV2')
return ADD_RENEWAL_TYPE
async def add_renewal_type_received(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
context.user_data['new_sub_data']['renewal_type'] = query.data.split('_')[1]
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip")
return ADD_NOTES
async def add_notes_received(update: Update, context: CallbackContext):
sub_data = context.user_data.get('new_sub_data')
if not sub_data:
await update.message.reply_text("发生错误,请重试。")
return ConversationHandler.END
sub_data['notes'] = update.message.text
save_subscription(update.effective_user.id, sub_data)
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
parse_mode='MarkdownV2')
context.user_data.clear()
return ConversationHandler.END
async def skip_notes(update: Update, context: CallbackContext):
sub_data = context.user_data.get('new_sub_data')
if not sub_data:
await update.message.reply_text("发生错误,请重试。")
return ConversationHandler.END
sub_data['notes'] = None
save_subscription(update.effective_user.id, sub_data)
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
parse_mode='MarkdownV2')
context.user_data.clear()
return ConversationHandler.END
def save_subscription(user_id, data):
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
user_id, data.get('name'), data.get('cost'), data.get('currency'), data.get('category'),
data.get('next_due'),
data.get('unit'), data.get('value'), data.get('renewal_type', 'auto'), data.get('notes')
))
conn.commit()
# --- List, View, Edit, Delete ---
async def list_subs(update: Update, context: CallbackContext):
user_id = update.effective_user.id
keyboard = await get_subs_list_keyboard(user_id)
if not keyboard:
await update.message.reply_text("您还没有任何订阅。")
return
await update.message.reply_text("您的所有订阅:", reply_markup=keyboard)
async def list_categories(update: Update, context: CallbackContext):
user_id = update.effective_user.id
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM categories WHERE user_id = ? ORDER BY name", (user_id,))
categories = cursor.fetchall()
if not categories:
if update.callback_query:
await update.callback_query.edit_message_text("您还没有任何分类。")
else:
await update.message.reply_text("您还没有任何分类。")
return
buttons = [InlineKeyboardButton(cat[0], callback_data=f"list_subs_in_category_{cat[0]}") for cat in categories]
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
if update.callback_query:
await update.callback_query.edit_message_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard))
else:
await update.message.reply_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard))
async def show_subscription_view(update: Update, context: CallbackContext, sub_id: int):
user_id = update.effective_user.id
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
sub = cursor.fetchone()
if not sub:
logger.error(f"Subscription with id {sub_id} not found for user {user_id}")
if update.effective_message:
await update.effective_message.reply_text("错误:找不到该订阅。")
return
name, cost, currency, category, next_due, renewal_type, reminders_enabled, notes = (
sub['name'], sub['cost'], sub['currency'], sub['category'], sub['next_due'], sub['renewal_type'],
sub['reminders_enabled'], sub['notes'])
freq_text = format_frequency(sub['frequency_unit'], sub['frequency_value'])
main_currency = get_user_main_currency(user_id)
converted_cost = convert_currency(cost, currency, main_currency)
safe_name, safe_category, safe_freq = escape_markdown(name, version=2), escape_markdown(category,
version=2), escape_markdown(
freq_text, version=2)
cost_str, converted_cost_str = escape_markdown(f"{cost:.2f}", version=2), escape_markdown(f"{converted_cost:.2f}",
version=2)
renewal_text = "手动续费" if renewal_type == 'manual' else "自动续费"
reminder_status = "开启" if reminders_enabled else "关闭"
text = (f"*订阅详情: {safe_name}*\n\n"
f"\\- *费用*: `{cost_str} {currency.upper()}` \\(\\~`{converted_cost_str} {main_currency.upper()}`\\)\n"
f"\\- *类别*: `{safe_category}`\n"
f"\\- *下次付款*: `{next_due}` \\(周期: {safe_freq}\\)\n"
f"\\- *续费方式*: `{renewal_text}`\n"
f"\\- *提醒状态*: `{reminder_status}`")
if notes:
text += f"\n\\- *备注*: {escape_markdown(notes, version=2)}"
keyboard_buttons = [
[InlineKeyboardButton("✏️ 编辑", callback_data=f'edit_{sub_id}'),
InlineKeyboardButton("🗑️ 删除", callback_data=f'delete_{sub_id}')],
[InlineKeyboardButton("🔔 提醒设置", callback_data=f'remind_{sub_id}')]
]
if renewal_type == 'manual':
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
if 'list_subs_in_category' in context.user_data:
cat_filter = context.user_data['list_subs_in_category']
keyboard_buttons.append(
[InlineKeyboardButton("« 返回分类订阅", callback_data=f'list_subs_in_category_{cat_filter}')])
else:
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
if update.callback_query:
await update.callback_query.edit_message_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons),
parse_mode='MarkdownV2')
elif update.effective_message:
await update.effective_message.reply_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons),
parse_mode='MarkdownV2')
async def button_callback_handler(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
data = query.data
user_id = query.from_user.id
logger.debug(f"Received callback query: {data} from user {user_id}")
if data.startswith('list_subs_in_category_'):
category = data.replace('list_subs_in_category_', '')
context.user_data['list_subs_in_category'] = category
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
if not keyboard:
msg_text = f"分类“{escape_markdown(category, version=2)}”下没有订阅。"
keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]])
await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='MarkdownV2')
return
if data == 'list_categories':
context.user_data.pop('list_subs_in_category', None)
await list_categories(update, context)
return
if data == 'list_all_subs':
context.user_data.pop('list_subs_in_category', None)
keyboard = await get_subs_list_keyboard(user_id)
if not keyboard:
await query.edit_message_text("您还没有任何订阅。")
return
await query.edit_message_text("您的所有订阅:", reply_markup=keyboard)
return
action, _, sub_id_str = data.partition('_')
sub_id = int(sub_id_str) if sub_id_str.isdigit() else None
if not sub_id:
logger.error(f"Invalid sub_id in callback data: {data}")
await query.edit_message_text("错误无效的订阅ID。")
return
if action == 'view':
await show_subscription_view(update, context, sub_id)
elif action == 'renewmanual':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
sub = cursor.fetchone()
if sub:
today = datetime.date.today()
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
if new_due_date:
new_date_str = new_due_date.strftime('%Y-%m-%d')
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
conn.commit()
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
await show_subscription_view(update, context, sub_id)
else:
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
else:
await query.answer("续费失败:订阅不存在。", show_alert=True)
elif action == 'renewfromremind':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
sub = cursor.fetchone()
if sub:
today = datetime.date.today()
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
if new_due_date:
new_date_str = new_due_date.strftime('%Y-%m-%d')
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
conn.commit()
safe_sub_name = escape_markdown(sub['name'], version=2)
await query.edit_message_text(
text=f"✅ *续费成功*\n\n您的订阅 `{safe_sub_name}` 新的到期日为: `{new_date_str}`",
parse_mode='MarkdownV2',
reply_markup=None
)
else:
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
else:
await query.answer("续费失败:此订阅可能已被删除。", show_alert=True)
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅已被删除)*",
parse_mode='MarkdownV2', reply_markup=None)
elif action == 'delete':
keyboard = InlineKeyboardMarkup([
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
])
await query.edit_message_text(text="您确定要删除这个订阅吗?", reply_markup=keyboard)
elif action == 'confirmdelete':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
conn.commit()
await query.answer("订阅已删除")
if 'list_subs_in_category' in context.user_data:
category = context.user_data['list_subs_in_category']
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
if not keyboard:
msg_text = f"分类“{escape_markdown(category, version=2)}”下没有订阅。"
keyboard = InlineKeyboardMarkup(
[[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]])
await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='MarkdownV2')
else:
keyboard = await get_subs_list_keyboard(user_id)
if not keyboard:
await query.edit_message_text("您还没有任何订阅。")
else:
await query.edit_message_text("您的所有订阅:", reply_markup=keyboard)
# --- 【新增】包装函数,用于在会话中处理“返回”按钮 ---
async def fallback_view_button(update: Update, context: CallbackContext):
"""
在会话的 fallback 中调用,处理 view_... 按钮的点击。
它会先显示订阅详情,然后明确地结束当前会话。
"""
# 先执行通用的按钮处理逻辑来显示界面
await button_callback_handler(update, context)
# 然后返回 END以确保当前会话如编辑、提醒设置被正确终止
return ConversationHandler.END
# --- Edit Subscription Conversation ---
async def edit_start(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
sub_id = query.data.split('_')[1]
logger.debug(f"Starting edit for sub_id: {sub_id}")
context.user_data['sub_id_for_action'] = sub_id
keyboard = [
[InlineKeyboardButton("名称", callback_data="editfield_name"),
InlineKeyboardButton("费用", callback_data="editfield_cost")],
[InlineKeyboardButton("货币", callback_data="editfield_currency"),
InlineKeyboardButton("类别", callback_data="editfield_category")],
[InlineKeyboardButton("下次付款日", callback_data="editfield_next_due"),
InlineKeyboardButton("周期", callback_data="editfield_frequency")],
[InlineKeyboardButton("续费方式", callback_data="editfield_renewal_type"),
InlineKeyboardButton("📝 备注", callback_data="editfield_notes")],
[InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')]
]
await query.edit_message_text("请选择您想编辑的字段:", reply_markup=InlineKeyboardMarkup(keyboard))
return EDIT_SELECT_FIELD
async def edit_field_selected(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
field_to_edit = query.data.partition('_')[2]
context.user_data['field_to_edit'] = field_to_edit
if field_to_edit == 'renewal_type':
keyboard = [
[InlineKeyboardButton("自动续费", callback_data='editvalue_auto'),
InlineKeyboardButton("手动续费", callback_data='editvalue_manual')]
]
await query.edit_message_text("请选择新的续费方式:", reply_markup=InlineKeyboardMarkup(keyboard))
return EDIT_GET_NEW_VALUE
if field_to_edit == 'frequency':
keyboard = [
[InlineKeyboardButton("", callback_data='freq_unit_day'),
InlineKeyboardButton("", callback_data='freq_unit_week')],
[InlineKeyboardButton("", callback_data='freq_unit_month'),
InlineKeyboardButton("", callback_data='freq_unit_year')]
]
await query.edit_message_text("请选择新的周期*单位*", reply_markup=InlineKeyboardMarkup(keyboard),
parse_mode='MarkdownV2')
return EDIT_FREQ_UNIT
else:
field_map = {'name': '名称', 'cost': '费用', 'currency': '货币', 'category': '类别', 'next_due': '下次付款日',
'notes': '备注'}
prompt = f"好的,请输入新的 *{field_map.get(field_to_edit, field_to_edit)}* 值:"
if field_to_edit == 'notes':
prompt += "\n(如需清空备注,请输入 /empty "
await query.edit_message_text(prompt, parse_mode='MarkdownV2')
return EDIT_GET_NEW_VALUE
async def edit_freq_unit_received(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
context.user_data['new_freq_unit'] = query.data.split('_')[2]
await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2')
return EDIT_FREQ_VALUE
async def edit_freq_value_received(update: Update, context: CallbackContext):
user_id = update.effective_user.id
try:
value = int(update.message.text)
if value <= 0:
raise ValueError
except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的正整数。")
return EDIT_FREQ_VALUE
unit = context.user_data.get('new_freq_unit')
sub_id = int(context.user_data.get('sub_id_for_action'))
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?",
(unit, value, sub_id, user_id))
conn.commit()
await update.message.reply_text("✅ 周期已更新!")
context.user_data.clear()
await show_subscription_view(update, context, sub_id)
return ConversationHandler.END
async def edit_new_value_received(update: Update, context: CallbackContext):
user_id = update.effective_user.id
field = context.user_data.get('field_to_edit')
try:
sub_id = int(context.user_data.get('sub_id_for_action'))
except (ValueError, TypeError):
if update.effective_message:
await update.effective_message.reply_text("错误无效的订阅ID。")
return ConversationHandler.END
if not field:
if update.effective_message:
await update.effective_message.reply_text("错误:未选择要编辑的字段。")
return ConversationHandler.END
query, new_value = update.callback_query, ""
message_to_reply = update.effective_message
if update.message and update.message.text == '/empty' and field == 'notes':
new_value = None
elif query:
new_value = query.data.split('_')[1]
elif update.message:
new_value = update.message.text
else:
if message_to_reply:
await message_to_reply.reply_text("错误:未提供新值。")
return ConversationHandler.END
validation_failed = False
if field == 'cost':
try:
new_value = float(new_value)
if new_value < 0:
raise ValueError("费用不能为负数")
except (ValueError, TypeError):
if message_to_reply: await message_to_reply.reply_text("费用必须是有效的非负数字。")
validation_failed = True
elif field == 'currency':
new_value = str(new_value).upper()
if not (len(new_value) == 3 and new_value.isalpha()):
if message_to_reply: await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY")
validation_failed = True
elif field == 'next_due':
parsed = parse_date(str(new_value))
if not parsed:
if message_to_reply: await message_to_reply.reply_text(
"无法识别的日期格式,请使用类似 '2025\\-10\\-01''10月1日' 的格式。")
validation_failed = True
else:
new_value = parsed
elif field == 'category':
new_value = str(new_value).strip()
if not new_value:
if message_to_reply: await message_to_reply.reply_text("类别不能为空。")
validation_failed = True
else:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, new_value))
conn.commit()
if validation_failed:
return EDIT_GET_NEW_VALUE
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE subscriptions SET {field} = ? WHERE id = ? AND user_id = ?",
(new_value, sub_id, user_id))
conn.commit()
if query:
await query.answer(f"✅ 字段已更新!")
elif message_to_reply:
await message_to_reply.reply_text("✅ 字段已更新!")
context.user_data.clear()
await show_subscription_view(update, context, sub_id)
return ConversationHandler.END
# --- Reminder Settings Conversation ---
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days FROM subscriptions WHERE id = ?",
(sub_id,))
sub = cursor.fetchone()
if not sub:
await query.edit_message_text("错误:找不到该订阅。")
return
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
keyboard = [
[InlineKeyboardButton(enabled_text, callback_data='remindaction_toggle_enabled')],
[InlineKeyboardButton(due_date_text, callback_data='remindaction_toggle_due_date')]
]
safe_name = escape_markdown(sub['name'], version=2)
current_status = f"*🔔 提醒设置: {safe_name}*\n\n"
if sub['renewal_type'] == 'manual':
current_status += f"当前提前提醒: *{sub['reminder_days']}天*\n"
keyboard.append([InlineKeyboardButton("⚙️ 更改提前天数", callback_data='remindaction_ask_days')])
keyboard.append([InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')])
await query.edit_message_text(current_status, reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='MarkdownV2')
async def remind_settings_start(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
sub_id_str = query.data.partition('_')[2]
if not sub_id_str.isdigit():
await query.edit_message_text("错误无效的订阅ID。")
return ConversationHandler.END
sub_id = int(sub_id_str)
logger.debug(f"Starting reminder settings for sub_id: {sub_id}")
context.user_data['sub_id_for_action'] = sub_id
await _display_reminder_settings(query, context, sub_id)
return REMIND_SELECT_ACTION
async def remind_action_handler(update: Update, context: CallbackContext):
query = update.callback_query
await query.answer()
if not query.data:
return REMIND_SELECT_ACTION
action = query.data.partition('remindaction_')[2]
sub_id = context.user_data.get('sub_id_for_action')
if not sub_id:
await query.edit_message_text("错误:会话已过期,请重试。")
return ConversationHandler.END
if action == 'ask_days':
await query.edit_message_text("请输入您想提前几天收到提醒输入0则不提前提醒")
return REMIND_GET_DAYS
if action not in ['toggle_enabled', 'toggle_due_date']:
logger.warning(f"Unexpected action '{query.data}' in remind_action_handler")
return REMIND_SELECT_ACTION
with get_db_connection() as conn:
cursor = conn.cursor()
if action == 'toggle_enabled':
cursor.execute("UPDATE subscriptions SET reminders_enabled = NOT reminders_enabled WHERE id = ?", (sub_id,))
elif action == 'toggle_due_date':
cursor.execute("UPDATE subscriptions SET reminder_on_due_date = NOT reminder_on_due_date WHERE id = ?",
(sub_id,))
conn.commit()
await _display_reminder_settings(query, context, sub_id)
return REMIND_SELECT_ACTION
async def remind_days_received(update: Update, context: CallbackContext):
sub_id = context.user_data.get('sub_id_for_action')
if not sub_id:
await update.message.reply_text("错误:会话已过期,请重试。")
return ConversationHandler.END
try:
days = int(update.message.text)
if days < 0:
raise ValueError
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ?", (days, sub_id))
conn.commit()
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
context.user_data.clear()
await show_subscription_view(update, context, sub_id)
except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的非负整数。")
return REMIND_GET_DAYS
return ConversationHandler.END
# --- Other Commands ---
async def set_currency(update: Update, context: CallbackContext):
user_id, args = update.effective_user.id, context.args
if len(args) != 1:
await update.message.reply_text("用法: /set_currency `<code>`(例如 /set_currency USD", parse_mode='MarkdownV2')
return
new_currency = args[0].upper()
if len(new_currency) != 3 or not new_currency.isalpha():
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY")
return
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO users (user_id, main_currency) VALUES (?, ?)", (user_id, new_currency))
conn.commit()
await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}",
parse_mode='MarkdownV2')
async def cancel(update: Update, context: CallbackContext):
context.user_data.clear()
if update.callback_query:
await update.callback_query.answer()
await update.callback_query.edit_message_text('操作已取消。')
else:
await update.message.reply_text('操作已取消。')
return ConversationHandler.END
# --- Main ---
def main():
if not TELEGRAM_TOKEN:
logger.critical("TELEGRAM_TOKEN 环境变量未设置!")
return
application = Application.builder().token(TELEGRAM_TOKEN).build()
async def post_init(app: Application):
try:
bot_info = await app.bot.get_me()
logger.info(f"TELEGRAM_TOKEN 验证成功: {bot_info.username}")
except TelegramError as e:
logger.critical(f"TELEGRAM_TOKEN 无效或无法连接 Telegram API: {e}")
raise SystemExit
commands = [
BotCommand("start", "🚀 开始使用"),
BotCommand("add_sub", " 添加新订阅"),
BotCommand("list_subs", "📋 列出所有订阅"),
BotCommand("list_categories", "🗂️ 按分类浏览"),
BotCommand("stats", "📊 查看订阅统计"),
BotCommand("import", "📥 导入订阅"),
BotCommand("export", "📤 导出订阅"),
BotCommand("set_currency", "💲 设置主货币"),
BotCommand("help", " 获取帮助"),
BotCommand("cancel", "❌ 取消当前操作")
]
try:
await app.bot.delete_my_commands()
logger.debug("Cleared existing bot commands")
await app.bot.set_my_commands(commands)
logger.info("Bot commands registered successfully")
except TelegramError as e:
logger.error(f"Failed to register bot commands: {e}")
app.job_queue.run_daily(
check_and_send_reminders,
time=datetime.time(hour=9, minute=0, tzinfo=datetime.timezone(datetime.timedelta(hours=8))),
name='daily_reminders'
)
logger.info("Daily reminder job scheduled.")
application.post_init = post_init
add_conv = ConversationHandler(
entry_points=[CommandHandler('add_sub', add_sub_start)],
states={
ADD_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_name_received)],
ADD_COST: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_cost_received)],
ADD_CURRENCY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_currency_received)],
ADD_CATEGORY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_category_received)],
ADD_NEXT_DUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_next_due_received)],
ADD_FREQ_UNIT: [CallbackQueryHandler(add_freq_unit_received, pattern='^freq_unit_')],
ADD_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_freq_value_received)],
ADD_RENEWAL_TYPE: [CallbackQueryHandler(add_renewal_type_received, pattern='^renewal_')],
ADD_NOTES: [
MessageHandler(filters.TEXT & ~filters.COMMAND, add_notes_received),
CommandHandler('skip', skip_notes)
],
},
fallbacks=[CommandHandler('cancel', cancel)]
)
edit_conv = ConversationHandler(
entry_points=[CallbackQueryHandler(edit_start, pattern='^edit_')],
states={
EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')],
EDIT_GET_NEW_VALUE: [
MessageHandler(filters.TEXT & ~filters.COMMAND, edit_new_value_received),
CallbackQueryHandler(edit_new_value_received, pattern='^editvalue_'),
CommandHandler('empty', edit_new_value_received)
],
EDIT_FREQ_UNIT: [CallbackQueryHandler(edit_freq_unit_received, pattern='^freq_unit_')],
EDIT_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, edit_freq_value_received)],
},
fallbacks=[
CommandHandler('cancel', cancel),
# 【修改】使用新的包装函数来确保会话能正确结束
CallbackQueryHandler(fallback_view_button, pattern='^view_'),
CallbackQueryHandler(edit_start, pattern='^edit_'),
CallbackQueryHandler(remind_settings_start, pattern='^remind_')
],
per_message=False
)
remind_conv = ConversationHandler(
entry_points=[CallbackQueryHandler(remind_settings_start, pattern='^remind_')],
states={
REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')],
REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)],
},
fallbacks=[
CommandHandler('cancel', cancel),
# 【修改】使用新的包装函数来确保会话能正确结束
CallbackQueryHandler(fallback_view_button, pattern='^view_'),
CallbackQueryHandler(edit_start, pattern='^edit_'),
CallbackQueryHandler(remind_settings_start, pattern='^remind_')
],
per_message=False
)
import_conv = ConversationHandler(
entry_points=[CommandHandler('import', import_start)],
states={
IMPORT_UPLOAD: [MessageHandler(filters.Document.ALL, import_upload_received)],
},
fallbacks=[CommandHandler('cancel', cancel)]
)
button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_.+|list_categories|list_all_subs)$'
application.add_handler(CommandHandler('start', start))
application.add_handler(CommandHandler('help', help_command))
application.add_handler(CommandHandler('list_subs', list_subs))
application.add_handler(CommandHandler('list_categories', list_categories))
application.add_handler(CommandHandler('set_currency', set_currency))
application.add_handler(CommandHandler('stats', stats))
application.add_handler(CommandHandler('export', export_command))
application.add_handler(CommandHandler('cancel', cancel))
application.add_handler(add_conv)
application.add_handler(edit_conv)
application.add_handler(remind_conv)
application.add_handler(import_conv)
application.add_handler(CallbackQueryHandler(button_callback_handler, pattern=button_pattern))
logger.info(f"{PROJECT_NAME} Bot is starting...")
application.run_polling()
if __name__ == '__main__':
update_past_due_dates()
main()