@@ -1,6 +1,8 @@
import sqlite3
import sqlite3
import asyncio
import asyncio
import os
import os
import sys
import subprocess
import html
import html
import requests
import requests
import datetime
import datetime
@@ -20,7 +22,10 @@ from telegram.ext import (
CallbackContext , CallbackQueryHandler , ConversationHandler
CallbackContext , CallbackQueryHandler , ConversationHandler
)
)
from telegram . error import TelegramError
from telegram . error import TelegramError
from telegram . helpers import escape_html
def escape_html ( text , version = None ) :
if text is None :
return ' '
return html . escape ( str ( text ) )
# --- 加载 .env 和设置 ---
# --- 加载 .env 和设置 ---
load_dotenv ( )
load_dotenv ( )
@@ -39,6 +44,11 @@ EXCHANGE_API_KEY = os.getenv('EXCHANGE_API_KEY')
PROJECT_NAME = " SubMind "
PROJECT_NAME = " SubMind "
DB_FILE = ' submind.db '
DB_FILE = ' submind.db '
# 自动更新配置
UPDATE_OWNER_ID = os . getenv ( ' UPDATE_OWNER_ID ' ) # 仅允许此用户执行 /update
AUTO_UPDATE_REMOTE = os . getenv ( ' AUTO_UPDATE_REMOTE ' , ' https://git.llc/zimk/SubMind.git ' ) . strip ( )
AUTO_UPDATE_BRANCH = os . getenv ( ' AUTO_UPDATE_BRANCH ' , ' main ' ) . strip ( ) or ' main '
# --- 对话处理器状态 ---
# --- 对话处理器状态 ---
( ADD_NAME , ADD_COST , ADD_CURRENCY , ADD_CATEGORY , ADD_NEXT_DUE ,
( ADD_NAME , ADD_COST , ADD_CURRENCY , ADD_CATEGORY , ADD_NEXT_DUE ,
ADD_FREQ_UNIT , ADD_FREQ_VALUE , ADD_RENEWAL_TYPE , ADD_NOTES ) = range ( 9 )
ADD_FREQ_UNIT , ADD_FREQ_VALUE , ADD_RENEWAL_TYPE , ADD_NOTES ) = range ( 9 )
@@ -373,7 +383,7 @@ async def check_and_send_reminders(context: CallbackContext):
] )
] )
if sub [ ' reminder_on_due_date ' ] and due_date == today :
if sub [ ' reminder_on_due_date ' ] and due_date == today :
message = f " 🔔 * 订阅到期提醒* \n \n 您的订阅 ` { safe_sub_name } ` 今天到期。"
message = f " 🔔 <b> 订阅到期提醒</b> \n \n 您的订阅 <code> { safe_sub_name } </code> 今天到期。"
if renewal_type == ' manual ' :
if renewal_type == ' manual ' :
message + = " 请记得手动续费。 "
message + = " 请记得手动续费。 "
else :
else :
@@ -385,7 +395,7 @@ async def check_and_send_reminders(context: CallbackContext):
if reminder_date == today :
if reminder_date == today :
days_left = ( due_date - today ) . days
days_left = ( due_date - today ) . days
days_text = f " <b> { days_left } 天后</b> " if days_left > 0 else " <b>今天</b> "
days_text = f " <b> { days_left } 天后</b> " if days_left > 0 else " <b>今天</b> "
message = f " 🔔 * 订阅即将到期提醒* \n \n 您的手动续费订阅 ` { safe_sub_name } ` 将在 { days_text } 到期。 "
message = f " 🔔 <b> 订阅即将到期提醒</b> \n \n 您的手动续费订阅 <code> { safe_sub_name } </code> 将在 { days_text } 到期。 "
if message :
if message :
await context . bot . send_message (
await context . bot . send_message (
@@ -415,19 +425,19 @@ async def start(update: Update, context: CallbackContext):
async def help_command ( update : Update , context : CallbackContext ) :
async def help_command ( update : Update , context : CallbackContext ) :
help_text = fr """
help_text = f """
* { escape_html ( PROJECT_NAME ) } 命令列表*
<b> { escape_html ( PROJECT_NAME ) } 命令列表</b>
* 🌟 核心功能*
<b> 🌟 核心功能</b>
/add\ _sub \ - 引导您添加一个新的订阅
/add_sub - 引导您添加一个新的订阅
/list\ _subs \ - 列出您的所有订阅
/list_subs - 列出您的所有订阅
/list\ _categories \ - 按分类浏览您的订阅
/list_categories - 按分类浏览您的订阅
* 📊 数据管理*
<b> 📊 数据管理</b>
/stats \ - 查看按类别分类的订阅统计
/stats - 查看按类别分类的订阅统计
/import \ - 通过上传 CSV 文件批量导入订阅
/import - 通过上传 CSV 文件批量导入订阅
/export \ - 将您的所有订阅导出为 CSV 文件
/export - 将您的所有订阅导出为 CSV 文件
* ⚙️ 个性化设置*
<b> ⚙️ 个性化设置</b>
/set\ _currency \ `<code> \ ` \ - 设置您的主要货币
/set_currency <代码> - 设置您的主要货币
/cancel \ - 在任何流程中取消当前操作
/cancel - 在任何流程中取消当前操作
"""
"""
await update . message . reply_text ( help_text , parse_mode = ' HTML ' )
await update . message . reply_text ( help_text , parse_mode = ' HTML ' )
@@ -513,7 +523,7 @@ async def stats(update: Update, context: CallbackContext):
try :
try :
theme_colors = [ ' #3B82F6 ' , ' #10B981 ' , ' #F59E0B ' , ' #EF4444 ' , ' #8B5CF6 ' , ' #EC4899 ' , ' #14B8A6 ' , ' #F97316 ' , ' #6366F1 ' , ' #84CC16 ' ]
theme_colors = [ ' #3B82F6 ' , ' #10B981 ' , ' #F59E0B ' , ' #EF4444 ' , ' #8B5CF6 ' , ' #EC4899 ' , ' #14B8A6 ' , ' #F97316 ' , ' #6366F1 ' , ' #84CC16 ' ]
if len ( category_costs ) > len ( theme_colors ) :
if len ( category_costs ) > len ( theme_colors ) :
import matplotlib . pyplot as plt
# 移除导致遮蔽的局部 import, 直接使用全局的 matplotlib 和 plt
extra_colors = [ matplotlib . colors . to_hex ( c ) for c in plt . get_cmap ( ' tab20 ' ) . colors ]
extra_colors = [ matplotlib . colors . to_hex ( c ) for c in plt . get_cmap ( ' tab20 ' ) . colors ]
theme_colors . extend ( extra_colors )
theme_colors . extend ( extra_colors )
@@ -593,7 +603,7 @@ async def stats(update: Update, context: CallbackContext):
weight = ' bold '
weight = ' bold '
)
)
fig . suptitle ( ' 📊 您的订阅支出洞察 ' , fontproperties = font_prop , fontsize = 24 , color = ' #0F172A ' , y = 1.02 , weight = ' bold ' )
fig . suptitle ( ' 您的订阅统计报告 ' , fontproperties = font_prop , fontsize = 24 , color = ' #0F172A ' , y = 1.02 , weight = ' bold ' )
fig . tight_layout ( rect = [ 0 , 0 , 1 , 0.95 ] )
fig . tight_layout ( rect = [ 0 , 0 , 1 , 0.95 ] )
with tempfile . NamedTemporaryFile ( prefix = f ' stats_ { user_id } _ ' , suffix = ' .png ' , delete = False ) as tmp :
with tempfile . NamedTemporaryFile ( prefix = f ' stats_ { user_id } _ ' , suffix = ' .png ' , delete = False ) as tmp :
@@ -824,7 +834,7 @@ async def add_category_received(update: Update, context: CallbackContext):
cursor = conn . cursor ( )
cursor = conn . cursor ( )
cursor . execute ( " INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?) " , ( user_id , category_name ) )
cursor . execute ( " INSERT OR IGNORE INTO categories (user_id, name) VALUES (?, ?) " , ( user_id , category_name ) )
conn . commit ( )
conn . commit ( )
await update . message . reply_text ( " 第五步:请输入 * 下一次付款日期* (例如 2025\\ -10 \\ -01 或 10月1日) " ,
await update . message . reply_text ( " 第五步:请输入 <b> 下一次付款日期</b> (例如 2025-10 -01 或 10月1日) " ,
parse_mode = ' HTML ' )
parse_mode = ' HTML ' )
return ADD_NEXT_DUE
return ADD_NEXT_DUE
@@ -838,7 +848,7 @@ async def add_next_due_received(update: Update, context: CallbackContext):
parsed_date = parse_date ( update . message . text )
parsed_date = parse_date ( update . message . text )
if not parsed_date :
if not parsed_date :
await update . message . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025\\ -10 \\ -01 ' 或 ' 10月1日 ' 的格式。 " )
await update . message . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025-10 -01 ' 或 ' 10月1日 ' 的格式。 " )
return ADD_NEXT_DUE
return ADD_NEXT_DUE
sub_data [ ' next_due ' ] = parsed_date
sub_data [ ' next_due ' ] = parsed_date
keyboard = [
keyboard = [
@@ -865,7 +875,7 @@ async def add_freq_unit_received(update: Update, context: CallbackContext):
await query . edit_message_text ( " 错误:无效的周期单位,请重试。 " )
await query . edit_message_text ( " 错误:无效的周期单位,请重试。 " )
return ConversationHandler . END
return ConversationHandler . END
sub_data [ ' unit ' ] = unit
sub_data [ ' unit ' ] = unit
await query . edit_message_text ( " 第七步:请输入周期的<b>数量</b>( 例如: 每3个月, 输入 3) " , parse_mode = ' Markdown ' )
await query . edit_message_text ( " 第七步:请输入周期的<b>数量</b>( 例如: 每3个月, 输入 3) " , parse_mode = ' HTML ' )
return ADD_FREQ_VALUE
return ADD_FREQ_VALUE
@@ -1012,19 +1022,19 @@ async def show_subscription_view(update: Update, context: CallbackContext, sub_i
sub [ ' reminders_enabled ' ] , sub [ ' notes ' ] )
sub [ ' reminders_enabled ' ] , sub [ ' notes ' ] )
freq_text = format_frequency ( sub [ ' frequency_unit ' ] , sub [ ' frequency_value ' ] )
freq_text = format_frequency ( sub [ ' frequency_unit ' ] , sub [ ' frequency_value ' ] )
main_currency = get_user_main_currency ( user_id )
main_currency = get_user_main_currency ( user_id )
converted_cost = convert_currency ( cost , currency , main_currency )
converted_cost = await asyncio . to_thread ( convert_currency , cost , currency , main_currency )
safe_name , safe_category , safe_freq = escape_html ( name ) , escape_html ( category ) , escape_html ( freq_text )
safe_name , safe_category , safe_freq = escape_html ( name ) , escape_html ( category ) , escape_html ( freq_text )
cost_str , converted_cost_str = escape_html ( f " { cost : .2f } " ) , escape_html ( f " { converted_cost : .2f } " )
cost_str , converted_cost_str = escape_html ( f " { cost : .2f } " ) , escape_html ( f " { converted_cost : .2f } " )
renewal_text = " 手动续费 " if renewal_type == ' manual ' else " 自动续费 "
renewal_text = " 手动续费 " if renewal_type == ' manual ' else " 自动续费 "
reminder_status = " 开启 " if reminders_enabled else " 关闭 "
reminder_status = " 开启 " if reminders_enabled else " 关闭 "
text = ( f " * 订阅详情: { safe_name } * \n \n "
text = ( f " <b> 订阅详情: { safe_name } </b> \n \n "
f " \\ - *费用*: ` { cost_str } { currency . upper ( ) } ` \\ ( \\ ~` { converted_cost_str } { main_currency . upper ( ) } ` \\ )\n "
f " - <b>费用</b>: <code> { cost_str } { currency . upper ( ) } </code> (~<code> { converted_cost_str } { main_currency . upper ( ) } </code> )\n "
f " \\ - *类别*: ` { safe_category } ` \n "
f " - <b>类别</b>: <code> { safe_category } </code> \n "
f " \\ - *下次付款*: ` { next_due } ` \\ (周期: { safe_freq } \\ ) \n "
f " - <b>下次付款</b>: <code> { next_due } </code> (周期: { safe_freq } ) \n "
f " \\ - *续费方式*: ` { renewal_text } ` \n "
f " - <b>续费方式</b>: <code> { renewal_text } </code> \n "
f " \\ - *提醒状态*: ` { reminder_status } ` " )
f " - <b>提醒状态</b>: <code> { reminder_status } </code> " )
if notes :
if notes :
text + = f " \n \\ - *备注* : { escape_html ( notes ) } "
text + = f " \n - <b>备注</b> : { escape_html ( notes ) } "
keyboard_buttons = [
keyboard_buttons = [
[ InlineKeyboardButton ( " ✏️ 编辑 " , callback_data = f ' edit_ { sub_id } ' ) ,
[ InlineKeyboardButton ( " ✏️ 编辑 " , callback_data = f ' edit_ { sub_id } ' ) ,
InlineKeyboardButton ( " 🗑️ 删除 " , callback_data = f ' delete_ { sub_id } ' ) ] ,
InlineKeyboardButton ( " 🗑️ 删除 " , callback_data = f ' delete_ { sub_id } ' ) ] ,
@@ -1161,7 +1171,7 @@ async def button_callback_handler(update: Update, context: CallbackContext):
await query . answer ( " 续费失败:无法计算新的到期日期。 " , show_alert = True )
await query . answer ( " 续费失败:无法计算新的到期日期。 " , show_alert = True )
else :
else :
await query . answer ( " 续费失败:此订阅可能已被删除或无权限。 " , show_alert = True )
await query . answer ( " 续费失败:此订阅可能已被删除或无权限。 " , show_alert = True )
await query . edit_message_text ( text = query . message . text + " \n \n * (错误:此订阅不存在或无权限)* " ,
await query . edit_message_text ( text = query . message . text + " \n \n <b> (错误:此订阅不存在或无权限)</b> " ,
parse_mode = ' HTML ' , reply_markup = None )
parse_mode = ' HTML ' , reply_markup = None )
elif action == ' delete ' :
elif action == ' delete ' :
@@ -1390,7 +1400,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
parsed = parse_date ( str ( new_value ) )
parsed = parse_date ( str ( new_value ) )
if not parsed :
if not parsed :
if message_to_reply :
if message_to_reply :
await message_to_reply . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025\\ -10 \\ -01 ' 或 ' 10月1日 ' 的格式。 " )
await message_to_reply . reply_text ( " 无法识别的日期格式,请使用类似 ' 2025-10 -01 ' 或 ' 10月1日 ' 的格式。 " )
validation_failed = True
validation_failed = True
else :
else :
new_value = parsed
new_value = parsed
@@ -1469,7 +1479,7 @@ async def _display_reminder_settings(query: CallbackQuery, context: CallbackCont
safe_name = escape_html ( sub [ ' name ' ] )
safe_name = escape_html ( sub [ ' name ' ] )
current_status = f " <b>🔔 提醒设置: { safe_name } </b> \n \n "
current_status = f " <b>🔔 提醒设置: { safe_name } </b> \n \n "
if sub [ ' renewal_type ' ] == ' manual ' :
if sub [ ' renewal_type ' ] == ' manual ' :
current_status + = f " 当前提前提醒: * { sub [ ' reminder_days ' ] } 天* \n "
current_status + = f " 当前提前提醒: <b> { sub [ ' reminder_days ' ] } 天</b> \n "
keyboard . append ( [ InlineKeyboardButton ( " ⚙️ 更改提前天数 " , callback_data = ' remindaction_ask_days ' ) ] )
keyboard . append ( [ InlineKeyboardButton ( " ⚙️ 更改提前天数 " , callback_data = ' remindaction_ask_days ' ) ] )
keyboard . append ( [ InlineKeyboardButton ( " « 返回详情 " , callback_data = f ' view_ { sub_id } ' ) ] )
keyboard . append ( [ InlineKeyboardButton ( " « 返回详情 " , callback_data = f ' view_ { sub_id } ' ) ] )
await query . edit_message_text ( current_status , reply_markup = InlineKeyboardMarkup ( keyboard ) , parse_mode = ' HTML ' )
await query . edit_message_text ( current_status , reply_markup = InlineKeyboardMarkup ( keyboard ) , parse_mode = ' HTML ' )
@@ -1603,6 +1613,131 @@ async def cancel(update: Update, context: CallbackContext):
return ConversationHandler . END
return ConversationHandler . END
def _can_run_update ( user_id : int ) - > bool :
""" 仅允许指定 owner 执行自动更新。未配置 owner 时默认拒绝。 """
if not UPDATE_OWNER_ID :
return False
try :
return int ( UPDATE_OWNER_ID ) == int ( user_id )
except ( ValueError , TypeError ) :
return False
def _resolve_update_target ( repo_dir : str ) :
"""
解析更新目标 remote/branch。
优先级:
1) 环境变量 AUTO_UPDATE_REMOTE + AUTO_UPDATE_BRANCH
2) 当前分支上游 @ {u}
3) 远程优先 gitllc, 其次 origin, 分支用 AUTO_UPDATE_BRANCH 或 main
"""
branch = ( AUTO_UPDATE_BRANCH or ' main ' ) . strip ( ) or ' main '
# 1) 明确指定 remote
if AUTO_UPDATE_REMOTE :
return AUTO_UPDATE_REMOTE , branch
# 2) 尝试读取上游分支(如 gitllc/main)
upstream_proc = subprocess . run (
[ " git " , " rev-parse " , " --abbrev-ref " , " --symbolic-full-name " , " @ {u} " ] ,
cwd = repo_dir , capture_output = True , text = True
)
if upstream_proc . returncode == 0 :
upstream = upstream_proc . stdout . strip ( )
if ' / ' in upstream :
remote , up_branch = upstream . split ( ' / ' , 1 )
if remote and up_branch :
return remote , up_branch
# 3) 回退:从远程列表推断
remotes_proc = subprocess . run ( [ " git " , " remote " ] , cwd = repo_dir , capture_output = True , text = True )
if remotes_proc . returncode != 0 :
return None , None
remotes = [ r . strip ( ) for r in remotes_proc . stdout . splitlines ( ) if r . strip ( ) ]
if not remotes :
return None , None
if ' gitllc ' in remotes :
return ' gitllc ' , branch
if ' origin ' in remotes :
return ' origin ' , branch
return remotes [ 0 ] , branch
def _run_cmd ( cmd , cwd ) :
return subprocess . run ( cmd , cwd = cwd , capture_output = True , text = True )
async def update_bot ( update : Update , context : CallbackContext ) :
user_id = update . effective_user . id
if not _can_run_update ( user_id ) :
await update . message . reply_text ( " 无权限执行 /update。 " )
return
await update . message . reply_text ( " 开始检查更新,请稍候… " )
repo_dir = os . path . dirname ( os . path . abspath ( __file__ ) )
try :
remote_name , branch_name = _resolve_update_target ( repo_dir )
if not remote_name or not branch_name :
await update . message . reply_text ( " 更新失败:无法解析 git 远程仓库,请检查仓库 remote 配置。 " )
return
fetch_cmd = [ " git " , " fetch " , remote_name , branch_name ]
fetch_proc = await asyncio . to_thread ( _run_cmd , fetch_cmd , repo_dir )
if fetch_proc . returncode != 0 :
err = ( fetch_proc . stderr or fetch_proc . stdout or " 未知错误 " ) . strip ( )
await update . message . reply_text ( f " 更新失败( fetch) : \n <code> { escape_html ( err ) } </code> " , parse_mode = ' HTML ' )
return
local_rev = await asyncio . to_thread ( _run_cmd , [ " git " , " rev-parse " , " HEAD " ] , repo_dir )
fetched_rev = await asyncio . to_thread ( _run_cmd , [ " git " , " rev-parse " , " FETCH_HEAD " ] , repo_dir )
if local_rev . returncode != 0 or fetched_rev . returncode != 0 :
await update . message . reply_text ( " 更新失败:无法读取当前版本。 " )
return
local_hash = local_rev . stdout . strip ( )
fetched_hash = fetched_rev . stdout . strip ( )
if local_hash == fetched_hash :
await update . message . reply_text ( " 当前已是最新版本,无需更新。 " )
return
reset_proc = await asyncio . to_thread (
_run_cmd ,
[ " git " , " reset " , " --hard " , " FETCH_HEAD " ] ,
repo_dir
)
if reset_proc . returncode != 0 :
err = ( reset_proc . stderr or reset_proc . stdout or " 未知错误 " ) . strip ( )
await update . message . reply_text ( f " 更新失败( reset) : \n <code> { escape_html ( err ) } </code> " , parse_mode = ' HTML ' )
return
pip_proc = await asyncio . to_thread (
_run_cmd ,
[ sys . executable , " -m " , " pip " , " install " , " -r " , " requirements.txt " ] ,
repo_dir
)
if pip_proc . returncode != 0 :
err = ( pip_proc . stderr or pip_proc . stdout or " 未知错误 " ) . strip ( )
await update . message . reply_text ( f " 依赖安装失败: \n <code> { escape_html ( err [ - 1800 : ] ) } </code> " , parse_mode = ' HTML ' )
return
await update . message . reply_text (
f " 更新完成( { escape_html ( remote_name ) } { escape_html ( branch_name ) } ),正在重启机器人… " ,
parse_mode = ' HTML '
)
os . execv ( sys . executable , [ sys . executable ] + sys . argv )
except Exception as e :
logger . error ( f " /update failed: { e } " )
await update . message . reply_text ( f " 更新异常:<code> { escape_html ( str ( e ) ) } </code> " , parse_mode = ' HTML ' )
# --- Main ---
# --- Main ---
def main ( ) :
def main ( ) :
if not TELEGRAM_TOKEN :
if not TELEGRAM_TOKEN :
@@ -1631,6 +1766,7 @@ def main():
BotCommand ( " import " , " 📥 导入订阅 " ) ,
BotCommand ( " import " , " 📥 导入订阅 " ) ,
BotCommand ( " export " , " 📤 导出订阅 " ) ,
BotCommand ( " export " , " 📤 导出订阅 " ) ,
BotCommand ( " set_currency " , " 💲 设置主货币 " ) ,
BotCommand ( " set_currency " , " 💲 设置主货币 " ) ,
BotCommand ( " update " , " 🛠️ 拉取最新代码并重启 " ) ,
BotCommand ( " help " , " ℹ ️ 获取帮助" ) ,
BotCommand ( " help " , " ℹ ️ 获取帮助" ) ,
BotCommand ( " cancel " , " ❌ 取消当前操作 " )
BotCommand ( " cancel " , " ❌ 取消当前操作 " )
]
]
@@ -1725,6 +1861,7 @@ def main():
application . add_handler ( CommandHandler ( ' set_currency ' , set_currency ) )
application . add_handler ( CommandHandler ( ' set_currency ' , set_currency ) )
application . add_handler ( CommandHandler ( ' stats ' , stats ) )
application . add_handler ( CommandHandler ( ' stats ' , stats ) )
application . add_handler ( CommandHandler ( ' export ' , export_command ) )
application . add_handler ( CommandHandler ( ' export ' , export_command ) )
application . add_handler ( CommandHandler ( ' update ' , update_bot ) )
application . add_handler ( CommandHandler ( ' cancel ' , cancel ) )
application . add_handler ( CommandHandler ( ' cancel ' , cancel ) )
application . add_handler ( add_conv )
application . add_handler ( add_conv )