hardening: close low-risk gaps and improve import validation

This commit is contained in:
Xiaolan Bot
2026-02-22 02:11:26 +08:00
parent 98a863f567
commit 36b136289c

View File

@@ -263,25 +263,9 @@ def _parse_category_id_from_callback(data: str) -> int | None:
return int(payload) if payload.isdigit() else None
async def get_subs_list_keyboard(user_id: int, category_filter: str = None) -> InlineKeyboardMarkup:
with get_db_connection() as conn:
cursor = conn.cursor()
query, params = "SELECT id, name FROM subscriptions WHERE user_id = ? ", [user_id]
if category_filter:
query += "AND category = ? "
params.append(category_filter)
query += "ORDER BY next_due ASC"
cursor.execute(query, tuple(params))
subs = cursor.fetchall()
if not subs:
return None
buttons = [InlineKeyboardButton(name, callback_data=f'view_{sub_id}') for sub_id, name in subs]
keyboard = [buttons[i:i + 2] for i in range(0, len(buttons), 2)]
if category_filter:
keyboard.append([InlineKeyboardButton("« 返回分类列表", callback_data='list_categories')])
else:
keyboard.append([InlineKeyboardButton("🗂️ 按分类浏览", callback_data='list_categories')])
return InlineKeyboardMarkup(keyboard)
def _clear_action_state(context: CallbackContext, keys: list[str]):
for key in keys:
context.user_data.pop(key, None)
# --- 自动任务 ---
@@ -435,7 +419,9 @@ async def stats(update: Update, context: CallbackContext):
plt.style.use('seaborn-v0_8-darkgrid')
fig, ax = plt.subplots(figsize=(12, 12))
image_path = None
try:
autopct_function = make_autopct(category_costs.values, main_currency)
wedges, texts, autotexts = ax.pie(category_costs.values,
@@ -457,18 +443,18 @@ async def stats(update: Update, context: CallbackContext):
autotext.set_color('white')
ax.axis('equal')
fig.tight_layout()
with tempfile.NamedTemporaryFile(prefix=f'stats_{user_id}_', suffix='.png', delete=False) as tmp:
image_path = tmp.name
plt.savefig(image_path)
plt.close(fig)
try:
plt.savefig(image_path)
with open(image_path, 'rb') as photo:
await update.message.reply_photo(photo, caption="这是您按类别统计的每月订阅总支出。")
finally:
if os.path.exists(image_path):
plt.close(fig)
if image_path and os.path.exists(image_path):
os.remove(image_path)
@@ -486,9 +472,9 @@ async def export_command(update: Update, context: CallbackContext):
with tempfile.NamedTemporaryFile(prefix=f'export_{user_id}_', suffix='.csv', delete=False) as tmp:
export_path = tmp.name
try:
df.to_csv(export_path, index=False, encoding='utf-8-sig')
try:
with open(export_path, 'rb') as file:
await update.message.reply_document(document=file, filename='subscriptions.csv',
caption="您的订阅数据已导出为 CSV 文件。")
@@ -545,14 +531,20 @@ async def import_upload_received(update: Update, context: CallbackContext):
renewal_type = str(row['renewal_type']).lower()
if renewal_type not in valid_renewal_types:
raise ValueError(f"无效续费类型: {renewal_type}")
notes = str(row['notes']) if pd.notna(row['notes']) else None
notes = str(row['notes']).strip() if pd.notna(row['notes']) else None
name = str(row['name']).strip()
category = str(row['category']).strip()
if not name:
raise ValueError("名称不能为空")
if not category:
raise ValueError("类别不能为空")
records.append((
user_id, row['name'], cost, currency, row['category'],
user_id, name, cost, currency, category,
next_due, frequency_unit, frequency_value, renewal_type, notes
))
except Exception as e:
logger.error(f"Invalid row in CSV: {row.to_dict()}, error: {e}")
await update.message.reply_text(f"导入失败,行数据无效:{row.to_dict()},错误{e}")
logger.error(f"Invalid row in CSV import, error: {e}")
await update.message.reply_text(f"导入失败,存在无效行:{e}")
return ConversationHandler.END
with get_db_connection() as conn:
@@ -926,6 +918,12 @@ async def button_callback_handler(update: Update, context: CallbackContext):
parse_mode='MarkdownV2', reply_markup=None)
elif action == 'delete':
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
if not cursor.fetchone():
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
return
keyboard = InlineKeyboardMarkup([
[InlineKeyboardButton("✅ 是的,删除", callback_data=f'confirmdelete_{sub_id}'),
InlineKeyboardButton("❌ 取消", callback_data=f'view_{sub_id}')]
@@ -935,7 +933,11 @@ async def button_callback_handler(update: Update, context: CallbackContext):
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM subscriptions WHERE id = ? AND user_id = ?", (sub_id, user_id))
deleted = cursor.rowcount
conn.commit()
if deleted == 0:
await query.answer("错误:找不到该订阅或无权限。", show_alert=True)
return
await query.answer("订阅已删除")
if 'list_subs_in_category' in context.user_data:
category = context.user_data['list_subs_in_category']
@@ -1068,7 +1070,7 @@ async def edit_freq_value_received(update: Update, context: CallbackContext):
conn.commit()
await update.message.reply_text("✅ 周期已更新!")
context.user_data.clear()
_clear_action_state(context, ['sub_id_for_action', 'new_freq_unit', 'field_to_edit'])
await show_subscription_view(update, context, sub_id)
return ConversationHandler.END
@@ -1131,6 +1133,11 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
validation_failed = True
else:
new_value = parsed
elif field == 'renewal_type':
if str(new_value) not in ('auto', 'manual'):
if message_to_reply:
await message_to_reply.reply_text("续费方式只能为 auto 或 manual。")
validation_failed = True
elif field == 'category':
new_value = str(new_value).strip()
if not new_value:
@@ -1161,7 +1168,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
elif message_to_reply:
await message_to_reply.reply_text("✅ 字段已更新!")
context.user_data.clear()
_clear_action_state(context, ['sub_id_for_action', 'field_to_edit', 'new_freq_unit'])
await show_subscription_view(update, context, sub_id)
return ConversationHandler.END
@@ -1282,7 +1289,7 @@ async def remind_days_received(update: Update, context: CallbackContext):
return ConversationHandler.END
conn.commit()
await update.message.reply_text(f"✅ 提前提醒天数已设置为: {days}天。")
context.user_data.clear()
_clear_action_state(context, ['sub_id_for_action'])
await show_subscription_view(update, context, sub_id)
except (ValueError, TypeError):
await update.message.reply_text("请输入一个有效的非负整数。")