1325 lines
60 KiB
Python
1325 lines
60 KiB
Python
import sqlite3
|
||
import os
|
||
import requests
|
||
import datetime
|
||
import dateparser
|
||
import logging
|
||
import pandas as pd
|
||
import matplotlib
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.font_manager as fm
|
||
import re
|
||
from dotenv import load_dotenv
|
||
from dateutil.relativedelta import relativedelta
|
||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, BotCommand, CallbackQuery
|
||
from telegram.ext import (
|
||
Application, CommandHandler, MessageHandler, filters,
|
||
CallbackContext, CallbackQueryHandler, ConversationHandler
|
||
)
|
||
from telegram.error import TelegramError
|
||
from telegram.helpers import escape_markdown
|
||
|
||
# --- 加载 .env 和设置 ---
|
||
load_dotenv()
|
||
|
||
# --- 日志配置 ---
|
||
logging.basicConfig(
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
level=logging.INFO
|
||
)
|
||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# --- 环境变量和项目配置 ---
|
||
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
|
||
EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY')
|
||
PROJECT_NAME = "SubMind"
|
||
DB_FILE = 'submind.db'
|
||
|
||
# --- 对话处理器状态 ---
|
||
(ADD_NAME, ADD_COST, ADD_CURRENCY, ADD_CATEGORY, ADD_NEXT_DUE,
|
||
ADD_FREQ_UNIT, ADD_FREQ_VALUE, ADD_RENEWAL_TYPE, ADD_NOTES) = range(9)
|
||
(EDIT_SELECT_FIELD, EDIT_GET_NEW_VALUE, EDIT_FREQ_UNIT, EDIT_FREQ_VALUE) = range(4)
|
||
(REMIND_SELECT_ACTION, REMIND_GET_DAYS) = range(2)
|
||
(IMPORT_UPLOAD,) = range(1)
|
||
|
||
|
||
# --- 辅助函数:数据库连接管理 ---
|
||
def get_db_connection():
|
||
conn = sqlite3.connect(DB_FILE)
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
|
||
# --- 字体管理 ---
|
||
def get_chinese_font():
|
||
font_name = 'SourceHanSansSC-Regular.otf'
|
||
font_path = os.path.join('fonts', font_name)
|
||
|
||
if os.path.exists(font_path):
|
||
logger.debug(f"Found font at {font_path}")
|
||
return fm.FontProperties(fname=font_path)
|
||
|
||
logger.info(f"Font '{font_name}' not found. Attempting to download...")
|
||
os.makedirs('fonts', exist_ok=True)
|
||
|
||
url = 'https://github.com/wweir/source-han-sans-sc/raw/refs/heads/master/SourceHanSansSC-Regular.otf'
|
||
headers = {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||
}
|
||
|
||
try:
|
||
response = requests.get(url, stream=True, headers=headers, timeout=10)
|
||
response.raise_for_status()
|
||
with open(font_path, 'wb') as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
f.write(chunk)
|
||
logger.info(f"Font '{font_name}' downloaded successfully to '{font_path}'.")
|
||
fm._load_fontmanager(try_read_cache=False)
|
||
return fm.FontProperties(fname=font_path)
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"Failed to download font. Error: {e}")
|
||
return fm.FontProperties(family='sans-serif')
|
||
|
||
|
||
# --- 数据库初始化与迁移 ---
|
||
def init_db():
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||
id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, cost REAL, currency TEXT,
|
||
category TEXT, next_due DATE, frequency TEXT,
|
||
renewal_type TEXT DEFAULT 'auto',
|
||
reminders_enabled BOOLEAN DEFAULT TRUE,
|
||
reminder_days INTEGER DEFAULT 3,
|
||
reminder_on_due_date BOOLEAN DEFAULT TRUE,
|
||
frequency_unit TEXT,
|
||
frequency_value INTEGER,
|
||
notes TEXT
|
||
)
|
||
''')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_subscriptions_user_id ON subscriptions(user_id)')
|
||
cursor.execute("PRAGMA table_info(subscriptions)")
|
||
columns = [info[1] for info in cursor.fetchall()]
|
||
if 'frequency_unit' not in columns:
|
||
cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_unit TEXT")
|
||
if 'frequency_value' not in columns:
|
||
cursor.execute("ALTER TABLE subscriptions ADD COLUMN frequency_value INTEGER")
|
||
if 'reminders_enabled' not in columns:
|
||
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminders_enabled BOOLEAN DEFAULT TRUE")
|
||
if 'reminder_days' not in columns:
|
||
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_days INTEGER DEFAULT 3")
|
||
if 'reminder_on_due_date' not in columns:
|
||
cursor.execute("ALTER TABLE subscriptions ADD COLUMN reminder_on_due_date BOOLEAN DEFAULT TRUE")
|
||
if 'notes' not in columns:
|
||
cursor.execute("ALTER TABLE subscriptions ADD COLUMN notes TEXT")
|
||
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS categories (
|
||
id INTEGER PRIMARY KEY, user_id INTEGER, name TEXT, UNIQUE(user_id, name)
|
||
)
|
||
''')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_categories_user_id ON categories(user_id)')
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS users (
|
||
user_id INTEGER PRIMARY KEY, main_currency TEXT DEFAULT "USD", language TEXT DEFAULT "en"
|
||
)
|
||
''')
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS exchange_rates (
|
||
from_currency TEXT, to_currency TEXT, rate REAL, last_updated TIMESTAMP,
|
||
PRIMARY KEY (from_currency, to_currency)
|
||
)
|
||
''')
|
||
migrate_frequency_data(conn, cursor)
|
||
conn.commit()
|
||
|
||
|
||
def migrate_frequency_data(conn, cursor):
|
||
cursor.execute("SELECT id, frequency FROM subscriptions WHERE frequency IS NOT NULL AND frequency_unit IS NULL")
|
||
subs_to_migrate = cursor.fetchall()
|
||
if not subs_to_migrate:
|
||
return
|
||
freq_map = {
|
||
'daily': ('day', 1), 'weekly': ('week', 1), '周付': ('week', 1), 'monthly': ('month', 1),
|
||
'月付': ('month', 1), 'quarterly': ('month', 3), '季付': ('month', 3), '半年': ('month', 6),
|
||
'half-year': ('month', 6), 'biannually': ('month', 6), 'yearly': ('year', 1), '年付': ('year', 1)
|
||
}
|
||
for sub_id, freq_str in subs_to_migrate:
|
||
unit, value = freq_map.get(str(freq_str).lower(), (None, None))
|
||
if unit and value:
|
||
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ?",
|
||
(unit, value, sub_id))
|
||
conn.commit()
|
||
|
||
|
||
init_db()
|
||
|
||
|
||
# --- 辅助函数 ---
|
||
def get_user_main_currency(user_id):
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('SELECT main_currency FROM users WHERE user_id = ?', (user_id,))
|
||
result = cursor.fetchone()
|
||
return result['main_currency'] if result else 'USD'
|
||
|
||
|
||
def convert_currency(amount, from_curr, to_curr):
|
||
if from_curr.upper() == to_curr.upper():
|
||
return amount
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
SELECT rate, last_updated FROM exchange_rates
|
||
WHERE from_currency = ? AND to_currency = ?
|
||
''', (from_curr.upper(), to_curr.upper()))
|
||
result = cursor.fetchone()
|
||
now = datetime.datetime.now()
|
||
cache_validity = datetime.timedelta(hours=24)
|
||
if result and (now - datetime.datetime.fromisoformat(result['last_updated'])) < cache_validity:
|
||
logger.debug(f"Using cached exchange rate for {from_curr} to {to_curr}: {result['rate']}")
|
||
return amount * result['rate']
|
||
|
||
if not EXCHANGE_API_KEY:
|
||
logger.warning("EXCHANGE_API_KEY not set, returning original amount")
|
||
return amount
|
||
try:
|
||
url = f"https://v6.exchangerate-api.com/v6/{EXCHANGE_API_KEY}/pair/{from_curr}/{to_curr}/{amount}"
|
||
response = requests.get(url, timeout=5)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
rate = data.get('conversion_rate', 1.0)
|
||
cursor.execute('''
|
||
INSERT OR REPLACE INTO exchange_rates (from_currency, to_currency, rate, last_updated)
|
||
VALUES (?, ?, ?, ?)
|
||
''', (from_curr.upper(), to_curr.upper(), rate, now.isoformat()))
|
||
conn.commit()
|
||
logger.debug(f"Updated exchange rate cache for {from_curr} to {to_curr}: {rate}")
|
||
return amount * rate
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"Currency conversion API error: {e}")
|
||
if result:
|
||
logger.warning(f"Falling back to cached rate: {result['rate']}")
|
||
return amount * result['rate']
|
||
logger.warning("No cached rate available, returning original amount")
|
||
return amount
|
||
|
||
|
||
def parse_date(date_string: str) -> str:
|
||
today = datetime.datetime.now()
|
||
try:
|
||
dt = dateparser.parse(date_string, languages=['en', 'zh'])
|
||
if not dt:
|
||
return None
|
||
has_year_info = any(c in date_string for c in ['年', '/']) or (re.search(r'\d{4}', date_string) is not None)
|
||
if not has_year_info and dt.date() < today.date():
|
||
dt = dt.replace(year=dt.year + 1)
|
||
return dt.strftime('%Y-%m-%d')
|
||
except Exception as e:
|
||
logger.error(f"Date parsing failed for string '{date_string}'. Error: {e}")
|
||
return None
|
||
|
||
|
||
def calculate_new_due_date(base_date, unit, value):
|
||
delta_map = {
|
||
'day': relativedelta(days=+value), 'week': relativedelta(weeks=+value),
|
||
'month': relativedelta(months=+value), 'year': relativedelta(years=+value)
|
||
}
|
||
delta = delta_map.get(str(unit).lower())
|
||
return base_date + delta if delta else None
|
||
|
||
|
||
def format_frequency(unit, value) -> str:
|
||
if not unit or value is None:
|
||
return "未知"
|
||
unit_map = {'day': '天', 'week': '周', 'month': '个月', 'year': '年'}
|
||
if value == 1:
|
||
single_unit_map = {'day': '每天', 'week': '每周', 'month': '每月', 'year': '每年'}
|
||
return single_unit_map.get(unit, f"每 {value} {unit_map.get(unit, unit)}")
|
||
return f"每 {value} {unit_map.get(unit, unit)}"
|
||
|
||
|
||
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id]
|
||
if category_filter:
|
||
query += "AND category = ? "
|
||
params.append(category_filter)
|
||
query += "ORDER BY next_due ASC"
|
||
cursor.execute(query, tuple(params))
|
||
subs = cursor.fetchall()
|
||
if not subs:
|
||
return None
|
||
buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs]
|
||
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
|
||
if category_filter:
|
||
keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')])
|
||
else:
|
||
keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')])
|
||
return InlineKeyboardMarkup(keyboard)
|
||
|
||
|
||
# --- 自动任务 ---
|
||
def update_past_due_dates():
|
||
today = datetime.date.today()
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM subscriptions WHERE next_due < ? AND renewal_type = 'auto'", (today,))
|
||
past_due_subs = cursor.fetchall()
|
||
if not past_due_subs:
|
||
return
|
||
for sub in past_due_subs:
|
||
try:
|
||
last_due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date()
|
||
new_due_date = last_due_date
|
||
while new_due_date <= today:
|
||
calculated_date = calculate_new_due_date(new_due_date, sub['frequency_unit'],
|
||
sub['frequency_value'])
|
||
if calculated_date:
|
||
new_due_date = calculated_date
|
||
else:
|
||
break
|
||
if new_due_date > last_due_date:
|
||
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?",
|
||
(new_due_date.strftime('%Y-%m-%d'), sub['id']))
|
||
except Exception as e:
|
||
logger.error(f"Failed to update subscription {sub['id']}: {e}")
|
||
conn.commit()
|
||
|
||
|
||
async def check_and_send_reminders(context: CallbackContext):
|
||
logger.info("Running job: Checking for subscription reminders...")
|
||
today = datetime.date.today()
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM subscriptions WHERE reminders_enabled = TRUE AND next_due IS NOT NULL")
|
||
subs_to_check = cursor.fetchall()
|
||
|
||
for sub in subs_to_check:
|
||
try:
|
||
due_date = datetime.datetime.strptime(sub['next_due'], '%Y-%m-%d').date()
|
||
user_id = sub['user_id']
|
||
renewal_type = sub['renewal_type']
|
||
safe_sub_name = escape_markdown(sub['name'], version=2)
|
||
|
||
message = None
|
||
keyboard = None
|
||
|
||
if renewal_type == 'manual':
|
||
keyboard = InlineKeyboardMarkup([
|
||
[InlineKeyboardButton("✅ 我已续费", callback_data=f"renewfromremind_{sub['id']}")]
|
||
])
|
||
|
||
if sub['reminder_on_due_date'] and due_date == today:
|
||
message = f"🔔 *订阅到期提醒*\n\n您的订阅 `{safe_sub_name}` 今天到期。"
|
||
if renewal_type == 'manual':
|
||
message += " 请记得手动续费。"
|
||
else:
|
||
message += " 将会自动续费。"
|
||
keyboard = None
|
||
|
||
elif renewal_type == 'manual' and sub['reminder_days'] > 0:
|
||
reminder_date = due_date - datetime.timedelta(days=sub['reminder_days'])
|
||
if reminder_date == today:
|
||
days_left = (due_date - today).days
|
||
days_text = f"*{days_left}天后*" if days_left > 0 else "*今天*"
|
||
message = f"🔔 *订阅即将到期提醒*\n\n您的手动续费订阅 `{safe_sub_name}` 将在 {days_text} 到期。"
|
||
|
||
if message:
|
||
await context.bot.send_message(
|
||
chat_id=user_id,
|
||
text=message,
|
||
parse_mode='MarkdownV2',
|
||
reply_markup=keyboard
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to process reminder for sub_id {sub.get('id', 'N/A')}: {e}")
|
||
|
||
|
||
# --- 命令处理器 ---
|
||
async def start(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('INSERT OR IGNORE INTO users (user_id) VALUES (?)', (user_id,))
|
||
conn.commit()
|
||
await update.message.reply_text(f'欢迎使用 {escape_markdown(PROJECT_NAME, version=2)}!\n您的私人订阅智能管家。',
|
||
parse_mode='MarkdownV2')
|
||
|
||
|
||
async def help_command(update: Update, context: CallbackContext):
|
||
help_text = f"""
|
||
*{escape_markdown(PROJECT_NAME, version=2)} 命令列表*
|
||
*🌟 核心功能*
|
||
/add\_sub \- 引导您添加一个新的订阅
|
||
/list\_subs \- 列出您的所有订阅
|
||
/list\_categories \- 按分类浏览您的订阅
|
||
*📊 数据管理*
|
||
/stats \- 查看按类别分类的订阅统计
|
||
/import \- 通过上传 CSV 文件批量导入订阅
|
||
/export \- 将您的所有订阅导出为 CSV 文件
|
||
*⚙️ 个性化设置*
|
||
/set\_currency \`<code>\` \- 设置您的主要货币
|
||
/cancel \- 在任何流程中取消当前操作
|
||
"""
|
||
await update.message.reply_text(help_text, parse_mode='MarkdownV2')
|
||
|
||
|
||
def make_autopct(values, currency_code):
|
||
currency_symbols = {'USD': '$', 'CNY': '¥', 'EUR': '€', 'GBP': '£', 'JPY': '¥'}
|
||
symbol = currency_symbols.get(currency_code.upper(), f'{currency_code} ')
|
||
|
||
def my_autopct(pct):
|
||
total = sum(values)
|
||
val = float(pct * total / 100.0)
|
||
return f'{symbol}{val:.2f}\n({pct:.1f}%)'
|
||
|
||
return my_autopct
|
||
|
||
|
||
async def stats(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
await update.message.reply_text("正在为您生成订阅统计图表,请稍候...")
|
||
|
||
font_prop = get_chinese_font()
|
||
main_currency = get_user_main_currency(user_id)
|
||
with get_db_connection() as conn:
|
||
df = pd.read_sql_query("SELECT * FROM subscriptions WHERE user_id = ?", conn, params=(user_id,))
|
||
if df.empty:
|
||
await update.message.reply_text("您还没有任何订阅数据。")
|
||
return
|
||
|
||
df['converted_cost'] = df.apply(lambda row: convert_currency(row['cost'], row['currency'], main_currency), axis=1)
|
||
unit_to_days = {'day': 1, 'week': 7, 'month': 30.4375, 'year': 365.25}
|
||
|
||
def normalize_to_monthly(row):
|
||
if pd.isna(row['frequency_unit']) or pd.isna(row['frequency_value']) or row['frequency_value'] == 0:
|
||
return 0
|
||
total_days = row['frequency_value'] * unit_to_days.get(row['frequency_unit'], 0)
|
||
if total_days == 0:
|
||
return 0
|
||
return (row['converted_cost'] / total_days) * 30.4375
|
||
|
||
df['monthly_cost'] = df.apply(normalize_to_monthly, axis=1)
|
||
category_costs = df.groupby('category')['monthly_cost'].sum().sort_values(ascending=False)
|
||
|
||
if category_costs.empty or category_costs.sum() == 0:
|
||
await update.message.reply_text("您的订阅没有有效的费用信息。")
|
||
return
|
||
|
||
plt.style.use('seaborn-v0_8-darkgrid')
|
||
fig, ax = plt.subplots(figsize=(12, 12))
|
||
|
||
autopct_function = make_autopct(category_costs.values, main_currency)
|
||
|
||
wedges, texts, autotexts = ax.pie(category_costs.values,
|
||
labels=category_costs.index,
|
||
autopct=autopct_function,
|
||
startangle=140,
|
||
pctdistance=0.7,
|
||
labeldistance=1.05)
|
||
|
||
ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20)
|
||
|
||
for text in texts:
|
||
text.set_fontproperties(font_prop)
|
||
text.set_fontsize(22)
|
||
|
||
for autotext in autotexts:
|
||
autotext.set_fontproperties(font_prop)
|
||
autotext.set_fontsize(20)
|
||
autotext.set_color('white')
|
||
|
||
ax.axis('equal')
|
||
|
||
fig.tight_layout()
|
||
image_path = f'stats_{user_id}.png'
|
||
plt.savefig(image_path)
|
||
plt.close(fig)
|
||
|
||
try:
|
||
with open(image_path, 'rb') as photo:
|
||
await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。")
|
||
finally:
|
||
if os.path.exists(image_path):
|
||
os.remove(image_path)
|
||
|
||
|
||
# --- Import/Export Commands ---
|
||
async def export_command(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
with get_db_connection() as conn:
|
||
df = pd.read_sql_query(
|
||
"SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ?",
|
||
conn, params=(user_id,))
|
||
if df.empty:
|
||
await update.message.reply_text("您还没有任何订阅数据,无法导出。")
|
||
return
|
||
|
||
export_path = f'export_{user_id}.csv'
|
||
df.to_csv(export_path, index=False, encoding='utf-8-sig')
|
||
|
||
try:
|
||
with open(export_path, 'rb') as file:
|
||
await update.message.reply_document(document=file, filename='subscriptions.csv',
|
||
caption="您的订阅数据已导出为 CSV 文件。")
|
||
finally:
|
||
if os.path.exists(export_path):
|
||
os.remove(export_path)
|
||
|
||
|
||
async def import_start(update: Update, context: CallbackContext):
|
||
await update.message.reply_text(
|
||
"请上传一个 CSV 文件以导入订阅数据。\n文件应包含以下列:name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes(notes 可为空)。")
|
||
return IMPORT_UPLOAD
|
||
|
||
|
||
async def import_upload_received(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
if not update.message.document or not update.message.document.file_name.endswith('.csv'):
|
||
await update.message.reply_text("请上传一个有效的 CSV 文件。")
|
||
return IMPORT_UPLOAD
|
||
|
||
file = await update.message.document.get_file()
|
||
file_path = f'import_{user_id}.csv'
|
||
try:
|
||
await file.download_to_drive(file_path)
|
||
df = pd.read_csv(file_path, encoding='utf-8-sig')
|
||
required_columns = ['name', 'cost', 'currency', 'category', 'next_due', 'frequency_unit', 'frequency_value',
|
||
'renewal_type']
|
||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||
if missing_columns:
|
||
await update.message.reply_text(f"CSV 文件缺少以下必要列:{', '.join(missing_columns)}")
|
||
return ConversationHandler.END
|
||
|
||
valid_units = ['day', 'week', 'month', 'year']
|
||
valid_renewal_types = ['auto', 'manual']
|
||
records = []
|
||
for _, row in df.iterrows():
|
||
try:
|
||
cost = float(row['cost'])
|
||
if cost < 0:
|
||
raise ValueError("费用不能为负数")
|
||
currency = str(row['currency']).upper()
|
||
if not (len(currency) == 3 and currency.isalpha()):
|
||
raise ValueError(f"无效货币代码: {currency}")
|
||
next_due = parse_date(str(row['next_due']))
|
||
if not next_due:
|
||
raise ValueError(f"无效日期格式: {row['next_due']}")
|
||
frequency_unit = str(row['frequency_unit']).lower()
|
||
if frequency_unit not in valid_units:
|
||
raise ValueError(f"无效周期单位: {frequency_unit}")
|
||
frequency_value = int(row['frequency_value'])
|
||
if frequency_value <= 0:
|
||
raise ValueError(f"无效周期数量: {frequency_value}")
|
||
renewal_type = str(row['renewal_type']).lower()
|
||
if renewal_type not in valid_renewal_types:
|
||
raise ValueError(f"无效续费类型: {renewal_type}")
|
||
notes = str(row['notes']) if pd.notna(row['notes']) else None
|
||
records.append((
|
||
user_id, row['name'], cost, currency, row['category'],
|
||
next_due, frequency_unit, frequency_value, renewal_type, notes
|
||
))
|
||
except Exception as e:
|
||
logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}")
|
||
await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误:{e}")
|
||
return ConversationHandler.END
|
||
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.executemany('''
|
||
INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
''', records)
|
||
for record in records:
|
||
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, record[4]))
|
||
conn.commit()
|
||
|
||
await update.message.reply_text(f"✅ 成功导入 {len(records)} 条订阅数据!")
|
||
except Exception as e:
|
||
logger.error(f"Import failed: {e}")
|
||
await update.message.reply_text(f"导入失败:{e}")
|
||
finally:
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
return ConversationHandler.END
|
||
|
||
|
||
# --- Add Subscription Conversation ---
|
||
async def add_sub_start(update: Update, context: CallbackContext):
|
||
context.user_data['new_sub_data'] = {}
|
||
await update.message.reply_text("好的,我们来添加一个新订阅。\n\n第一步:请输入订阅的 *名称*",
|
||
parse_mode='MarkdownV2')
|
||
return ADD_NAME
|
||
|
||
|
||
async def add_name_received(update: Update, context: CallbackContext):
|
||
context.user_data['new_sub_data']['name'] = update.message.text
|
||
await update.message.reply_text("第二步:请输入订阅 *费用*", parse_mode='MarkdownV2')
|
||
return ADD_COST
|
||
|
||
|
||
async def add_cost_received(update: Update, context: CallbackContext):
|
||
try:
|
||
cost = float(update.message.text)
|
||
if cost < 0:
|
||
raise ValueError("费用不能为负数")
|
||
context.user_data['new_sub_data']['cost'] = cost
|
||
except (ValueError, TypeError):
|
||
await update.message.reply_text("费用必须是有效的非负数字。")
|
||
return ADD_COST
|
||
await update.message.reply_text("第三步:请输入 *货币* 代码(例如 USD, CNY)", parse_mode='MarkdownV2')
|
||
return ADD_CURRENCY
|
||
|
||
|
||
async def add_currency_received(update: Update, context: CallbackContext):
|
||
currency = update.message.text.upper()
|
||
if not (len(currency) == 3 and currency.isalpha()):
|
||
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
||
return ADD_CURRENCY
|
||
context.user_data['new_sub_data']['currency'] = currency
|
||
await update.message.reply_text("第四步:请为订阅指定一个 *类别*", parse_mode='MarkdownV2')
|
||
return ADD_CATEGORY
|
||
|
||
|
||
async def add_category_received(update: Update, context: CallbackContext):
|
||
user_id, category_name = update.effective_user.id, update.message.text.strip()
|
||
if not category_name:
|
||
await update.message.reply_text("类别不能为空。")
|
||
return ADD_CATEGORY
|
||
context.user_data['new_sub_data']['category'] = category_name
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, category_name))
|
||
conn.commit()
|
||
await update.message.reply_text("第五步:请输入 *下一次付款日期*(例如 2025\\-10\\-01 或 10月1日)",
|
||
parse_mode='MarkdownV2')
|
||
return ADD_NEXT_DUE
|
||
|
||
|
||
async def add_next_due_received(update: Update, context: CallbackContext):
|
||
parsed_date = parse_date(update.message.text)
|
||
if not parsed_date:
|
||
await update.message.reply_text("无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
||
return ADD_NEXT_DUE
|
||
context.user_data['new_sub_data']['next_due'] = parsed_date
|
||
keyboard = [
|
||
[InlineKeyboardButton("天", callback_data='freq_unit_day'),
|
||
InlineKeyboardButton("周", callback_data='freq_unit_week')],
|
||
[InlineKeyboardButton("月", callback_data='freq_unit_month'),
|
||
InlineKeyboardButton("年", callback_data='freq_unit_year')]
|
||
]
|
||
await update.message.reply_text("第六步:请选择付款周期的*单位*", reply_markup=InlineKeyboardMarkup(keyboard),
|
||
parse_mode='MarkdownV2')
|
||
return ADD_FREQ_UNIT
|
||
|
||
|
||
async def add_freq_unit_received(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
context.user_data['new_sub_data']['unit'] = query.data.split('_')[2]
|
||
await query.edit_message_text("第七步:请输入周期的*数量*(例如:每3个月,输入 3)", parse_mode='Markdown')
|
||
return ADD_FREQ_VALUE
|
||
|
||
|
||
async def add_freq_value_received(update: Update, context: CallbackContext):
|
||
try:
|
||
value = int(update.message.text)
|
||
if value <= 0:
|
||
raise ValueError
|
||
context.user_data['new_sub_data']['value'] = value
|
||
except (ValueError, TypeError):
|
||
await update.message.reply_text("请输入一个有效的正整数。")
|
||
return ADD_FREQ_VALUE
|
||
keyboard = [
|
||
[InlineKeyboardButton("自动续费", callback_data='renewal_auto'),
|
||
InlineKeyboardButton("手动续费", callback_data='renewal_manual')]
|
||
]
|
||
await update.message.reply_text("第八步:请选择 *续费方式*", reply_markup=InlineKeyboardMarkup(keyboard),
|
||
parse_mode='MarkdownV2')
|
||
return ADD_RENEWAL_TYPE
|
||
|
||
|
||
async def add_renewal_type_received(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
context.user_data['new_sub_data']['renewal_type'] = query.data.split('_')[1]
|
||
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip)")
|
||
return ADD_NOTES
|
||
|
||
|
||
async def add_notes_received(update: Update, context: CallbackContext):
|
||
sub_data = context.user_data.get('new_sub_data')
|
||
if not sub_data:
|
||
await update.message.reply_text("发生错误,请重试。")
|
||
return ConversationHandler.END
|
||
sub_data['notes'] = update.message.text
|
||
save_subscription(update.effective_user.id, sub_data)
|
||
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
|
||
parse_mode='MarkdownV2')
|
||
context.user_data.clear()
|
||
return ConversationHandler.END
|
||
|
||
|
||
async def skip_notes(update: Update, context: CallbackContext):
|
||
sub_data = context.user_data.get('new_sub_data')
|
||
if not sub_data:
|
||
await update.message.reply_text("发生错误,请重试。")
|
||
return ConversationHandler.END
|
||
sub_data['notes'] = None
|
||
save_subscription(update.effective_user.id, sub_data)
|
||
await update.message.reply_text(text=f"✅ 订阅 '{escape_markdown(sub_data.get('name', ''), version=2)}' 已添加!",
|
||
parse_mode='MarkdownV2')
|
||
context.user_data.clear()
|
||
return ConversationHandler.END
|
||
|
||
|
||
def save_subscription(user_id, data):
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
INSERT INTO subscriptions (user_id, name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
''', (
|
||
user_id, data.get('name'), data.get('cost'), data.get('currency'), data.get('category'),
|
||
data.get('next_due'),
|
||
data.get('unit'), data.get('value'), data.get('renewal_type', 'auto'), data.get('notes')
|
||
))
|
||
conn.commit()
|
||
|
||
|
||
# --- List, View, Edit, Delete ---
|
||
async def list_subs(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
keyboard = await get_subs_list_keyboard(user_id)
|
||
if not keyboard:
|
||
await update.message.reply_text("您还没有任何订阅。")
|
||
return
|
||
await update.message.reply_text("您的所有订阅:", reply_markup=keyboard)
|
||
|
||
|
||
async def list_categories(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT name FROM categories WHERE user_id = ? ORDER BY name", (user_id,))
|
||
categories = cursor.fetchall()
|
||
if not categories:
|
||
if update.callback_query:
|
||
await update.callback_query.edit_message_text("您还没有任何分类。")
|
||
else:
|
||
await update.message.reply_text("您还没有任何分类。")
|
||
return
|
||
|
||
buttons = [InlineKeyboardButton(cat[0], callback_data=f"list_subs_in_category_{cat[0]}") for cat in categories]
|
||
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
|
||
keyboard.append([InlineKeyboardButton("查看全部订阅", callback_data="list_all_subs")])
|
||
if update.callback_query:
|
||
await update.callback_query.edit_message_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard))
|
||
else:
|
||
await update.message.reply_text("请选择一个分类:", reply_markup=InlineKeyboardMarkup(keyboard))
|
||
|
||
|
||
async def show_subscription_view(update: Update, context: CallbackContext, sub_id: int):
|
||
user_id = update.effective_user.id
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||
sub = cursor.fetchone()
|
||
if not sub:
|
||
logger.error(f"Subscription with id {sub_id} not found for user {user_id}")
|
||
if update.effective_message:
|
||
await update.effective_message.reply_text("错误:找不到该订阅。")
|
||
return
|
||
name, cost, currency, category, next_due, renewal_type, reminders_enabled, notes = (
|
||
sub['name'], sub['cost'], sub['currency'], sub['category'], sub['next_due'], sub['renewal_type'],
|
||
sub['reminders_enabled'], sub['notes'])
|
||
freq_text = format_frequency(sub['frequency_unit'], sub['frequency_value'])
|
||
main_currency = get_user_main_currency(user_id)
|
||
converted_cost = convert_currency(cost, currency, main_currency)
|
||
safe_name, safe_category, safe_freq = escape_markdown(name, version=2), escape_markdown(category,
|
||
version=2), escape_markdown(
|
||
freq_text, version=2)
|
||
cost_str, converted_cost_str = escape_markdown(f"{cost:.2f}", version=2), escape_markdown(f"{converted_cost:.2f}",
|
||
version=2)
|
||
renewal_text = "手动续费" if renewal_type == 'manual' else "自动续费"
|
||
reminder_status = "开启" if reminders_enabled else "关闭"
|
||
text = (f"*订阅详情: {safe_name}*\n\n"
|
||
f"\\- *费用*: `{cost_str} {currency.upper()}` \\(\\~`{converted_cost_str} {main_currency.upper()}`\\)\n"
|
||
f"\\- *类别*: `{safe_category}`\n"
|
||
f"\\- *下次付款*: `{next_due}` \\(周期: {safe_freq}\\)\n"
|
||
f"\\- *续费方式*: `{renewal_text}`\n"
|
||
f"\\- *提醒状态*: `{reminder_status}`")
|
||
if notes:
|
||
text += f"\n\\- *备注*: {escape_markdown(notes, version=2)}"
|
||
keyboard_buttons = [
|
||
[InlineKeyboardButton("✏️ 编辑", callback_data=f'edit_{sub_id}'),
|
||
InlineKeyboardButton("🗑️ 删除", callback_data=f'delete_{sub_id}')],
|
||
[InlineKeyboardButton("🔔 提醒设置", callback_data=f'remind_{sub_id}')]
|
||
]
|
||
if renewal_type == 'manual':
|
||
keyboard_buttons.insert(0, [InlineKeyboardButton("✅ 续费", callback_data=f'renewmanual_{sub_id}')])
|
||
if 'list_subs_in_category' in context.user_data:
|
||
cat_filter = context.user_data['list_subs_in_category']
|
||
keyboard_buttons.append(
|
||
[InlineKeyboardButton("« 返回分类订阅", callback_data=f'list_subs_in_category_{cat_filter}')])
|
||
else:
|
||
keyboard_buttons.append([InlineKeyboardButton("« 返回全部订阅", callback_data='list_all_subs')])
|
||
logger.debug(f"Generated buttons for sub_id {sub_id}: edit_{sub_id}, remind_{sub_id}")
|
||
if update.callback_query:
|
||
await update.callback_query.edit_message_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons),
|
||
parse_mode='MarkdownV2')
|
||
elif update.effective_message:
|
||
await update.effective_message.reply_text(text, reply_markup=InlineKeyboardMarkup(keyboard_buttons),
|
||
parse_mode='MarkdownV2')
|
||
|
||
|
||
async def button_callback_handler(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
data = query.data
|
||
user_id = query.from_user.id
|
||
logger.debug(f"Received callback query: {data} from user {user_id}")
|
||
|
||
if data.startswith('list_subs_in_category_'):
|
||
category = data.replace('list_subs_in_category_', '')
|
||
context.user_data['list_subs_in_category'] = category
|
||
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
|
||
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
|
||
if not keyboard:
|
||
msg_text = f"分类“{escape_markdown(category, version=2)}”下没有订阅。"
|
||
keyboard = InlineKeyboardMarkup([[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]])
|
||
await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='MarkdownV2')
|
||
return
|
||
if data == 'list_categories':
|
||
context.user_data.pop('list_subs_in_category', None)
|
||
await list_categories(update, context)
|
||
return
|
||
if data == 'list_all_subs':
|
||
context.user_data.pop('list_subs_in_category', None)
|
||
keyboard = await get_subs_list_keyboard(user_id)
|
||
if not keyboard:
|
||
await query.edit_message_text("您还没有任何订阅。")
|
||
return
|
||
await query.edit_message_text("您的所有订阅:", reply_markup=keyboard)
|
||
return
|
||
|
||
action, _, sub_id_str = data.partition('_')
|
||
sub_id = int(sub_id_str) if sub_id_str.isdigit() else None
|
||
if not sub_id:
|
||
logger.error(f"Invalid sub_id in callback data: {data}")
|
||
await query.edit_message_text("错误:无效的订阅ID。")
|
||
return
|
||
|
||
if action == 'view':
|
||
await show_subscription_view(update, context, sub_id)
|
||
|
||
elif action == 'renewmanual':
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
|
||
sub = cursor.fetchone()
|
||
if sub:
|
||
today = datetime.date.today()
|
||
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
||
if new_due_date:
|
||
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
||
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
|
||
conn.commit()
|
||
await query.answer(f"✅ 续费成功!新到期日: {new_date_str}", show_alert=True)
|
||
await show_subscription_view(update, context, sub_id)
|
||
else:
|
||
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
||
else:
|
||
await query.answer("续费失败:订阅不存在。", show_alert=True)
|
||
|
||
elif action == 'renewfromremind':
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT name, frequency_unit, frequency_value FROM subscriptions WHERE id = ?", (sub_id,))
|
||
sub = cursor.fetchone()
|
||
if sub:
|
||
today = datetime.date.today()
|
||
new_due_date = calculate_new_due_date(today, sub['frequency_unit'], sub['frequency_value'])
|
||
if new_due_date:
|
||
new_date_str = new_due_date.strftime('%Y-%m-%d')
|
||
cursor.execute("UPDATE subscriptions SET next_due = ? WHERE id = ?", (new_date_str, sub_id))
|
||
conn.commit()
|
||
safe_sub_name = escape_markdown(sub['name'], version=2)
|
||
await query.edit_message_text(
|
||
text=f"✅ *续费成功*\n\n您的订阅 `{safe_sub_name}` 新的到期日为: `{new_date_str}`",
|
||
parse_mode='MarkdownV2',
|
||
reply_markup=None
|
||
)
|
||
else:
|
||
await query.answer("续费失败:无法计算新的到期日期。", show_alert=True)
|
||
else:
|
||
await query.answer("续费失败:此订阅可能已被删除。", show_alert=True)
|
||
await query.edit_message_text(text=query.message.text + "\n\n*(错误:此订阅已被删除)*",
|
||
parse_mode='MarkdownV2', reply_markup=None)
|
||
|
||
elif action == 'delete':
|
||
keyboard = InlineKeyboardMarkup([
|
||
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
|
||
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
|
||
])
|
||
await query.edit_message_text(text="您确定要删除这个订阅吗?", reply_markup=keyboard)
|
||
elif action == 'confirmdelete':
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||
conn.commit()
|
||
await query.answer("订阅已删除")
|
||
if 'list_subs_in_category' in context.user_data:
|
||
category = context.user_data['list_subs_in_category']
|
||
keyboard = await get_subs_list_keyboard(user_id, category_filter=category)
|
||
msg_text = f"分类“{escape_markdown(category, version=2)}”下的订阅:"
|
||
if not keyboard:
|
||
msg_text = f"分类“{escape_markdown(category, version=2)}”下没有订阅。"
|
||
keyboard = InlineKeyboardMarkup(
|
||
[[InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')]])
|
||
await query.edit_message_text(msg_text, reply_markup=keyboard, parse_mode='MarkdownV2')
|
||
else:
|
||
keyboard = await get_subs_list_keyboard(user_id)
|
||
if not keyboard:
|
||
await query.edit_message_text("您还没有任何订阅。")
|
||
else:
|
||
await query.edit_message_text("您的所有订阅:", reply_markup=keyboard)
|
||
|
||
|
||
# --- 【新增】包装函数,用于在会话中处理“返回”按钮 ---
|
||
async def fallback_view_button(update: Update, context: CallbackContext):
|
||
"""
|
||
在会话的 fallback 中调用,处理 view_... 按钮的点击。
|
||
它会先显示订阅详情,然后明确地结束当前会话。
|
||
"""
|
||
# 先执行通用的按钮处理逻辑来显示界面
|
||
await button_callback_handler(update, context)
|
||
# 然后返回 END,以确保当前会话(如编辑、提醒设置)被正确终止
|
||
return ConversationHandler.END
|
||
|
||
|
||
# --- Edit Subscription Conversation ---
|
||
async def edit_start(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
sub_id = query.data.split('_')[1]
|
||
logger.debug(f"Starting edit for sub_id: {sub_id}")
|
||
context.user_data['sub_id_for_action'] = sub_id
|
||
keyboard = [
|
||
[InlineKeyboardButton("名称", callback_data="editfield_name"),
|
||
InlineKeyboardButton("费用", callback_data="editfield_cost")],
|
||
[InlineKeyboardButton("货币", callback_data="editfield_currency"),
|
||
InlineKeyboardButton("类别", callback_data="editfield_category")],
|
||
[InlineKeyboardButton("下次付款日", callback_data="editfield_next_due"),
|
||
InlineKeyboardButton("周期", callback_data="editfield_frequency")],
|
||
[InlineKeyboardButton("续费方式", callback_data="editfield_renewal_type"),
|
||
InlineKeyboardButton("📝 备注", callback_data="editfield_notes")],
|
||
[InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')]
|
||
]
|
||
await query.edit_message_text("请选择您想编辑的字段:", reply_markup=InlineKeyboardMarkup(keyboard))
|
||
return EDIT_SELECT_FIELD
|
||
|
||
|
||
async def edit_field_selected(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
field_to_edit = query.data.partition('_')[2]
|
||
context.user_data['field_to_edit'] = field_to_edit
|
||
if field_to_edit == 'renewal_type':
|
||
keyboard = [
|
||
[InlineKeyboardButton("自动续费", callback_data='editvalue_auto'),
|
||
InlineKeyboardButton("手动续费", callback_data='editvalue_manual')]
|
||
]
|
||
await query.edit_message_text("请选择新的续费方式:", reply_markup=InlineKeyboardMarkup(keyboard))
|
||
return EDIT_GET_NEW_VALUE
|
||
if field_to_edit == 'frequency':
|
||
keyboard = [
|
||
[InlineKeyboardButton("天", callback_data='freq_unit_day'),
|
||
InlineKeyboardButton("周", callback_data='freq_unit_week')],
|
||
[InlineKeyboardButton("月", callback_data='freq_unit_month'),
|
||
InlineKeyboardButton("年", callback_data='freq_unit_year')]
|
||
]
|
||
await query.edit_message_text("请选择新的周期*单位*", reply_markup=InlineKeyboardMarkup(keyboard),
|
||
parse_mode='MarkdownV2')
|
||
return EDIT_FREQ_UNIT
|
||
else:
|
||
field_map = {'name': '名称', 'cost': '费用', 'currency': '货币', 'category': '类别', 'next_due': '下次付款日',
|
||
'notes': '备注'}
|
||
prompt = f"好的,请输入新的 *{field_map.get(field_to_edit, field_to_edit)}* 值:"
|
||
if field_to_edit == 'notes':
|
||
prompt += "\n(如需清空备注,请输入 /empty )"
|
||
await query.edit_message_text(prompt, parse_mode='MarkdownV2')
|
||
return EDIT_GET_NEW_VALUE
|
||
|
||
|
||
async def edit_freq_unit_received(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
context.user_data['new_freq_unit'] = query.data.split('_')[2]
|
||
await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2')
|
||
return EDIT_FREQ_VALUE
|
||
|
||
|
||
async def edit_freq_value_received(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
try:
|
||
value = int(update.message.text)
|
||
if value <= 0:
|
||
raise ValueError
|
||
except (ValueError, TypeError):
|
||
await update.message.reply_text("请输入一个有效的正整数。")
|
||
return EDIT_FREQ_VALUE
|
||
unit = context.user_data.get('new_freq_unit')
|
||
sub_id = int(context.user_data.get('sub_id_for_action'))
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("UPDATE subscriptions SET frequency_unit = ?, frequency_value = ? WHERE id = ? AND user_id = ?",
|
||
(unit, value, sub_id, user_id))
|
||
conn.commit()
|
||
await update.message.reply_text("✅ 周期已更新!")
|
||
context.user_data.clear()
|
||
await show_subscription_view(update, context, sub_id)
|
||
return ConversationHandler.END
|
||
|
||
|
||
async def edit_new_value_received(update: Update, context: CallbackContext):
|
||
user_id = update.effective_user.id
|
||
field = context.user_data.get('field_to_edit')
|
||
try:
|
||
sub_id = int(context.user_data.get('sub_id_for_action'))
|
||
except (ValueError, TypeError):
|
||
if update.effective_message:
|
||
await update.effective_message.reply_text("错误:无效的订阅ID。")
|
||
return ConversationHandler.END
|
||
if not field:
|
||
if update.effective_message:
|
||
await update.effective_message.reply_text("错误:未选择要编辑的字段。")
|
||
return ConversationHandler.END
|
||
query, new_value = update.callback_query, ""
|
||
message_to_reply = update.effective_message
|
||
|
||
if update.message and update.message.text == '/empty' and field == 'notes':
|
||
new_value = None
|
||
elif query:
|
||
new_value = query.data.split('_')[1]
|
||
elif update.message:
|
||
new_value = update.message.text
|
||
else:
|
||
if message_to_reply:
|
||
await message_to_reply.reply_text("错误:未提供新值。")
|
||
return ConversationHandler.END
|
||
|
||
validation_failed = False
|
||
if field == 'cost':
|
||
try:
|
||
new_value = float(new_value)
|
||
if new_value < 0:
|
||
raise ValueError("费用不能为负数")
|
||
except (ValueError, TypeError):
|
||
if message_to_reply: await message_to_reply.reply_text("费用必须是有效的非负数字。")
|
||
validation_failed = True
|
||
elif field == 'currency':
|
||
new_value = str(new_value).upper()
|
||
if not (len(new_value) == 3 and new_value.isalpha()):
|
||
if message_to_reply: await message_to_reply.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
||
validation_failed = True
|
||
elif field == 'next_due':
|
||
parsed = parse_date(str(new_value))
|
||
if not parsed:
|
||
if message_to_reply: await message_to_reply.reply_text(
|
||
"无法识别的日期格式,请使用类似 '2025\\-10\\-01' 或 '10月1日' 的格式。")
|
||
validation_failed = True
|
||
else:
|
||
new_value = parsed
|
||
elif field == 'category':
|
||
new_value = str(new_value).strip()
|
||
if not new_value:
|
||
if message_to_reply: await message_to_reply.reply_text("类别不能为空。")
|
||
validation_failed = True
|
||
else:
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?)", (user_id, new_value))
|
||
conn.commit()
|
||
|
||
if validation_failed:
|
||
return EDIT_GET_NEW_VALUE
|
||
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(f"UPDATE subscriptions SET {field} = ? WHERE id = ? AND user_id = ?",
|
||
(new_value, sub_id, user_id))
|
||
conn.commit()
|
||
|
||
if query:
|
||
await query.answer(f"✅ 字段已更新!")
|
||
elif message_to_reply:
|
||
await message_to_reply.reply_text("✅ 字段已更新!")
|
||
|
||
context.user_data.clear()
|
||
await show_subscription_view(update, context, sub_id)
|
||
return ConversationHandler.END
|
||
|
||
|
||
# --- Reminder Settings Conversation ---
|
||
async def _display_reminder_settings(query: CallbackQuery, context: CallbackContext, sub_id: int):
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"SELECT name, renewal_type, reminders_enabled, reminder_on_due_date, reminder_days FROM subscriptions WHERE id = ?",
|
||
(sub_id,))
|
||
sub = cursor.fetchone()
|
||
if not sub:
|
||
await query.edit_message_text("错误:找不到该订阅。")
|
||
return
|
||
enabled_text = "❌ 关闭提醒" if sub['reminders_enabled'] else "✅ 开启提醒"
|
||
due_date_text = "❌ 关闭到期日提醒" if sub['reminder_on_due_date'] else "✅ 开启到期日提醒"
|
||
keyboard = [
|
||
[InlineKeyboardButton(enabled_text, callback_data='remindaction_toggle_enabled')],
|
||
[InlineKeyboardButton(due_date_text, callback_data='remindaction_toggle_due_date')]
|
||
]
|
||
safe_name = escape_markdown(sub['name'], version=2)
|
||
current_status = f"*🔔 提醒设置: {safe_name}*\n\n"
|
||
if sub['renewal_type'] == 'manual':
|
||
current_status += f"当前提前提醒: *{sub['reminder_days']}天*\n"
|
||
keyboard.append([InlineKeyboardButton("⚙️ 更改提前天数", callback_data='remindaction_ask_days')])
|
||
keyboard.append([InlineKeyboardButton("« 返回详情", callback_data=f'view_{sub_id}')])
|
||
await query.edit_message_text(current_status, reply_markup=InlineKeyboardMarkup(keyboard), parse_mode='MarkdownV2')
|
||
|
||
|
||
async def remind_settings_start(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
sub_id_str = query.data.partition('_')[2]
|
||
if not sub_id_str.isdigit():
|
||
await query.edit_message_text("错误:无效的订阅ID。")
|
||
return ConversationHandler.END
|
||
sub_id = int(sub_id_str)
|
||
logger.debug(f"Starting reminder settings for sub_id: {sub_id}")
|
||
context.user_data['sub_id_for_action'] = sub_id
|
||
await _display_reminder_settings(query, context, sub_id)
|
||
return REMIND_SELECT_ACTION
|
||
|
||
|
||
async def remind_action_handler(update: Update, context: CallbackContext):
|
||
query = update.callback_query
|
||
await query.answer()
|
||
|
||
if not query.data:
|
||
return REMIND_SELECT_ACTION
|
||
|
||
action = query.data.partition('remindaction_')[2]
|
||
sub_id = context.user_data.get('sub_id_for_action')
|
||
if not sub_id:
|
||
await query.edit_message_text("错误:会话已过期,请重试。")
|
||
return ConversationHandler.END
|
||
|
||
if action == 'ask_days':
|
||
await query.edit_message_text("请输入您想提前几天收到提醒?(输入0则不提前提醒)")
|
||
return REMIND_GET_DAYS
|
||
|
||
if action not in ['toggle_enabled', 'toggle_due_date']:
|
||
logger.warning(f"Unexpected action '{query.data}' in remind_action_handler")
|
||
return REMIND_SELECT_ACTION
|
||
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
if action == 'toggle_enabled':
|
||
cursor.execute("UPDATE subscriptions SET reminders_enabled = NOT reminders_enabled WHERE id = ?", (sub_id,))
|
||
elif action == 'toggle_due_date':
|
||
cursor.execute("UPDATE subscriptions SET reminder_on_due_date = NOT reminder_on_due_date WHERE id = ?",
|
||
(sub_id,))
|
||
conn.commit()
|
||
await _display_reminder_settings(query, context, sub_id)
|
||
return REMIND_SELECT_ACTION
|
||
|
||
|
||
async def remind_days_received(update: Update, context: CallbackContext):
|
||
sub_id = context.user_data.get('sub_id_for_action')
|
||
if not sub_id:
|
||
await update.message.reply_text("错误:会话已过期,请重试。")
|
||
return ConversationHandler.END
|
||
try:
|
||
days = int(update.message.text)
|
||
if days < 0:
|
||
raise ValueError
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("UPDATE subscriptions SET reminder_days = ? WHERE id = ?", (days, sub_id))
|
||
conn.commit()
|
||
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
|
||
context.user_data.clear()
|
||
await show_subscription_view(update, context, sub_id)
|
||
except (ValueError, TypeError):
|
||
await update.message.reply_text("请输入一个有效的非负整数。")
|
||
return REMIND_GET_DAYS
|
||
return ConversationHandler.END
|
||
|
||
|
||
# --- Other Commands ---
|
||
async def set_currency(update: Update, context: CallbackContext):
|
||
user_id, args = update.effective_user.id, context.args
|
||
if len(args) != 1:
|
||
await update.message.reply_text("用法: /set_currency `<code>`(例如 /set_currency USD)", parse_mode='MarkdownV2')
|
||
return
|
||
new_currency = args[0].upper()
|
||
if len(new_currency) != 3 or not new_currency.isalpha():
|
||
await update.message.reply_text("请输入有效的三字母货币代码(如 USD, CNY)。")
|
||
return
|
||
with get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("INSERT OR REPLACE INTO users (user_id, main_currency) VALUES (?, ?)", (user_id, new_currency))
|
||
conn.commit()
|
||
await update.message.reply_text(f"您的主货币已设为 {escape_markdown(new_currency, version=2)}。",
|
||
parse_mode='MarkdownV2')
|
||
|
||
|
||
async def cancel(update: Update, context: CallbackContext):
|
||
context.user_data.clear()
|
||
if update.callback_query:
|
||
await update.callback_query.answer()
|
||
await update.callback_query.edit_message_text('操作已取消。')
|
||
else:
|
||
await update.message.reply_text('操作已取消。')
|
||
return ConversationHandler.END
|
||
|
||
|
||
# --- Main ---
|
||
def main():
|
||
if not TELEGRAM_TOKEN:
|
||
logger.critical("TELEGRAM_TOKEN 环境变量未设置!")
|
||
return
|
||
|
||
application = Application.builder().token(TELEGRAM_TOKEN).build()
|
||
|
||
async def post_init(app: Application):
|
||
try:
|
||
bot_info = await app.bot.get_me()
|
||
logger.info(f"TELEGRAM_TOKEN 验证成功: {bot_info.username}")
|
||
except TelegramError as e:
|
||
logger.critical(f"TELEGRAM_TOKEN 无效或无法连接 Telegram API: {e}")
|
||
raise SystemExit
|
||
|
||
commands = [
|
||
BotCommand("start", "🚀 开始使用"),
|
||
BotCommand("add_sub", "➕ 添加新订阅"),
|
||
BotCommand("list_subs", "📋 列出所有订阅"),
|
||
BotCommand("list_categories", "🗂️ 按分类浏览"),
|
||
BotCommand("stats", "📊 查看订阅统计"),
|
||
BotCommand("import", "📥 导入订阅"),
|
||
BotCommand("export", "📤 导出订阅"),
|
||
BotCommand("set_currency", "💲 设置主货币"),
|
||
BotCommand("help", "ℹ️ 获取帮助"),
|
||
BotCommand("cancel", "❌ 取消当前操作")
|
||
]
|
||
try:
|
||
await app.bot.delete_my_commands()
|
||
logger.debug("Cleared existing bot commands")
|
||
await app.bot.set_my_commands(commands)
|
||
logger.info("Bot commands registered successfully")
|
||
except TelegramError as e:
|
||
logger.error(f"Failed to register bot commands: {e}")
|
||
|
||
app.job_queue.run_daily(
|
||
check_and_send_reminders,
|
||
time=datetime.time(hour=9, minute=0, tzinfo=datetime.timezone(datetime.timedelta(hours=8))),
|
||
name='daily_reminders'
|
||
)
|
||
logger.info("Daily reminder job scheduled.")
|
||
|
||
application.post_init = post_init
|
||
|
||
add_conv = ConversationHandler(
|
||
entry_points=[CommandHandler('add_sub', add_sub_start)],
|
||
states={
|
||
ADD_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_name_received)],
|
||
ADD_COST: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_cost_received)],
|
||
ADD_CURRENCY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_currency_received)],
|
||
ADD_CATEGORY: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_category_received)],
|
||
ADD_NEXT_DUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_next_due_received)],
|
||
ADD_FREQ_UNIT: [CallbackQueryHandler(add_freq_unit_received, pattern='^freq_unit_')],
|
||
ADD_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, add_freq_value_received)],
|
||
ADD_RENEWAL_TYPE: [CallbackQueryHandler(add_renewal_type_received, pattern='^renewal_')],
|
||
ADD_NOTES: [
|
||
MessageHandler(filters.TEXT & ~filters.COMMAND, add_notes_received),
|
||
CommandHandler('skip', skip_notes)
|
||
],
|
||
},
|
||
fallbacks=[CommandHandler('cancel', cancel)]
|
||
)
|
||
|
||
edit_conv = ConversationHandler(
|
||
entry_points=[CallbackQueryHandler(edit_start, pattern='^edit_')],
|
||
states={
|
||
EDIT_SELECT_FIELD: [CallbackQueryHandler(edit_field_selected, pattern='^editfield_')],
|
||
EDIT_GET_NEW_VALUE: [
|
||
MessageHandler(filters.TEXT & ~filters.COMMAND, edit_new_value_received),
|
||
CallbackQueryHandler(edit_new_value_received, pattern='^editvalue_'),
|
||
CommandHandler('empty', edit_new_value_received)
|
||
],
|
||
EDIT_FREQ_UNIT: [CallbackQueryHandler(edit_freq_unit_received, pattern='^freq_unit_')],
|
||
EDIT_FREQ_VALUE: [MessageHandler(filters.TEXT & ~filters.COMMAND, edit_freq_value_received)],
|
||
},
|
||
fallbacks=[
|
||
CommandHandler('cancel', cancel),
|
||
# 【修改】使用新的包装函数来确保会话能正确结束
|
||
CallbackQueryHandler(fallback_view_button, pattern='^view_'),
|
||
CallbackQueryHandler(edit_start, pattern='^edit_'),
|
||
CallbackQueryHandler(remind_settings_start, pattern='^remind_')
|
||
],
|
||
per_message=False
|
||
)
|
||
|
||
remind_conv = ConversationHandler(
|
||
entry_points=[CallbackQueryHandler(remind_settings_start, pattern='^remind_')],
|
||
states={
|
||
REMIND_SELECT_ACTION: [CallbackQueryHandler(remind_action_handler, pattern='^remindaction_')],
|
||
REMIND_GET_DAYS: [MessageHandler(filters.TEXT & ~filters.COMMAND, remind_days_received)],
|
||
},
|
||
fallbacks=[
|
||
CommandHandler('cancel', cancel),
|
||
# 【修改】使用新的包装函数来确保会话能正确结束
|
||
CallbackQueryHandler(fallback_view_button, pattern='^view_'),
|
||
CallbackQueryHandler(edit_start, pattern='^edit_'),
|
||
CallbackQueryHandler(remind_settings_start, pattern='^remind_')
|
||
],
|
||
per_message=False
|
||
)
|
||
|
||
import_conv = ConversationHandler(
|
||
entry_points=[CommandHandler('import', import_start)],
|
||
states={
|
||
IMPORT_UPLOAD: [MessageHandler(filters.Document.ALL, import_upload_received)],
|
||
},
|
||
fallbacks=[CommandHandler('cancel', cancel)]
|
||
)
|
||
|
||
button_pattern = r'^(view_\d+|renewmanual_\d+|delete_\d+|confirmdelete_\d+|renewfromremind_\d+|list_subs_in_category_.+|list_categories|list_all_subs)$'
|
||
|
||
application.add_handler(CommandHandler('start', start))
|
||
application.add_handler(CommandHandler('help', help_command))
|
||
application.add_handler(CommandHandler('list_subs', list_subs))
|
||
application.add_handler(CommandHandler('list_categories', list_categories))
|
||
application.add_handler(CommandHandler('set_currency', set_currency))
|
||
application.add_handler(CommandHandler('stats', stats))
|
||
application.add_handler(CommandHandler('export', export_command))
|
||
application.add_handler(CommandHandler('cancel', cancel))
|
||
|
||
application.add_handler(add_conv)
|
||
application.add_handler(edit_conv)
|
||
application.add_handler(remind_conv)
|
||
application.add_handler(import_conv)
|
||
application.add_handler(CallbackQueryHandler(button_callback_handler, pattern=button_pattern))
|
||
|
||
logger.info(f"{PROJECT_NAME} Bot is starting...")
|
||
application.run_polling()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
update_past_due_dates()
|
||
main() |