hardening: close low-risk gaps and improve import validation
This commit is contained in:
121
SubMind.py
121
SubMind.py
@@ -263,25 +263,9 @@ def _parse_category_id_from_callback(data: str) -> int | None:
|
||||
return int(payload) if payload.isdigit() else None
|
||||
|
||||
|
||||
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id]
|
||||
if category_filter:
|
||||
query += "AND category = ? "
|
||||
params.append(category_filter)
|
||||
query += "ORDER BY next_due ASC"
|
||||
cursor.execute(query, tuple(params))
|
||||
subs = cursor.fetchall()
|
||||
if not subs:
|
||||
return None
|
||||
buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs]
|
||||
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
|
||||
if category_filter:
|
||||
keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')])
|
||||
else:
|
||||
keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')])
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
def _clear_action_state(context: CallbackContext, keys: list[str]):
|
||||
for key in keys:
|
||||
context.user_data.pop(key, None)
|
||||
|
||||
|
||||
# --- 自动任务 ---
|
||||
@@ -435,40 +419,42 @@ async def stats(update: Update, context: CallbackContext):
|
||||
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
fig, ax = plt.subplots(figsize=(12, 12))
|
||||
|
||||
autopct_function = make_autopct(category_costs.values, main_currency)
|
||||
|
||||
wedges, texts, autotexts = ax.pie(category_costs.values,
|
||||
labels=category_costs.index,
|
||||
autopct=autopct_function,
|
||||
startangle=140,
|
||||
pctdistance=0.7,
|
||||
labeldistance=1.05)
|
||||
|
||||
ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20)
|
||||
|
||||
for text in texts:
|
||||
text.set_fontproperties(font_prop)
|
||||
text.set_fontsize(22)
|
||||
|
||||
for autotext in autotexts:
|
||||
autotext.set_fontproperties(font_prop)
|
||||
autotext.set_fontsize(20)
|
||||
autotext.set_color('white')
|
||||
|
||||
ax.axis('equal')
|
||||
|
||||
fig.tight_layout()
|
||||
with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp:
|
||||
image_path = tmp.name
|
||||
plt.savefig(image_path)
|
||||
plt.close(fig)
|
||||
image_path = None
|
||||
|
||||
try:
|
||||
autopct_function = make_autopct(category_costs.values, main_currency)
|
||||
|
||||
wedges, texts, autotexts = ax.pie(category_costs.values,
|
||||
labels=category_costs.index,
|
||||
autopct=autopct_function,
|
||||
startangle=140,
|
||||
pctdistance=0.7,
|
||||
labeldistance=1.05)
|
||||
|
||||
ax.set_title('每月订阅支出分类统计', fontproperties=font_prop, fontsize=32, pad=20)
|
||||
|
||||
for text in texts:
|
||||
text.set_fontproperties(font_prop)
|
||||
text.set_fontsize(22)
|
||||
|
||||
for autotext in autotexts:
|
||||
autotext.set_fontproperties(font_prop)
|
||||
autotext.set_fontsize(20)
|
||||
autotext.set_color('white')
|
||||
|
||||
ax.axis('equal')
|
||||
fig.tight_layout()
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp:
|
||||
image_path = tmp.name
|
||||
|
||||
plt.savefig(image_path)
|
||||
|
||||
with open(image_path, 'rb') as photo:
|
||||
await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。")
|
||||
finally:
|
||||
if os.path.exists(image_path):
|
||||
plt.close(fig)
|
||||
if image_path and os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
|
||||
@@ -486,9 +472,9 @@ async def export_command(update: Update, context: CallbackContext):
|
||||
with tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) as tmp:
|
||||
export_path = tmp.name
|
||||
|
||||
df.to_csv(export_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
try:
|
||||
df.to_csv(export_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
with open(export_path, 'rb') as file:
|
||||
await update.message.reply_document(document=file, filename='subscriptions.csv',
|
||||
caption="您的订阅数据已导出为 CSV 文件。")
|
||||
@@ -545,14 +531,20 @@ async def import_upload_received(update: Update, context: CallbackContext):
|
||||
renewal_type = str(row['renewal_type']).lower()
|
||||
if renewal_type not in valid_renewal_types:
|
||||
raise ValueError(f"无效续费类型: {renewal_type}")
|
||||
notes = str(row['notes']) if pd.notna(row['notes']) else None
|
||||
notes = str(row['notes']).strip() if pd.notna(row['notes']) else None
|
||||
name = str(row['name']).strip()
|
||||
category = str(row['category']).strip()
|
||||
if not name:
|
||||
raise ValueError("名称不能为空")
|
||||
if not category:
|
||||
raise ValueError("类别不能为空")
|
||||
records.append((
|
||||
user_id, row['name'], cost, currency, row['category'],
|
||||
user_id, name, cost, currency, category,
|
||||
next_due, frequency_unit, frequency_value, renewal_type, notes
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}")
|
||||
await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误:{e}")
|
||||
logger.error(f"Invalid row in CSV import, error: {e}")
|
||||
await update.message.reply_text(f"导入失败,存在无效行:{e}")
|
||||
return ConversationHandler.END
|
||||
|
||||
with get_db_connection() as conn:
|
||||
@@ -926,6 +918,12 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
||||
parse_mode='MarkdownV2', reply_markup=None)
|
||||
|
||||
elif action == 'delete':
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||||
if not cursor.fetchone():
|
||||
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
|
||||
return
|
||||
keyboard = InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
|
||||
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
|
||||
@@ -935,7 +933,11 @@ async def button_callback_handler(update: Update, context: CallbackContext):
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
if deleted == 0:
|
||||
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
|
||||
return
|
||||
await query.answer("订阅已删除")
|
||||
if 'list_subs_in_category' in context.user_data:
|
||||
category = context.user_data['list_subs_in_category']
|
||||
@@ -1068,7 +1070,7 @@ async def edit_freq_value_received(update: Update, context: CallbackContext):
|
||||
conn.commit()
|
||||
|
||||
await update.message.reply_text("✅ 周期已更新!")
|
||||
context.user_data.clear()
|
||||
_clear_action_state(context, ['sub_id_for_action', 'new_freq_unit', 'field_to_edit'])
|
||||
await show_subscription_view(update, context, sub_id)
|
||||
return ConversationHandler.END
|
||||
|
||||
@@ -1131,6 +1133,11 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
||||
validation_failed = True
|
||||
else:
|
||||
new_value = parsed
|
||||
elif field == 'renewal_type':
|
||||
if str(new_value) not in ('auto', 'manual'):
|
||||
if message_to_reply:
|
||||
await message_to_reply.reply_text("续费方式只能为 auto 或 manual。")
|
||||
validation_failed = True
|
||||
elif field == 'category':
|
||||
new_value = str(new_value).strip()
|
||||
if not new_value:
|
||||
@@ -1161,7 +1168,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
|
||||
elif message_to_reply:
|
||||
await message_to_reply.reply_text("✅ 字段已更新!")
|
||||
|
||||
context.user_data.clear()
|
||||
_clear_action_state(context, ['sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
|
||||
await show_subscription_view(update, context, sub_id)
|
||||
return ConversationHandler.END
|
||||
|
||||
@@ -1282,7 +1289,7 @@ async def remind_days_received(update: Update, context: CallbackContext):
|
||||
return ConversationHandler.END
|
||||
conn.commit()
|
||||
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
|
||||
context.user_data.clear()
|
||||
_clear_action_state(context, ['sub_id_for_action'])
|
||||
await show_subscription_view(update, context, sub_id)
|
||||
except (ValueError, TypeError):
|
||||
await update.message.reply_text("请输入一个有效的非负整数。")
|
||||
|
||||
Reference in New Issue
Block a user