From a1c7c5a223d22017508998599388c5adf9a90713 Mon Sep 17 00:00:00 2001 From: clarkzjw Date: Mon, 27 Feb 2023 15:48:09 -0800 Subject: bot: add check_user decorator to check user login status --- bot.py | 8 ++++---- callback.py | 49 +++++++++++++++++++++++++++---------------------- command.py | 15 +++++++++------ util.py | 27 +++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/bot.py b/bot.py index bcfab8f..364a85f 100644 --- a/bot.py +++ b/bot.py @@ -114,10 +114,10 @@ async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: user.access_key = encrypt(access_token, ENCRYPT_KEY) user.save() - text = "You have successfully logged in to your Mastodon account!" - await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN]) - await context.bot.send_message(chat_id=user.telegram_user_id, text=text) - await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) + text = "You have successfully logged in to your Mastodon account!" + await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN]) + await context.bot.send_message(chat_id=user.telegram_user_id, text=text) + await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) async def main() -> None: diff --git a/callback.py b/callback.py index 7530afb..5fe4593 100644 --- a/callback.py +++ b/callback.py @@ -12,7 +12,7 @@ from config import BOT_SCOPE, ENCRYPT_KEY from dbstore.peewee_store import User, db, TOOT_VISIBILITY_PRIVATE, TOOT_VISIBILITY_PUBLIC, TOOT_VISIBILITY_UNLISTED import uuid from mastodon import Mastodon -from util import decrypt +from util import decrypt, check_user def generate_uuid(): @@ -118,7 +118,8 @@ async def callback_generate_fedi_login_url(update: Update, context: ContextTypes return FEDI_LOGIN -async def callback_location_sharing(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_location_sharing(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: if update.message.venue is not None: context.user_data["fsq_id"] = update.message.venue.foursquare_id context.user_data["title"] = update.message.venue.title @@ -128,11 +129,10 @@ async def callback_location_sharing(update: Update, context: ContextTypes.DEFAUL poi = query_poi_by_fsq_id(context.user_data.get("fsq_id")) content = generate_toot_text(poi["name"], poi["locality"], poi["region"], poi["latitude"], poi["longitude"]) - u = get_user_by_id(str(update.effective_user.id)) - content_type = "text/markdown" if u["home_instance_type"] == "pleroma" else None + content_type = "text/markdown" if user["home_instance_type"] == "pleroma" else None status = get_mastodon_client(update.effective_user.id).status_post(content, - visibility=u["tool_visibility"], + visibility=user["tool_visibility"], content_type=content_type, media_ids=[]) @@ -174,7 +174,8 @@ async def _process_location_search(keyword, lat, lon) -> list: return keyboard -async def callback_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_LOCATION_KEYWORD)) key = update.effective_message.text @@ -191,7 +192,8 @@ async def callback_location_keyword_search(update: Update, context: ContextTypes return LOCATION_CONFIRMATION -async def callback_skip_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_skip_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: query = update.callback_query await query.answer() await query.message.delete() @@ -202,7 +204,7 @@ async def callback_skip_location_keyword_search(update: Update, context: Context return LOCATION_CONFIRMATION -async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE) -> int: +async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE, user: User) -> int: poi_name = context.user_data.get("poi_name") if context.user_data.get("fsq_id") is not None: poi = get_poi_by_fsq_id(context.user_data.get("fsq_id")) @@ -212,11 +214,10 @@ async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE) -> int content = generate_toot_text(poi_name, "", "", context.user_data.get("latitude"), context.user_data.get("longitude")) - u = get_user_by_id(context.user_data["user_id"]) - content_type = "text/markdown" if u["home_instance_type"] == "pleroma" else None + content_type = "text/markdown" if user["home_instance_type"] == "pleroma" else None status = get_mastodon_client(context.user_data["user_id"]).status_post(content, - visibility=u["toot_visibility"], + visibility=user["toot_visibility"], content_type=content_type, media_ids=[]) @@ -236,7 +237,8 @@ async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE) -> int return ADD_COMMENT -async def callback_location_confirmation(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_location_confirmation(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: query = update.callback_query await query.answer() context.user_data["fsq_id"] = query.data @@ -245,15 +247,16 @@ async def callback_location_confirmation(update: Update, context: ContextTypes.D context.user_data["chat_id"] = update.effective_chat.id - return await _process_location_selection(context) + return await _process_location_selection(context, user) -async def callback_manual_location(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_manual_location(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: context.user_data["poi_name"] = update.effective_message.text context.user_data["chat_id"] = update.effective_chat.id context.user_data["user_id"] = update.effective_user.id - return await _process_location_selection(context) + return await _process_location_selection(context, user) async def _process_comment(context: ContextTypes.DEFAULT_TYPE) -> int: @@ -264,14 +267,13 @@ async def _process_comment(context: ContextTypes.DEFAULT_TYPE) -> int: return ADD_MEDIA -async def callback_add_comment(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_add_comment(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: context.user_data["chat_id"] = update.effective_chat.id context.user_data["user_id"] = update.effective_user.id await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_ADD_COMMENT)) - with db.connection_context(): - u = User.get(User.telegram_user_id == context.user_data["user_id"]) - content_type = "text/markdown" if u.home_instance_type == "pleroma" else None + content_type = "text/markdown" if user["home_instance_type"] == "pleroma" else None comment = update.effective_message.text get_mastodon_client(update.effective_user.id).status_update(id=context.user_data.get(KEY_TOOT_STATUS_ID), @@ -283,14 +285,16 @@ async def callback_add_comment(update: Update, context: ContextTypes.DEFAULT_TYP return await _process_comment(context) -async def callback_skip_comment(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_skip_comment(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: context.user_data["chat_id"] = update.effective_chat.id await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_ADD_COMMENT)) return await _process_comment(context) -async def callback_add_media(update: Update, context: CallbackContext): +@check_user +async def callback_add_media(update: Update, context: CallbackContext, user: User): await update.message.reply_chat_action(ChatAction.TYPING) try: @@ -342,7 +346,8 @@ async def callback_add_media(update: Update, context: CallbackContext): await update.message.reply_text(text=PROMPT_DONE, reply_markup=MAIN_MENU) -async def callback_skip_media(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_skip_media(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: query = update.callback_query await query.answer() diff --git a/command.py b/command.py index f124e2a..db85538 100644 --- a/command.py +++ b/command.py @@ -1,11 +1,11 @@ -from telegram import Update +from telegram import Update, User from telegram.constants import ParseMode from telegram.error import BadRequest from telegram.ext import ContextTypes, ConversationHandler from dbstore.peewee_store import get_user_access_key, get_user_home_instance, delete_user_by_id, update_user_visibility -from dbstore.peewee_store import get_user_by_id from dbstore.peewee_store import TOOT_VISIBILITY_PRIVATE, TOOT_VISIBILITY_UNLISTED, TOOT_VISIBILITY_PUBLIC from config import * +from util import check_user async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: @@ -34,10 +34,11 @@ async def tos_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non await update.message.reply_text(PROMPT_TOS, parse_mode=ParseMode.HTML, reply_markup=MAIN_MENU) +@check_user async def list_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: result = get_user_home_instance(str(update.effective_user.id)) if len(result) == 0: - await update.message.reply_text(PROMPT_LIST_NO_RESULT, parse_mode=ParseMode.HTML) + pass else: await update.message.reply_text(f"You are linked with the following Fediverse accounts:\n\n" f"Instance: {result['home_instance']}\n" @@ -47,26 +48,28 @@ async def list_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No reply_markup=MAIN_MENU) +@check_user async def logout_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: if delete_user_by_id(str(update.effective_user.id)): await update.message.reply_text(PROMPT_LOGOUT_SUCCESS, parse_mode=ParseMode.HTML, reply_markup=LOGIN_MENU) -async def toggle_visibility_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def toggle_visibility_command(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: visibility_menu = InlineKeyboardMarkup([ [InlineKeyboardButton("Private", callback_data=TOOT_VISIBILITY_PRIVATE)], [InlineKeyboardButton("Unlisted", callback_data=TOOT_VISIBILITY_UNLISTED)], [InlineKeyboardButton("Public", callback_data=TOOT_VISIBILITY_PUBLIC)] ]) - user = get_user_by_id(str(update.effective_user.id)) await update.message.reply_text(PROMPT_TOGGLE_VIS.format(user["toot_visibility"]), parse_mode=ParseMode.HTML, reply_markup=visibility_menu) return WAIT_VISIBILITY -async def callback_toggle_visibility(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: +@check_user +async def callback_toggle_visibility(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: query = update.callback_query await query.answer() diff --git a/util.py b/util.py index 61efa0b..a0cda5e 100644 --- a/util.py +++ b/util.py @@ -1,4 +1,7 @@ from cryptography.fernet import Fernet +from telegram import Update +from dbstore.peewee_store import db, User, get_user_by_id +import functools def encrypt(input: str, key: str) -> str: @@ -11,6 +14,30 @@ def decrypt(input: str, key: str) -> str: return f.decrypt(input).decode('utf-8') +def check_user(fn): + """ + Decorator: loads User model and passes it to the function or stops the request. + Ref: https://shallowdepth.online/posts/2021/12/using-python-decorators-to-process-and-authorize-requests/ + """ + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + # Expects that Update object is always the first arg + update: Update = args[0] + user = get_user_by_id(str(update.effective_user.id)) + + if len(user) == 0: + await update.effective_message.reply_text("You are not logged in. Use `/login` to link your account first") + return + else: + # TODO + # check access_key is still valid + pass + return await fn(*args, **kwargs, user=user) + + return wrapper + + # if __name__ == "__main__": # key = Fernet.generate_key().decode('utf-8') # print(key) -- cgit v1.2.3