@@ -1,5 +1,9 @@
import sqlite3
import sqlite3
import asyncio
import os
import os
import sys
import subprocess
import html
import requests
import requests
import datetime
import datetime
import dateparser
import dateparser
@@ -18,7 +22,10 @@ from telegram.ext import (
CallbackContext , CallbackQueryHandler , ConversationHandler
CallbackContext , CallbackQueryHandler , ConversationHandler
)
)
from telegram . error import TelegramError
from telegram . error import TelegramError
from telegram . helpers import escape_markdown
def escape_html ( text , version = None ) :
if text is None :
return ' '
return html . escape ( str ( text ) )
# --- 加载 .env 和设置 ---
# --- 加载 .env 和设置 ---
load_dotenv ( )
load_dotenv ( )
@@ -37,6 +44,11 @@ EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY')
PROJECT_NAME = " SubMind "
PROJECT_NAME = " SubMind "
DB_FILE = ' submind.db '
DB_FILE = ' submind.db '
# 自动更新配置
UPDATE_OWNER_ID = os . getenv ( ' UPDATE_OWNER_ID ' ) # 仅允许此用户执行 /update
AUTO_UPDATE_REMOTE = os . getenv ( ' AUTO_UPDATE_REMOTE ' , ' https://git.llc/zimk/SubMind.git ' ) . strip ( )
AUTO_UPDATE_BRANCH = os . getenv ( ' AUTO_UPDATE_BRANCH ' , ' main ' ) . strip ( ) or ' main '
# --- 对话处理器状态 ---
# --- 对话处理器状态 ---
( ADD_NAME , ADD_COST , ADD_CURRENCY , ADD_CATEGORY , ADD_NEXT_DUE ,
( ADD_NAME , ADD_COST , ADD_CURRENCY , ADD_CATEGORY , ADD_NEXT_DUE ,
ADD_FREQ_UNIT , ADD_FREQ_VALUE , ADD_RENEWAL_TYPE , ADD_NOTES ) = range ( 9 )
ADD_FREQ_UNIT , ADD_FREQ_VALUE , ADD_RENEWAL_TYPE , ADD_NOTES ) = range ( 9 )
@@ -64,23 +76,34 @@ def get_chinese_font():
logger . info ( f " Font ' { font_name } ' not found. Attempting to download... " )
logger . info ( f " Font ' { font_name } ' not found. Attempting to download... " )
os . makedirs ( ' fonts ' , exist_ok = True )
os . makedirs ( ' fonts ' , exist_ok = True )
url = ' https://github.com/wweir/source-han-sans-sc/raw/refs/heads/master/SourceHanSansSC-Regular.otf '
urls = [
' https://github.com/wweir/source-han-sans-sc/raw/refs/heads/master/SourceHanSansSC-Regular.otf ' ,
' https://cdn.jsdelivr.net/gh/wweir/source-han-sans-sc@master/SourceHanSansSC-Regular.otf ' ,
' https://fastly.jsdelivr.net/gh/wweir/source-han-sans-sc@master/SourceHanSansSC-Regular.otf '
]
headers = {
headers = {
' User-Agent ' : ' Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 '
' User-Agent ' : ' Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 '
}
}
try :
for url in urls :
response = requests . get ( url , stream = True , headers = headers , timeout = 10 )
try :
response . raise_for_status ( )
logger . info ( f " Trying to download font from: { url } " )
with open ( font_path , ' wb ' ) as f :
response = requests . get ( url , stream = True , headers = headers , timeout = 15 )
for chunk in response . iter_content ( chunk_size = 8192 ) :
response . raise_for_status ( )
f . write ( chunk )
with open ( font_path , ' wb ' ) as f :
logger . info ( f " Font ' { font_name } ' downloaded successfully to ' { font_path } ' . " )
for chunk in response . iter_content ( chunk_size = 8192 ) :
fm . _load_fontmanager ( try_read_cache = False )
f . write ( chunk )
return fm . FontProperties ( fname = font_path )
logger . info ( f " Font ' { font_name } ' downloaded successfully to ' { font_path } ' . " )
except requests . exceptions . RequestException as e:
fm . _load_fontmanager ( try_read_cache = Fals e )
logger . error ( f " Failed to download font. Error: { e } " )
return fm . FontProperties ( fname = font_path )
return fm . FontProperties ( family = ' sans-serif ' )
except requests . exceptions . RequestException as e :
logger . warning ( f " Failed to download font from { url } . Error: { e } " )
continue
logger . error ( " All font download attempts failed. Falling back to default sans-serif. " )
return fm . FontProperties ( family = ' sans-serif ' )
# --- 数据库初始化与迁移 ---
# --- 数据库初始化与迁移 ---
@@ -115,6 +138,8 @@ def init_db():
cursor . execute ( " ALTER TABLE subscriptions ADD COLUMN reminder_on_due_date BOOLEAN DEFAULT TRUE " )
cursor . execute ( " ALTER TABLE subscriptions ADD COLUMN reminder_on_due_date BOOLEAN DEFAULT TRUE " )
if ' notes ' not in columns :
if ' notes ' not in columns :
cursor . execute ( " ALTER TABLE subscriptions ADD COLUMN notes TEXT " )
cursor . execute ( " ALTER TABLE subscriptions ADD COLUMN notes TEXT " )
if ' last_reminded_date ' not in columns :
cursor . execute ( " ALTER TABLE subscriptions ADD COLUMN last_reminded_date DATE " )
cursor . execute ( '''
cursor . execute ( '''
CREATE TABLE IF NOT EXISTS categories (
CREATE TABLE IF NOT EXISTS categories (
@@ -211,7 +236,7 @@ def convert_currency(amount, from_curr, to_curr):
def parse_date ( date_string : str ) - > str :
def parse_date ( date_string : str ) - > str :
today = datetime . datetime . now ( )
today = datetime . datetime . now ( )
try :
try :
dt = dateparser . parse ( date_string , languages = [ ' en ' , ' zh ' ] )
dt = dateparser . parse ( date_string , languages = [ ' en ' , ' zh ' ] , settings = { ' TIMEZONE ' : ' Asia/Shanghai ' , ' RETURN_AS_TIMEZONE_AWARE ' : False } )
if not dt :
if not dt :
return None
return None
has_year_info = any ( c in date_string for c in [ ' 年 ' , ' / ' ] ) or ( re . search ( r ' \ d {4} ' , date_string ) is not None )
has_year_info = any ( c in date_string for c in [ ' 年 ' , ' / ' ] ) or ( re . search ( r ' \ d {4} ' , date_string ) is not None )
@@ -335,9 +360,11 @@ def update_past_due_dates():
async def check_and_send_reminders ( context : CallbackContext ) :
async def check_and_send_reminders ( context : CallbackContext ) :
logger . info ( " Running job: Checking for subscription reminders... " )
logger . info ( " Running job: Checking for subscription reminders... " )
today = datetime . date . today ( )
today = datetime . date . today ( )
today_str = today . strftime ( ' % Y- % m- %d ' )
with get_db_connection ( ) as conn :
with get_db_connection ( ) as conn :
cursor = conn . cursor ( )
cursor = conn . cursor ( )
cursor . execute ( " SELECT * FROM subscriptions WHERE reminders_enabled = TRUE AND next_due IS NOT NULL " )
# 过滤掉今天已经提醒过的订阅
cursor . execute ( " SELECT * FROM subscriptions WHERE reminders_enabled = TRUE AND next_due IS NOT NULL AND (last_reminded_date IS NULL OR last_reminded_date != ?) " , ( today_str , ) )
subs_to_check = cursor . fetchall ( )
subs_to_check = cursor . fetchall ( )
for sub in subs_to_check :
for sub in subs_to_check :
@@ -345,7 +372,7 @@ async def check_and_send_reminders(context: CallbackContext):
due_date = datetime . datetime . strptime ( sub [ ' next_due ' ] , ' % Y- % m- %d ' ) . date ( )
due_date = datetime . datetime . strptime ( sub [ ' next_due ' ] , ' % Y- % m- %d ' ) . date ( )
user_id = sub [ ' user_id ' ]
user_id = sub [ ' user_id ' ]
renewal_type = sub [ ' renewal_type ' ]
renewal_type = sub [ ' renewal_type ' ]
safe_sub_name = escape_markdown ( sub [ ' name ' ] , version = 2 )
safe_sub_name = escape_html ( sub [ ' name ' ] )
message = None
message = None
keyboard = None
keyboard = None
@@ -356,7 +383,7 @@ async def check_and_send_reminders(context: CallbackContext):
] )
] )
if sub [ ' reminder_on_due_date ' ] and due_date == today :
if sub [ ' reminder_on_due_date ' ] and due_date == today :
message = f " 🔔 * 订阅到期提醒* \n \n 您的订阅 ` { safe_sub_name } ` 今天到期。"
message = f " 🔔 <b> 订阅到期提醒</b> \n \n 您的订阅 <code> { safe_sub_name } </code> 今天到期。"
if renewal_type == ' manual ' :
if renewal_type == ' manual ' :
message + = " 请记得手动续费。 "
message + = " 请记得手动续费。 "
else :
else :
@@ -367,21 +394,25 @@ async def check_and_send_reminders(context: CallbackContext):
reminder_date = due_date - datetime . timedelta ( days = sub [ ' reminder_days ' ] )
reminder_date = due_date - datetime . timedelta ( days = sub [ ' reminder_days ' ] )
if reminder_date == today :
if reminder_date == today :
days_left = ( due_date - today ) . days
days_left = ( due_date - today ) . days
days_text = f " * { days_left } 天后* " if days_left > 0 else " *今天* "
days_text = f " <b> { days_left } 天后</b> " if days_left > 0 else " <b>今天</b> "
message = f " 🔔 * 订阅即将到期提醒* \n \n 您的手动续费订阅 ` { safe_sub_name } ` 将在 { days_text } 到期。 "
message = f " 🔔 <b> 订阅即将到期提醒</b> \n \n 您的手动续费订阅 <code> { safe_sub_name } </code> 将在 { days_text } 到期。 "
if message :
if message :
await context . bot . send_message (
await context . bot . send_message (
chat_id = user_id ,
chat_id = user_id ,
text = message ,
text = message ,
parse_mode = ' MarkdownV2 ' ,
parse_mode = ' HTML ' ,
reply_markup = keyboard
reply_markup = keyboard
)
)
# 记录今天已发送提醒
with get_db_connection ( ) as update_conn :
update_cursor = update_conn . cursor ( )
update_cursor . execute ( " UPDATE subscriptions SET last_reminded_date = ? WHERE id = ? " , ( today_str , sub [ ' id ' ] ) )
update_conn . commit ( )
logger . info ( f " Reminder sent for sub_id { sub [ ' id ' ] } " )
except Exception as e :
except Exception as e :
logger . error ( f " Failed to process reminder for sub_id { sub . get ( ' id ' , ' N/A ' ) } : { e } " )
logger . error ( f " Failed to process reminder for sub_id { sub . get ( ' id ' , ' N/A ' ) } : { e } " )
# --- 命令处理器 ---
# --- 命令处理器 ---
async def start ( update : Update , context : CallbackContext ) :
async def start ( update : Update , context : CallbackContext ) :
user_id = update . effective_user . id
user_id = update . effective_user . id
@@ -389,26 +420,26 @@ async def start(update: Update, context: CallbackContext):
cursor = conn . cursor ( )
cursor = conn . cursor ( )
cursor . execute ( ' INSERT OR IGNORE INTO users (user_id) VALUES (?) ' , ( user_id , ) )
cursor . execute ( ' INSERT OR IGNORE INTO users (user_id) VALUES (?) ' , ( user_id , ) )
conn . commit ( )
conn . commit ( )
await update . message . reply_text ( f ' 欢迎使用 { escape_markdown ( PROJECT_NAME , version = 2 )} ! \n 您的私人订阅智能管家。 ' ,
await update . message . reply_text ( f ' 欢迎使用 <b> { escape_html ( PROJECT_NAME ) } </b> ! \n 您的私人订阅智能管家。 ' ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
async def help_command ( update : Update , context : CallbackContext ) :
async def help_command ( update : Update , context : CallbackContext ) :
help_text = fr """
help_text = f """
* { escape_markdown ( PROJECT_NAME , version = 2 )} 命令列表*
<b> { escape_html ( PROJECT_NAME ) } 命令列表</b>
* 🌟 核心功能*
<b> 🌟 核心功能</b>
/add\ _sub \ - 引导您添加一个新的订阅
/add_sub - 引导您添加一个新的订阅
/list\ _subs \ - 列出您的所有订阅
/list_subs - 列出您的所有订阅
/list\ _categories \ - 按分类浏览您的订阅
/list_categories - 按分类浏览您的订阅
* 📊 数据管理*
<b> 📊 数据管理</b>
/stats \ - 查看按类别分类的订阅统计
/stats - 查看按类别分类的订阅统计
/import \ - 通过上传 CSV 文件批量导入订阅
/import - 通过上传 CSV 文件批量导入订阅
/export \ - 将您的所有订阅导出为 CSV 文件
/export - 将您的所有订阅导出为 CSV 文件
* ⚙️ 个性化设置*
<b> ⚙️ 个性化设置</b>
/set\ _currency \ `<code> \ ` \ - 设置您的主要货币
/set_currency <代码> - 设置您的主要货币
/cancel \ - 在任何流程中取消当前操作
/cancel - 在任何流程中取消当前操作
"""
"""
await update . message . reply_text ( help_text , parse_mode = ' MarkdownV2 ' )
await update . message . reply_text ( help_text , parse_mode = ' HTML ' )
def make_autopct ( values , currency_code ) :
def make_autopct ( values , currency_code ) :
@@ -425,97 +456,210 @@ def make_autopct(values, currency_code):
async def stats ( update : Update , context : CallbackContext ) :
async def stats ( update : Update , context : CallbackContext ) :
user_id = update . effective_user . id
user_id = update . effective_user . id
await update . message . reply_text ( " 正在为您生成订阅 统计图表 ,请稍候... " )
await update . message . reply_text ( " 正在为您生成更美观的 统计图,请稍候... " )
font_prop = get_chinese_fon t( )
def generate_char t( ) :
main_currency = get_user_main_currency ( user_id )
font_prop = get_chinese_font ( )
with get_db_connection ( ) as conn :
main_currency = get_user_main_currency ( user_id )
df = pd . read_sql_query ( " SELECT * FROM subscriptions WHERE user_id = ? " , conn , params = ( user_id , ) )
with get_db_connection ( ) as conn :
if df . empty :
df = pd . read_sql_query ( " SELECT * FROM subscriptions WHERE user_id = ? " , conn , params = ( user_id , ) )
await update . message . reply_text ( " 您还没有任何订阅数据。 " )
cursor = conn . cursor ( )
return
cursor . execute ( " SELECT from_currency, to_currency, rate FROM exchange_rates WHERE to_currency = ? " , ( main_currency . upper ( ) , ) )
rate_cache = { ( row [ ' from_currency ' ] , row [ ' to_currency ' ] ) : row [ ' rate ' ] for row in cursor . fetchall ( ) }
df [ ' converted_cost ' ] = df . apply ( lambda row : convert_currency ( row [ ' cost ' ] , row [ ' currency ' ] , main_currency ) , axis = 1 )
if df . empty :
unit_to_days = { ' day ' : 1 , ' week ' : 7 , ' month ' : 30.4375 , ' year ' : 365.25 }
return False , " 您还没有任何订阅数据。 "
def normalize_to_ monthly ( row ) :
def fast_convert ( a mou nt, from_curr , to_curr ) :
if pd . isna ( row [ ' frequency_unit ' ] ) or pd . isna ( row [ ' frequency_value ' ] ) or row [ ' frequency_value ' ] == 0 :
if from_curr . upper ( ) == to_curr . upper ( ) :
return 0
return amount
total_days = row [ ' frequency_value ' ] * unit_to_days . get ( row [ ' frequency_unit ' ] , 0 )
cached_rate = rate_cache . get ( ( from_curr . upper ( ) , to_curr . upper ( ) ) )
if total_days == 0 :
if cached_rate is not None :
return 0
return amount * cached_rate
return ( row [ ' converted_cost ' ] / total_days ) * 30.4375
return convert_currency ( amount , from_curr , to_curr )
df [ ' m onthly _cost' ] = df . apply ( normalize_to_monthly , axis = 1 )
df [ ' c onverted _cost' ] = df . apply ( lambda row : fast_convert ( row [ ' cost ' ] , row [ ' currency ' ] , main_currency ) , axis = 1 )
category_costs = df . groupby ( ' category ' ) [ ' monthly_cost ' ] . sum ( ) . sort_values ( ascending = False )
unit_to_days = { ' day ' : 1 , ' week ' : 7 , ' month ' : 30.4375 , ' year ' : 365.25 }
i f category_costs . empty or category_costs . sum ( ) == 0 :
de f normalize_to_monthly ( row ) :
await u pdate . message . reply_text ( " 您的订阅没有有效的费用信息。 " )
if pd . isna ( row [ ' frequency_unit ' ] ) or pd . isna ( row [ ' frequency_value ' ] ) or row [ ' frequency_value ' ] == 0 :
return
return 0
total_days = row [ ' frequency_value ' ] * unit_to_days . get ( row [ ' frequency_unit ' ] , 0 )
if total_days == 0 :
return 0
return ( row [ ' converted_cost ' ] / total_days ) * 30.4375
plt . style . use ( ' seaborn-v0_8-darkgrid ' )
df [ ' monthly_cost ' ] = df . apply ( normalize_to_monthly , axis = 1 )
fig , ax = plt . subplots ( figsize = ( 12 , 12 ) )
category_costs = df . groupby ( ' category ' ) [ ' monthly_cost ' ] . sum ( ) . sort_values ( ascending = False )
image_path = None
try :
if category_costs . empty or category_costs . sum ( ) == 0 :
autopct_function = make_autopct ( category_costs . values , main_currency )
return False , " 您的订阅没有有效的费用信息。 "
wedges , texts , autotexts = ax . pie ( category_costs . values ,
max_categories = 8
labels = category_costs . index ,
if len ( category_costs ) > max_categories :
autopct = autopct_function ,
top = category_costs . iloc [ : max_categories ]
startangle = 140 ,
others_sum = category_costs . iloc [ max_categories : ] . sum ( )
pctdistance = 0.7 ,
if others_sum > 0 :
labeldistance = 1.05 )
category_costs = pd . concat ( [ top , pd . Series ( { ' 其他 ' : others_sum } ) ] )
else :
category_costs = top
ax . set_title ( ' 每月订阅支出分类统计 ' , fontproperties = font_prop , fontsize = 32 , pad = 20 )
total_monthly = category_costs . sum ( )
currency_symbols = { ' USD ' : ' $ ' , ' CNY ' : ' ¥ ' , ' EUR ' : ' € ' , ' GBP ' : ' £ ' , ' JPY ' : ' ¥ ' }
symbol = currency_symbols . get ( main_currency . upper ( ) , f ' { main_currency . upper ( ) } ' )
for text in texts :
def autopct_if_large ( pct ) :
text . set_fontproperties ( font_prop )
if pct < 4 :
text . set_fontsize ( 22 )
return ' '
value = total_monthly * pct / 100
return f " { pct : .1f } % \n { symbol } { value : .2f } "
for autotext in autotexts :
fig = plt . figure ( figsize = ( 15 , 8.5 ) , facecolor = ' #FAFAFA ' )
autotext . set_fontproperties ( font_prop )
gs = fig . add_gridspec ( 1 , 2 , width_ratios = [ 1.1 , 1 ] , wspace = 0.15 )
autotext . set_fontsize ( 20 )
ax_pie = fig . add_subplot ( gs [ 0 , 0 ] )
autotext . set_co lor ( ' white ' )
ax_bar = fig . add_subp lot ( gs [ 0 , 1 ] )
image_path = None
ax . axis ( ' equal ' )
try :
fig . tight_layout ( )
theme_colors = [ ' #3B82F6 ' , ' #10B981 ' , ' #F59E0B ' , ' #EF4444 ' , ' #8B5CF6 ' , ' #EC4899 ' , ' #14B8A6 ' , ' #F97316 ' , ' #6366F1 ' , ' #84CC16 ' ]
if len ( category_costs ) > len ( theme_colors ) :
# 移除导致遮蔽的局部 import, 直接使用全局的 matplotlib 和 plt
extra_colors = [ matplotlib . colors . to_hex ( c ) for c in plt . get_cmap ( ' tab20 ' ) . colors ]
theme_colors . extend ( extra_colors )
with tempfile . NamedTemporaryFile ( prefix = f ' stats_ { user_id } _ ' , suffix = ' .png ' , delete = False ) as tmp :
color_map = { cat : theme_colors [ i ] for i , cat in enumerate ( category_costs . index ) }
image_path = tmp . name
pie_colors = [ color_map [ cat ] for cat in category_costs . index ]
plt . savefig ( image_path )
wedges , texts , autotexts = ax_pie . pie (
category_costs . values ,
labels = category_costs . index ,
autopct = autopct_if_large ,
startangle = 140 ,
counterclock = False ,
pctdistance = 0.75 ,
labeldistance = 1.1 ,
colors = pie_colors ,
wedgeprops = { ' width ' : 0.35 , ' edgecolor ' : ' #FAFAFA ' , ' linewidth ' : 2.5 }
)
for t in texts :
t . set_fontproperties ( font_prop )
t . set_fontsize ( 13 )
t . set_color ( ' #374151 ' )
for t in autotexts :
t . set_fontproperties ( font_prop )
t . set_fontsize ( 10 )
t . set_color ( ' #FFFFFF ' )
t . set_weight ( ' bold ' )
ax_pie . text (
0 , 0 ,
f " 月总支出 \n { symbol } { total_monthly : .2f } " ,
ha = ' center ' , va = ' center ' ,
fontproperties = font_prop ,
fontsize = 18 ,
color = ' #1F2937 ' ,
weight = ' bold '
)
ax_pie . set_title ( ' 支出占比结构 ' , fontproperties = font_prop , fontsize = 18 , pad = 20 , color = ' #111827 ' , weight = ' bold ' )
ax_pie . axis ( ' equal ' )
bar_series = category_costs . sort_values ( ascending = True )
bar_colors = [ color_map [ cat ] for cat in bar_series . index ]
bars = ax_bar . barh ( bar_series . index , bar_series . values , color = bar_colors , height = 0.6 , alpha = 0.95 , edgecolor = ' none ' )
ax_bar . set_title ( ' 各类别月支出对比 ' , fontproperties = font_prop , fontsize = 18 , pad = 20 , color = ' #111827 ' , weight = ' bold ' )
ax_bar . set_xlabel ( f ' 金额( { main_currency . upper ( ) } ) ' , fontproperties = font_prop , fontsize = 12 , color = ' #6B7280 ' , labelpad = 10 )
ax_bar . spines [ ' top ' ] . set_visible ( False )
ax_bar . spines [ ' right ' ] . set_visible ( False )
ax_bar . spines [ ' left ' ] . set_visible ( False )
ax_bar . spines [ ' bottom ' ] . set_color ( ' #E5E7EB ' )
ax_bar . tick_params ( axis = ' x ' , colors = ' #6B7280 ' , labelsize = 11 )
ax_bar . tick_params ( axis = ' y ' , length = 0 , pad = 10 )
ax_bar . grid ( axis = ' x ' , color = ' #F3F4F6 ' , linestyle = ' - ' , linewidth = 1.5 , alpha = 1 )
ax_bar . set_axisbelow ( True )
for label in ax_bar . get_yticklabels ( ) :
label . set_fontproperties ( font_prop )
label . set_fontsize ( 13 )
label . set_color ( ' #374151 ' )
max_val = bar_series . max ( ) if len ( bar_series ) else 0
offset = max_val * 0.02 if max_val > 0 else 0.1
for rect , value in zip ( bars , bar_series . values ) :
ax_bar . text (
rect . get_width ( ) + offset ,
rect . get_y ( ) + rect . get_height ( ) / 2 ,
f " { symbol } { value : .2f } " ,
va = ' center ' ,
ha = ' left ' ,
fontproperties = font_prop ,
fontsize = 11 ,
color = ' #4B5563 ' ,
weight = ' bold '
)
fig . suptitle ( ' 您的订阅统计报告 ' , fontproperties = font_prop , fontsize = 24 , color = ' #0F172A ' , y = 1.02 , weight = ' bold ' )
fig . tight_layout ( rect = [ 0 , 0 , 1 , 0.95 ] )
with tempfile . NamedTemporaryFile ( prefix = f ' stats_ { user_id } _ ' , suffix = ' .png ' , delete = False ) as tmp :
image_path = tmp . name
plt . savefig ( image_path , dpi = 250 , bbox_inches = ' tight ' , facecolor = fig . get_facecolor ( ) )
return True , image_path
finally :
plt . close ( fig )
success , result = await asyncio . to_thread ( generate_chart )
if success :
try :
with open ( result , ' rb ' ) as photo :
await update . message . reply_photo ( photo , caption = " ✨ 已为您生成全新的精美订阅统计图! " )
finally :
if os . path . exists ( result ) :
os . remove ( result )
else :
await update . message . reply_text ( result )
with open ( image_path , ' rb ' ) as photo :
await update . message . reply_photo ( photo , caption = " 这是您按类别统计的每月订阅总支出。 " )
finally :
plt . close ( fig )
if image_path and os . path . exists ( image_path ) :
os . remove ( image_path )
# --- Import/Export Commands ---
# --- Import/Export Commands ---
async def export_command ( update : Update , context : CallbackContext ) :
async def export_command ( update : Update , context : CallbackContext ) :
user_id = update . effective_user . id
user_id = update . effective_user . id
with get_db_connection ( ) as conn :
df = pd . read_sql_query (
# 将重度 I/O 和 CPU 绑定的 pandas 导出操作放入后台线程
" SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ? " ,
def process_export ( ) :
conn , params = ( user_id , ) )
with get_db_connection ( ) as conn :
if df . empt y:
df = pd . read_sql_quer y (
" SELECT name, cost, currency, category, next_due, frequency_unit, frequency_value, renewal_type, notes FROM subscriptions WHERE user_id = ? " ,
conn , params = ( user_id , ) )
if df . empty :
return False , None
tmp = tempfile . NamedTemporaryFile ( prefix = f ' export_ { user_id } _ ' , suffix = ' .csv ' , delete = False )
export_path = tmp . name
tmp . close ( )
df . to_csv ( export_path , index = False , encoding = ' utf-8-sig ' )
return True , export_path
success , export_path = await asyncio . to_thread ( process_export )
if not success :
await update . message . reply_text ( " 您还没有任何订阅数据,无法导出。 " )
await update . message . reply_text ( " 您还没有任何订阅数据,无法导出。 " )
return
return
with tempfile . NamedTemporaryFile ( prefix = f ' export_ { user_id } _ ' , suffix = ' .csv ' , delete = False ) as tmp :
export_path = tmp . name
try :
try :
df . to_csv ( export_path , index = False , encoding = ' utf-8-sig ' )
with open ( export_path , ' rb ' ) as file :
with open ( export_path , ' rb ' ) as file :
await update . message . reply_document ( document = file , filename = ' subscriptions.csv ' ,
await update . message . reply_document ( document = file , filename = ' subscriptions.csv ' ,
caption = " 您的订阅数据已导出为 CSV 文件。 " )
caption = " 您的订阅数据已导出为 CSV 文件。 " )
finally :
finally :
if os . path . exists ( export_path ) :
if export_path and os. path . exists ( export_path ) :
os . remove ( export_path )
os . remove ( export_path )
@@ -612,8 +756,8 @@ async def import_upload_received(update: Update, context: CallbackContext):
# --- Add Subscription Conversation ---
# --- Add Subscription Conversation ---
async def add_sub_start ( update : Update , context : CallbackContext ) :
async def add_sub_start ( update : Update , context : CallbackContext ) :
context . user_data [ ' new_sub_data ' ] = { }
context . user_data [ ' new_sub_data ' ] = { }
await update . message . reply_text ( " 好的,我们来添加一个新订阅。 \n \n 第一步:请输入订阅的 *名称* " ,
await update . message . reply_text ( " 好的,我们来添加一个新订阅。 \n \n 第一步:请输入订阅的 <b>名称</b> " ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
return ADD_NAME
return ADD_NAME
@@ -632,7 +776,7 @@ async def add_name_received(update: Update, context: CallbackContext):
await update . message . reply_text ( f " 订阅名称过长,请控制在 { MAX_NAME_LEN } 个字符以内。 " )
await update . message . reply_text ( f " 订阅名称过长,请控制在 { MAX_NAME_LEN } 个字符以内。 " )
return ADD_NAME
return ADD_NAME
sub_data [ ' name ' ] = name
sub_data [ ' name ' ] = name
await update . message . reply_text ( " 第二步:请输入订阅 *费用* " , parse_mode = ' MarkdownV2 ' )
await update . message . reply_text ( " 第二步:请输入订阅 <b>费用</b> " , parse_mode = ' HTML ' )
return ADD_COST
return ADD_COST
@@ -651,7 +795,7 @@ async def add_cost_received(update: Update, context: CallbackContext):
except ( ValueError , TypeError ) :
except ( ValueError , TypeError ) :
await update . message . reply_text ( " 费用必须是有效的非负数字。 " )
await update . message . reply_text ( " 费用必须是有效的非负数字。 " )
return ADD_COST
return ADD_COST
await update . message . reply_text ( " 第三步:请输入 *货币* 代码(例如 USD, CNY) " , parse_mode = ' MarkdownV2 ' )
await update . message . reply_text ( " 第三步:请输入 <b>货币</b> 代码(例如 USD, CNY) " , parse_mode = ' HTML ' )
return ADD_CURRENCY
return ADD_CURRENCY
@@ -667,7 +811,7 @@ async def add_currency_received(update: Update, context: CallbackContext):
await update . message . reply_text ( " 请输入有效的三字母货币代码(如 USD, CNY) 。 " )
await update . message . reply_text ( " 请输入有效的三字母货币代码(如 USD, CNY) 。 " )
return ADD_CURRENCY
return ADD_CURRENCY
sub_data [ ' currency ' ] = currency
sub_data [ ' currency ' ] = currency
await update . message . reply_text ( " 第四步:请为订阅指定一个 *类别* " , parse_mode = ' MarkdownV2 ' )
await update . message . reply_text ( " 第四步:请为订阅指定一个 <b>类别</b> " , parse_mode = ' HTML ' )
return ADD_CATEGORY
return ADD_CATEGORY
@@ -690,8 +834,8 @@ async def add_category_received(update: Update, context: CallbackContext):
cursor = conn . cursor ( )
cursor = conn . cursor ( )
cursor . execute ( " INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?) " , ( user_id , category_name ) )
cursor . execute ( " INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?) " , ( user_id , category_name ) )
conn . commit ( )
conn . commit ( )
await update . message . reply_text ( " 第五步:请输入 * 下一次付款日期* (例如 2025\\ -10 \\ -01 或 10月1日) " ,
await update . message . reply_text ( " 第五步:请输入 <b> 下一次付款日期</b> (例如 2025-10 -01 或 10月1日) " ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
return ADD_NEXT_DUE
return ADD_NEXT_DUE
@@ -704,7 +848,7 @@ async def add_next_due_received(update: Update, context: CallbackContext):
parsed_date = parse_date ( update . message . text )
parsed_date = parse_date ( update . message . text )
if not parsed_date :
if not parsed_date :
await update . message . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025\\ -10 \\ -01 ' 或 ' 10月1日 ' 的格式。 " )
await update . message . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025-10 -01 ' 或 ' 10月1日 ' 的格式。 " )
return ADD_NEXT_DUE
return ADD_NEXT_DUE
sub_data [ ' next_due ' ] = parsed_date
sub_data [ ' next_due ' ] = parsed_date
keyboard = [
keyboard = [
@@ -713,8 +857,8 @@ async def add_next_due_received(update: Update, context: CallbackContext):
[ InlineKeyboardButton ( " 月 " , callback_data = ' freq_unit_month ' ) ,
[ InlineKeyboardButton ( " 月 " , callback_data = ' freq_unit_month ' ) ,
InlineKeyboardButton ( " 年 " , callback_data = ' freq_unit_year ' ) ]
InlineKeyboardButton ( " 年 " , callback_data = ' freq_unit_year ' ) ]
]
]
await update . message . reply_text ( " 第六步:请选择付款周期的*单位* " , reply_markup = InlineKeyboardMarkup ( keyboard ) ,
await update . message . reply_text ( " 第六步:请选择付款周期的<b>单位</b> " , reply_markup = InlineKeyboardMarkup ( keyboard ) ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
return ADD_FREQ_UNIT
return ADD_FREQ_UNIT
@@ -731,7 +875,7 @@ async def add_freq_unit_received(update: Update, context: CallbackContext):
await query . edit_message_text ( " 错误:无效的周期单位,请重试。 " )
await query . edit_message_text ( " 错误:无效的周期单位,请重试。 " )
return ConversationHandler . END
return ConversationHandler . END
sub_data [ ' unit ' ] = unit
sub_data [ ' unit ' ] = unit
await query . edit_message_text ( " 第七步:请输入周期的*数量* ( 例如: 每3个月, 输入 3) " , parse_mode = ' Markdown ' )
await query . edit_message_text ( " 第七步:请输入周期的<b>数量</b> ( 例如: 每3个月, 输入 3) " , parse_mode = ' HTML ' )
return ADD_FREQ_VALUE
return ADD_FREQ_VALUE
@@ -754,8 +898,8 @@ async def add_freq_value_received(update: Update, context: CallbackContext):
[ InlineKeyboardButton ( " 自动续费 " , callback_data = ' renewal_auto ' ) ,
[ InlineKeyboardButton ( " 自动续费 " , callback_data = ' renewal_auto ' ) ,
InlineKeyboardButton ( " 手动续费 " , callback_data = ' renewal_manual ' ) ]
InlineKeyboardButton ( " 手动续费 " , callback_data = ' renewal_manual ' ) ]
]
]
await update . message . reply_text ( " 第八步:请选择 * 续费方式* " , reply_markup = InlineKeyboardMarkup ( keyboard ) ,
await update . message . reply_text ( " 第八步:请选择 <b> 续费方式</b> " , reply_markup = InlineKeyboardMarkup ( keyboard ) ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
return ADD_RENEWAL_TYPE
return ADD_RENEWAL_TYPE
@@ -790,8 +934,8 @@ async def add_notes_received(update: Update, context: CallbackContext):
return ADD_NOTES
return ADD_NOTES
sub_data [ ' notes ' ] = note if note else None
sub_data [ ' notes ' ] = note if note else None
save_subscription ( update . effective_user . id , sub_data )
save_subscription ( update . effective_user . id , sub_data )
await update . message . reply_text ( text = f " ✅ 订阅 ' { escape_markdown ( sub_data . get ( ' name ' , ' ' ) , version = 2 )} ' 已添加! " ,
await update . message . reply_text ( text = f " ✅ 订阅 ' { escape_html ( sub_data . get ( ' name ' , ' ' ) ) } ' 已添加! " ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
_clear_action_state ( context , [ ' new_sub_data ' ] )
_clear_action_state ( context , [ ' new_sub_data ' ] )
return ConversationHandler . END
return ConversationHandler . END
@@ -806,8 +950,8 @@ async def skip_notes(update: Update, context: CallbackContext):
sub_data [ ' notes ' ] = None
sub_data [ ' notes ' ] = None
save_subscription ( update . effective_user . id , sub_data )
save_subscription ( update . effective_user . id , sub_data )
await update . message . reply_text ( text = f " ✅ 订阅 ' { escape_markdown ( sub_data . get ( ' name ' , ' ' ) , version = 2 )} ' 已添加! " ,
await update . message . reply_text ( text = f " ✅ 订阅 ' { escape_html ( sub_data . get ( ' name ' , ' ' ) ) } ' 已添加! " ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
_clear_action_state ( context , [ ' new_sub_data ' ] )
_clear_action_state ( context , [ ' new_sub_data ' ] )
return ConversationHandler . END
return ConversationHandler . END
@@ -878,22 +1022,19 @@ async def show_subscription_view(update: Update, context: CallbackContext, sub_i
sub [ ' reminders_enabled ' ] , sub [ ' notes ' ] )
sub [ ' reminders_enabled ' ] , sub [ ' notes ' ] )
freq_text = format_frequency ( sub [ ' frequency_unit ' ] , sub [ ' frequency_value ' ] )
freq_text = format_frequency ( sub [ ' frequency_unit ' ] , sub [ ' frequency_value ' ] )
main_currency = get_user_main_currency ( user_id )
main_currency = get_user_main_currency ( user_id )
converted_cost = convert_currency ( cost , currency , main_currency )
converted_cost = await asyncio . to_thread ( convert_currency , cost , currency , main_currency )
safe_name , safe_category , safe_freq = escape_markdown ( name , version = 2 ) , escape_markdown ( category ,
safe_name , safe_category , safe_freq = escape_html ( name ) , escape_html ( category ) , escape_html ( freq_text )
version = 2 ) , escape_markdown (
cost_str , converted_cost_str = escape_html ( f " { cost : .2f } " ) , escape_html ( f " { converted_cost : .2f } " )
freq_text , version = 2 )
cost_str , converted_cost_str = escape_markdown ( f " { cost : .2f } " , version = 2 ) , escape_markdown ( f " { converted_cost : .2f } " ,
version = 2 )
renewal_text = " 手动续费 " if renewal_type == ' manual ' else " 自动续费 "
renewal_text = " 手动续费 " if renewal_type == ' manual ' else " 自动续费 "
reminder_status = " 开启 " if reminders_enabled else " 关闭 "
reminder_status = " 开启 " if reminders_enabled else " 关闭 "
text = ( f " * 订阅详情: { safe_name } * \n \n "
text = ( f " <b> 订阅详情: { safe_name } </b> \n \n "
f " \\ - *费用*: ` { cost_str } { currency . upper ( ) } ` \\ ( \\ ~` { converted_cost_str } { main_currency . upper ( ) } ` \\ )\n "
f " - <b>费用</b>: <code> { cost_str } { currency . upper ( ) } </code> (~<code> { converted_cost_str } { main_currency . upper ( ) } </code> )\n "
f " \\ - *类别*: ` { safe_category } ` \n "
f " - <b>类别</b>: <code> { safe_category } </code> \n "
f " \\ - *下次付款*: ` { next_due } ` \\ (周期: { safe_freq } \\ ) \n "
f " - <b>下次付款</b>: <code> { next_due } </code> (周期: { safe_freq } ) \n "
f " \\ - *续费方式*: ` { renewal_text } ` \n "
f " - <b>续费方式</b>: <code> { renewal_text } </code> \n "
f " \\ - *提醒状态*: ` { reminder_status } ` " )
f " - <b>提醒状态</b>: <code> { reminder_status } </code> " )
if notes :
if notes :
text + = f " \n \\ - *备注* : { escape_markdown ( notes , version = 2 )} "
text + = f " \n - <b>备注</b> : { escape_html ( notes ) } "
keyboard_buttons = [
keyboard_buttons = [
[ InlineKeyboardButton ( " ✏️ 编辑 " , callback_data = f ' edit_ { sub_id } ' ) ,
[ InlineKeyboardButton ( " ✏️ 编辑 " , callback_data = f ' edit_ { sub_id } ' ) ,
InlineKeyboardButton ( " 🗑️ 删除 " , callback_data = f ' delete_ { sub_id } ' ) ] ,
InlineKeyboardButton ( " 🗑️ 删除 " , callback_data = f ' delete_ { sub_id } ' ) ] ,
@@ -914,10 +1055,10 @@ async def show_subscription_view(update: Update, context: CallbackContext, sub_i
logger . debug ( f " Generated buttons for sub_id { sub_id } : edit_ { sub_id } , remind_ { sub_id } " )
logger . debug ( f " Generated buttons for sub_id { sub_id } : edit_ { sub_id } , remind_ { sub_id } " )
if update . callback_query :
if update . callback_query :
await update . callback_query . edit_message_text ( text , reply_markup = InlineKeyboardMarkup ( keyboard_buttons ) ,
await update . callback_query . edit_message_text ( text , reply_markup = InlineKeyboardMarkup ( keyboard_buttons ) ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
elif update . effective_message :
elif update . effective_message :
await update . effective_message . reply_text ( text , reply_markup = InlineKeyboardMarkup ( keyboard_buttons ) ,
await update . effective_message . reply_text ( text , reply_markup = InlineKeyboardMarkup ( keyboard_buttons ) ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
async def button_callback_handler ( update : Update , context : CallbackContext ) :
async def button_callback_handler ( update : Update , context : CallbackContext ) :
@@ -946,11 +1087,11 @@ async def button_callback_handler(update: Update, context: CallbackContext):
context . user_data [ ' list_subs_in_category ' ] = category
context . user_data [ ' list_subs_in_category ' ] = category
context . user_data [ ' list_subs_in_category_id ' ] = category_id
context . user_data [ ' list_subs_in_category_id ' ] = category_id
keyboard = await get_subs_list_keyboard ( user_id , category_filter = category )
keyboard = await get_subs_list_keyboard ( user_id , category_filter = category )
msg_text = f " 分类“ { escape_markdown ( category , version = 2 )} ” 下的订阅:"
msg_text = f " 分类 <b> { escape_html ( category ) } </b> 下的订阅:"
if not keyboard :
if not keyboard :
msg_text = f " 分类“ { escape_markdown ( category , version = 2 )} ” 下没有订阅。"
msg_text = f " 分类 <b> { escape_html ( category ) } </b> 下没有订阅。"
keyboard = InlineKeyboardMarkup ( [ [ InlineKeyboardButton ( " « 返回分类列表 " , callback_data = ' list_categories ' ) ] ] )
keyboard = InlineKeyboardMarkup ( [ [ InlineKeyboardButton ( " « 返回分类列表 " , callback_data = ' list_categories ' ) ] ] )
await query . edit_message_text ( msg_text , reply_markup = keyboard , parse_mode = ' MarkdownV2 ' )
await query . edit_message_text ( msg_text , reply_markup = keyboard , parse_mode = ' HTML ' )
return
return
if data == ' list_categories ' :
if data == ' list_categories ' :
context . user_data . pop ( ' list_subs_in_category ' , None )
context . user_data . pop ( ' list_subs_in_category ' , None )
@@ -1020,18 +1161,18 @@ async def button_callback_handler(update: Update, context: CallbackContext):
( new_date_str , sub_id , user_id )
( new_date_str , sub_id , user_id )
)
)
conn . commit ( )
conn . commit ( )
safe_sub_name = escape_markdown ( sub [ ' name ' ] , version = 2 )
safe_sub_name = escape_html ( sub [ ' name ' ] )
await query . edit_message_text (
await query . edit_message_text (
text = f " ✅ * 续费成功* \n \n 您的订阅 ` { safe_sub_name } ` 新的到期日为: ` { new_date_str } ` " ,
text = f " ✅ <b> 续费成功</b> \n \n 您的订阅 <code> { safe_sub_name } </code> 新的到期日为: <code> { new_date_str } </code> " ,
parse_mode = ' MarkdownV2 ' ,
parse_mode = ' HTML ' ,
reply_markup = None
reply_markup = None
)
)
else :
else :
await query . answer ( " 续费失败:无法计算新的到期日期。 " , show_alert = True )
await query . answer ( " 续费失败:无法计算新的到期日期。 " , show_alert = True )
else :
else :
await query . answer ( " 续费失败:此订阅可能已被删除或无权限。 " , show_alert = True )
await query . answer ( " 续费失败:此订阅可能已被删除或无权限。 " , show_alert = True )
await query . edit_message_text ( text = query . message . text + " \n \n * (错误:此订阅不存在或无权限)* " ,
await query . edit_message_text ( text = query . message . text + " \n \n <b> (错误:此订阅不存在或无权限)</b> " ,
parse_mode = ' MarkdownV2 ' , reply_markup = None )
parse_mode = ' HTML ' , reply_markup = None )
elif action == ' delete ' :
elif action == ' delete ' :
with get_db_connection ( ) as conn :
with get_db_connection ( ) as conn :
@@ -1058,12 +1199,12 @@ async def button_callback_handler(update: Update, context: CallbackContext):
if ' list_subs_in_category ' in context . user_data :
if ' list_subs_in_category ' in context . user_data :
category = context . user_data [ ' list_subs_in_category ' ]
category = context . user_data [ ' list_subs_in_category ' ]
keyboard = await get_subs_list_keyboard ( user_id , category_filter = category )
keyboard = await get_subs_list_keyboard ( user_id , category_filter = category )
msg_text = f " 分类“ { escape_markdown ( category , version = 2 )} ” 下的订阅:"
msg_text = f " 分类 <b> { escape_html ( category ) } </b> 下的订阅:"
if not keyboard :
if not keyboard :
msg_text = f " 分类“ { escape_markdown ( category , version = 2 )} ” 下没有订阅。"
msg_text = f " 分类 <b> { escape_html ( category ) } </b> 下没有订阅。"
keyboard = InlineKeyboardMarkup (
keyboard = InlineKeyboardMarkup (
[ [ InlineKeyboardButton ( " « 返回分类列表 " , callback_data = ' list_categories ' ) ] ] )
[ [ InlineKeyboardButton ( " « 返回分类列表 " , callback_data = ' list_categories ' ) ] ] )
await query . edit_message_text ( msg_text , reply_markup = keyboard , parse_mode = ' MarkdownV2 ' )
await query . edit_message_text ( msg_text , reply_markup = keyboard , parse_mode = ' HTML ' )
else :
else :
keyboard = await get_subs_list_keyboard ( user_id )
keyboard = await get_subs_list_keyboard ( user_id )
if not keyboard :
if not keyboard :
@@ -1139,16 +1280,16 @@ async def edit_field_selected(update: Update, context: CallbackContext):
[ InlineKeyboardButton ( " 月 " , callback_data = ' freq_unit_month ' ) ,
[ InlineKeyboardButton ( " 月 " , callback_data = ' freq_unit_month ' ) ,
InlineKeyboardButton ( " 年 " , callback_data = ' freq_unit_year ' ) ]
InlineKeyboardButton ( " 年 " , callback_data = ' freq_unit_year ' ) ]
]
]
await query . edit_message_text ( " 请选择新的周期*单位* " , reply_markup = InlineKeyboardMarkup ( keyboard ) ,
await query . edit_message_text ( " 请选择新的周期<b>单位</b> " , reply_markup = InlineKeyboardMarkup ( keyboard ) ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
return EDIT_FREQ_UNIT
return EDIT_FREQ_UNIT
else :
else :
field_map = { ' name ' : ' 名称 ' , ' cost ' : ' 费用 ' , ' currency ' : ' 货币 ' , ' category ' : ' 类别 ' , ' next_due ' : ' 下次付款日 ' ,
field_map = { ' name ' : ' 名称 ' , ' cost ' : ' 费用 ' , ' currency ' : ' 货币 ' , ' category ' : ' 类别 ' , ' next_due ' : ' 下次付款日 ' ,
' notes ' : ' 备注 ' }
' notes ' : ' 备注 ' }
prompt = f " 好的,请输入新的 * { field_map . get ( field_to_edit , field_to_edit ) } * 值:"
prompt = f " 好的,请输入新的 <b> { field_map . get ( field_to_edit , field_to_edit ) } </b> 值:"
if field_to_edit == ' notes ' :
if field_to_edit == ' notes ' :
prompt + = " \n (如需清空备注,请输入 /empty ) "
prompt + = " \n (如需清空备注,请输入 /empty ) "
await query . edit_message_text ( prompt , parse_mode = ' MarkdownV2 ' )
await query . edit_message_text ( prompt , parse_mode = ' HTML ' )
return EDIT_GET_NEW_VALUE
return EDIT_GET_NEW_VALUE
@@ -1160,7 +1301,7 @@ async def edit_freq_unit_received(update: Update, context: CallbackContext):
await query . edit_message_text ( " 错误:无效的周期单位,请重试。 " )
await query . edit_message_text ( " 错误:无效的周期单位,请重试。 " )
return ConversationHandler . END
return ConversationHandler . END
context . user_data [ ' new_freq_unit ' ] = unit
context . user_data [ ' new_freq_unit ' ] = unit
await query . edit_message_text ( " 好的,现在请输入新的周期*数量* 。 " , parse_mode = ' MarkdownV2 ' )
await query . edit_message_text ( " 好的,现在请输入新的周期<b>数量</b> 。 " , parse_mode = ' HTML ' )
return EDIT_FREQ_VALUE
return EDIT_FREQ_VALUE
@@ -1209,7 +1350,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
await update . effective_message . reply_text ( " 错误:未选择要编辑的字段。 " )
await update . effective_message . reply_text ( " 错误:未选择要编辑的字段。 " )
return ConversationHandler . END
return ConversationHandler . END
db_field = EDITABLE_SUB_FIELDS . get ( field )
db_field = EDITABLE_SUB_FIELDS . get ( field )
if not db_field :
if not db_field or not db_field . isidentifier ( ) :
if update . effective_message :
if update . effective_message :
await update . effective_message . reply_text ( " 错误:不允许编辑该字段。 " )
await update . effective_message . reply_text ( " 错误:不允许编辑该字段。 " )
logger . warning ( f " Blocked unsafe field update attempt: { field } " )
logger . warning ( f " Blocked unsafe field update attempt: { field } " )
@@ -1259,7 +1400,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
parsed = parse_date ( str ( new_value ) )
parsed = parse_date ( str ( new_value ) )
if not parsed :
if not parsed :
if message_to_reply :
if message_to_reply :
await message_to_reply . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025\\ -10 \\ -01 ' 或 ' 10月1日 ' 的格式。 " )
await message_to_reply . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025-10 -01 ' 或 ' 10月1日 ' 的格式。 " )
validation_failed = True
validation_failed = True
else :
else :
new_value = parsed
new_value = parsed
@@ -1335,13 +1476,13 @@ async def _display_reminder_settings(query: CallbackQuery, context: CallbackCont
[ InlineKeyboardButton ( enabled_text , callback_data = ' remindaction_toggle_enabled ' ) ] ,
[ InlineKeyboardButton ( enabled_text , callback_data = ' remindaction_toggle_enabled ' ) ] ,
[ InlineKeyboardButton ( due_date_text , callback_data = ' remindaction_toggle_due_date ' ) ]
[ InlineKeyboardButton ( due_date_text , callback_data = ' remindaction_toggle_due_date ' ) ]
]
]
safe_name = escape_markdown ( sub [ ' name ' ] , version = 2 )
safe_name = escape_html ( sub [ ' name ' ] )
current_status = f " * 🔔 提醒设置: { safe_name } * \n \n "
current_status = f " <b> 🔔 提醒设置: { safe_name } </b> \n \n "
if sub [ ' renewal_type ' ] == ' manual ' :
if sub [ ' renewal_type ' ] == ' manual ' :
current_status + = f " 当前提前提醒: * { sub [ ' reminder_days ' ] } 天* \n "
current_status + = f " 当前提前提醒: <b> { sub [ ' reminder_days ' ] } 天</b> \n "
keyboard . append ( [ InlineKeyboardButton ( " ⚙️ 更改提前天数 " , callback_data = ' remindaction_ask_days ' ) ] )
keyboard . append ( [ InlineKeyboardButton ( " ⚙️ 更改提前天数 " , callback_data = ' remindaction_ask_days ' ) ] )
keyboard . append ( [ InlineKeyboardButton ( " « 返回详情 " , callback_data = f ' view_ { sub_id } ' ) ] )
keyboard . append ( [ InlineKeyboardButton ( " « 返回详情 " , callback_data = f ' view_ { sub_id } ' ) ] )
await query . edit_message_text ( current_status , reply_markup = InlineKeyboardMarkup ( keyboard ) , parse_mode = ' MarkdownV2 ' )
await query . edit_message_text ( current_status , reply_markup = InlineKeyboardMarkup ( keyboard ) , parse_mode = ' HTML ' )
async def remind_settings_start ( update : Update , context : CallbackContext ) :
async def remind_settings_start ( update : Update , context : CallbackContext ) :
@@ -1443,7 +1584,7 @@ async def remind_days_received(update: Update, context: CallbackContext):
async def set_currency ( update : Update , context : CallbackContext ) :
async def set_currency ( update : Update , context : CallbackContext ) :
user_id , args = update . effective_user . id , context . args
user_id , args = update . effective_user . id , context . args
if len ( args ) != 1 :
if len ( args ) != 1 :
await update . message . reply_text ( " 用法: /set_currency ` <code>` (例如 /set_currency USD) " , parse_mode = ' MarkdownV2 ' )
await update . message . reply_text ( " 用法: /set_currency <code>代码</ code>(例如 /set_currency USD) " , parse_mode = ' HTML ' )
return
return
new_currency = args [ 0 ] . upper ( )
new_currency = args [ 0 ] . upper ( )
if len ( new_currency ) != 3 or not new_currency . isalpha ( ) :
if len ( new_currency ) != 3 or not new_currency . isalpha ( ) :
@@ -1457,8 +1598,9 @@ async def set_currency(update: Update, context: CallbackContext):
ON CONFLICT(user_id) DO UPDATE SET main_currency = excluded.main_currency
ON CONFLICT(user_id) DO UPDATE SET main_currency = excluded.main_currency
""" , ( user_id , new_currency ) )
""" , ( user_id , new_currency ) )
conn . commit ( )
conn . commit ( )
await update . message . reply_text ( f " 您的主货币已设为 { escape_markdown ( new_currency , version = 2 )} 。 " ,
await update . message . reply_text ( f " 您的主货币已设为 <b> { escape_html ( new_currency ) } </b> 。" ,
parse_mode = ' MarkdownV2 ' )
parse_mode = ' HTML ' )
return ConversationHandler . END
async def cancel ( update : Update , context : CallbackContext ) :
async def cancel ( update : Update , context : CallbackContext ) :
@@ -1471,12 +1613,140 @@ async def cancel(update: Update, context: CallbackContext):
return ConversationHandler . END
return ConversationHandler . END
def _can_run_update ( user_id : int ) - > bool :
""" 仅允许指定 owner 执行自动更新。未配置 owner 时默认拒绝。 """
if not UPDATE_OWNER_ID :
return False
try :
return int ( UPDATE_OWNER_ID ) == int ( user_id )
except ( ValueError , TypeError ) :
return False
def _resolve_update_target ( repo_dir : str ) :
"""
解析更新目标 remote/branch。
优先级:
1) 环境变量 AUTO_UPDATE_REMOTE + AUTO_UPDATE_BRANCH
2) 当前分支上游 @ {u}
3) 远程优先 gitllc, 其次 origin, 分支用 AUTO_UPDATE_BRANCH 或 main
"""
branch = ( AUTO_UPDATE_BRANCH or ' main ' ) . strip ( ) or ' main '
# 1) 明确指定 remote
if AUTO_UPDATE_REMOTE :
return AUTO_UPDATE_REMOTE , branch
# 2) 尝试读取上游分支(如 gitllc/main)
upstream_proc = subprocess . run (
[ " git " , " rev-parse " , " --abbrev-ref " , " --symbolic-full-name " , " @ {u} " ] ,
cwd = repo_dir , capture_output = True , text = True
)
if upstream_proc . returncode == 0 :
upstream = upstream_proc . stdout . strip ( )
if ' / ' in upstream :
remote , up_branch = upstream . split ( ' / ' , 1 )
if remote and up_branch :
return remote , up_branch
# 3) 回退:从远程列表推断
remotes_proc = subprocess . run ( [ " git " , " remote " ] , cwd = repo_dir , capture_output = True , text = True )
if remotes_proc . returncode != 0 :
return None , None
remotes = [ r . strip ( ) for r in remotes_proc . stdout . splitlines ( ) if r . strip ( ) ]
if not remotes :
return None , None
if ' gitllc ' in remotes :
return ' gitllc ' , branch
if ' origin ' in remotes :
return ' origin ' , branch
return remotes [ 0 ] , branch
def _run_cmd ( cmd , cwd ) :
return subprocess . run ( cmd , cwd = cwd , capture_output = True , text = True )
async def update_bot ( update : Update , context : CallbackContext ) :
user_id = update . effective_user . id
if not _can_run_update ( user_id ) :
await update . message . reply_text ( " 无权限执行 /update。 " )
return
await update . message . reply_text ( " 开始检查更新,请稍候… " )
repo_dir = os . path . dirname ( os . path . abspath ( __file__ ) )
try :
remote_name , branch_name = _resolve_update_target ( repo_dir )
if not remote_name or not branch_name :
await update . message . reply_text ( " 更新失败:无法解析 git 远程仓库,请检查仓库 remote 配置。 " )
return
fetch_cmd = [ " git " , " fetch " , remote_name , branch_name ]
fetch_proc = await asyncio . to_thread ( _run_cmd , fetch_cmd , repo_dir )
if fetch_proc . returncode != 0 :
err = ( fetch_proc . stderr or fetch_proc . stdout or " 未知错误 " ) . strip ( )
await update . message . reply_text ( f " 更新失败( fetch) : \n <code> { escape_html ( err ) } </code> " , parse_mode = ' HTML ' )
return
local_rev = await asyncio . to_thread ( _run_cmd , [ " git " , " rev-parse " , " HEAD " ] , repo_dir )
fetched_rev = await asyncio . to_thread ( _run_cmd , [ " git " , " rev-parse " , " FETCH_HEAD " ] , repo_dir )
if local_rev . returncode != 0 or fetched_rev . returncode != 0 :
await update . message . reply_text ( " 更新失败:无法读取当前版本。 " )
return
local_hash = local_rev . stdout . strip ( )
fetched_hash = fetched_rev . stdout . strip ( )
if local_hash == fetched_hash :
await update . message . reply_text ( " 当前已是最新版本,无需更新。 " )
return
reset_proc = await asyncio . to_thread (
_run_cmd ,
[ " git " , " reset " , " --hard " , " FETCH_HEAD " ] ,
repo_dir
)
if reset_proc . returncode != 0 :
err = ( reset_proc . stderr or reset_proc . stdout or " 未知错误 " ) . strip ( )
await update . message . reply_text ( f " 更新失败( reset) : \n <code> { escape_html ( err ) } </code> " , parse_mode = ' HTML ' )
return
pip_proc = await asyncio . to_thread (
_run_cmd ,
[ sys . executable , " -m " , " pip " , " install " , " -r " , " requirements.txt " ] ,
repo_dir
)
if pip_proc . returncode != 0 :
err = ( pip_proc . stderr or pip_proc . stdout or " 未知错误 " ) . strip ( )
await update . message . reply_text ( f " 依赖安装失败: \n <code> { escape_html ( err [ - 1800 : ] ) } </code> " , parse_mode = ' HTML ' )
return
await update . message . reply_text (
f " 更新完成( { escape_html ( remote_name ) } { escape_html ( branch_name ) } ),正在重启机器人… " ,
parse_mode = ' HTML '
)
os . execv ( sys . executable , [ sys . executable ] + sys . argv )
except Exception as e :
logger . error ( f " /update failed: { e } " )
await update . message . reply_text ( f " 更新异常:<code> { escape_html ( str ( e ) ) } </code> " , parse_mode = ' HTML ' )
# --- Main ---
# --- Main ---
def main ( ) :
def main ( ) :
if not TELEGRAM_TOKEN :
if not TELEGRAM_TOKEN :
logger . critical ( " TELEGRAM_TOKEN 环境变量未设置! " )
logger . critical ( " TELEGRAM_TOKEN 环境变量未设置! " )
return
return
if not EXCHANGE_API_KEY :
logger . info ( " 未配置 EXCHANGE_API_KEY, 多货币换算将降级为只使用本地缓存( 若无缓存则不转换) 。 " )
application = Application . builder ( ) . token ( TELEGRAM_TOKEN ) . build ( )
application = Application . builder ( ) . token ( TELEGRAM_TOKEN ) . build ( )
async def post_init ( app : Application ) :
async def post_init ( app : Application ) :
@@ -1496,6 +1766,7 @@ def main():
BotCommand ( " import " , " 📥 导入订阅 " ) ,
BotCommand ( " import " , " 📥 导入订阅 " ) ,
BotCommand ( " export " , " 📤 导出订阅 " ) ,
BotCommand ( " export " , " 📤 导出订阅 " ) ,
BotCommand ( " set_currency " , " 💲 设置主货币 " ) ,
BotCommand ( " set_currency " , " 💲 设置主货币 " ) ,
BotCommand ( " update " , " 🛠️ 拉取最新代码并重启 " ) ,
BotCommand ( " help " , " ℹ ️ 获取帮助" ) ,
BotCommand ( " help " , " ℹ ️ 获取帮助" ) ,
BotCommand ( " cancel " , " ❌ 取消当前操作 " )
BotCommand ( " cancel " , " ❌ 取消当前操作 " )
]
]
@@ -1590,6 +1861,7 @@ def main():
application . add_handler ( CommandHandler ( ' set_currency ' , set_currency ) )
application . add_handler ( CommandHandler ( ' set_currency ' , set_currency ) )
application . add_handler ( CommandHandler ( ' stats ' , stats ) )
application . add_handler ( CommandHandler ( ' stats ' , stats ) )
application . add_handler ( CommandHandler ( ' export ' , export_command ) )
application . add_handler ( CommandHandler ( ' export ' , export_command ) )
application . add_handler ( CommandHandler ( ' update ' , update_bot ) )
application . add_handler ( CommandHandler ( ' cancel ' , cancel ) )
application . add_handler ( CommandHandler ( ' cancel ' , cancel ) )
application . add_handler ( add_conv )
application . add_handler ( add_conv )