refactor: centralize enum validation for unit and renewal type

This commit is contained in:
Xiaolan Bot
2026-02-22 11:33:41 +08:00
parent 052966e07c
commit d212d73c2a

View File

@@ -255,6 +255,8 @@ EDITABLE_SUB_FIELDS = {
MAX_NAME_LEN = 128 MAX_NAME_LEN = 128
MAX_CATEGORY_LEN = 64 MAX_CATEGORY_LEN = 64
MAX_NOTES_LEN = 1000 MAX_NOTES_LEN = 1000
VALID_FREQ_UNITS = {'day', 'week', 'month', 'year'}
VALID_RENEWAL_TYPES = {'auto', 'manual'}
def _build_category_callback_data(category_id: int) -> str: def _build_category_callback_data(category_id: int) -> str:
@@ -679,7 +681,11 @@ async def add_next_due_received(update: Update, context: CallbackContext):
async def add_freq_unit_received(update: Update, context: CallbackContext): async def add_freq_unit_received(update: Update, context: CallbackContext):
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
context.user_data['new_sub_data']['unit'] = query.data.split('_')[2] unit = query.data.split('_')[2]
if unit not in VALID_FREQ_UNITS:
await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END
context.user_data['new_sub_data']['unit'] = unit
await query.edit_message_text("第七步:请输入周期的*数量*例如每3个月输入 3", parse_mode='Markdown') await query.edit_message_text("第七步:请输入周期的*数量*例如每3个月输入 3", parse_mode='Markdown')
return ADD_FREQ_VALUE return ADD_FREQ_VALUE
@@ -705,7 +711,11 @@ async def add_freq_value_received(update: Update, context: CallbackContext):
async def add_renewal_type_received(update: Update, context: CallbackContext): async def add_renewal_type_received(update: Update, context: CallbackContext):
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
context.user_data['new_sub_data']['renewal_type'] = query.data.split('_')[1] renewal_type = query.data.split('_')[1]
if renewal_type not in VALID_RENEWAL_TYPES:
await query.edit_message_text("错误:无效的续费类型,请重试。")
return ConversationHandler.END
context.user_data['new_sub_data']['renewal_type'] = renewal_type
await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip") await query.edit_message_text("最后一步(可选):需要添加备注吗?\n(如:共享账号、用途等。不需要请 /skip")
return ADD_NOTES return ADD_NOTES
@@ -1083,7 +1093,11 @@ async def edit_field_selected(update: Update, context: CallbackContext):
async def edit_freq_unit_received(update: Update, context: CallbackContext): async def edit_freq_unit_received(update: Update, context: CallbackContext):
query = update.callback_query query = update.callback_query
await query.answer() await query.answer()
context.user_data['new_freq_unit'] = query.data.split('_')[2] unit = query.data.split('_')[2]
if unit not in VALID_FREQ_UNITS:
await query.edit_message_text("错误:无效的周期单位,请重试。")
return ConversationHandler.END
context.user_data['new_freq_unit'] = unit
await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2') await query.edit_message_text("好的,现在请输入新的周期*数量*。", parse_mode='MarkdownV2')
return EDIT_FREQ_VALUE return EDIT_FREQ_VALUE
@@ -1188,7 +1202,7 @@ async def edit_new_value_received(update: Update, context: CallbackContext):
else: else:
new_value = parsed new_value = parsed
elif field == 'renewal_type': elif field == 'renewal_type':
if str(new_value) not in ('auto', 'manual'): if str(new_value) not in VALID_RENEWAL_TYPES:
if message_to_reply: if message_to_reply:
await message_to_reply.reply_text("续费方式只能为 auto 或 manual。") await message_to_reply.reply_text("续费方式只能为 auto 或 manual。")
validation_failed = True validation_failed = True