#!/usr/bin/env python import asyncio import logging from dataclasses import dataclass from http import HTTPStatus from config import BOT_TOKEN, TELEGRAM_WEBHOOK_URL, HEALTHCHECK_URL, FEDI_LOGIN_CALLBACK_URL, BOT_DOMAIN, BOT_PORT import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.routing import Route from telegram import Update from telegram.ext import ( Application, CallbackContext, ContextTypes, ExtBot, CallbackQueryHandler, CommandHandler, MessageHandler, filters, ConversationHandler, TypeHandler, ) from callback import ( callback_generate_fedi_login_url, callback_skip_media, callback_location_sharing, callback_manual_location, callback_location_confirmation, callback_location_keyword_search, callback_skip_location_keyword_search, callback_add_comment, callback_skip_comment, callback_add_media ) from command import ( start_command, fedi_login_command, cancel_command, help_command ) from config import ( FEDI_LOGIN, WAIT_LOCATION, PROMPT_FEDI_LOGIN, LOCATION_SEARCH_KEYWORD, LOCATION_CONFIRMATION, ADD_MEDIA, ADD_COMMENT, BOT_TOKEN, BOT_SCOPE, MAIN_MENU ) from prompt.string import PROMPT_CHOOSE_ACTION from mastodon import Mastodon from dbstore.peewee_store import db, User, get_user_by_state logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) @dataclass class FediLoginCallbackUpdate: code: str state: str user_id: str class FediLoginCallbackContext(CallbackContext[ExtBot, dict, dict, dict]): """ Custom CallbackContext class that makes `user_data` available for updates of type `WebhookUpdate`. """ @classmethod def from_update(cls, update: object, application: "Application") -> "FediLoginCallbackContext": if isinstance(update, FediLoginCallbackUpdate): return cls(application=application, user_id=int(update.user_id)) return super().from_update(update, application) async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: FediLoginCallbackContext) -> None: state = update.state with db.connection_context(): user = User.get(User.state == state) client_id = user.client_id client_secret = user.client_secret home_instance = user.home_instance if len(user.access_key) == 0: mastodon_client = Mastodon(client_id=client_id, client_secret=client_secret, api_base_url=home_instance) access_token = mastodon_client.log_in( code=update.code, redirect_uri="{}{}".format(BOT_DOMAIN, FEDI_LOGIN_CALLBACK_URL), scopes=BOT_SCOPE ) user.access_key = access_token user.save() text = "You have successfully logged in to your Mastodon account!" await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN]) await context.bot.send_message(chat_id=user.telegram_user_id, text=text) await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) async def main() -> None: context_types = ContextTypes(context=FediLoginCallbackContext) application = ( Application.builder().updater(None).token(BOT_TOKEN).context_types(context_types).build() ) checkin_handler = ConversationHandler( entry_points=[ CommandHandler("start", start_command), CommandHandler("login", fedi_login_command), MessageHandler(filters.LOCATION, callback_location_sharing), ], states={ FEDI_LOGIN: [ MessageHandler(filters.TEXT & ~filters.COMMAND, callback_generate_fedi_login_url), ], WAIT_LOCATION: [ MessageHandler(filters.LOCATION, callback_location_sharing), ], LOCATION_SEARCH_KEYWORD: [ MessageHandler(filters.TEXT & ~filters.COMMAND, callback_location_keyword_search), CallbackQueryHandler(callback_skip_location_keyword_search), ], LOCATION_CONFIRMATION: [ CallbackQueryHandler(callback_location_confirmation), MessageHandler(filters.TEXT & ~filters.COMMAND, callback_manual_location) ], ADD_COMMENT: [ MessageHandler(filters.TEXT & ~filters.COMMAND, callback_add_comment), CallbackQueryHandler(callback_skip_comment), ], ADD_MEDIA: [MessageHandler(filters.PHOTO, callback_add_media), CallbackQueryHandler(callback_skip_media)], }, fallbacks=[CommandHandler("cancel", cancel_command)], per_message=False, allow_reentry=True, ) # register handlers application.add_handler(CommandHandler("help", help_command)) application.add_handler(checkin_handler) application.add_handler(TypeHandler(type=FediLoginCallbackUpdate, callback=process_oauth_login_callback)) # Pass webhook settings to telegram await application.bot.set_webhook(url=f"{BOT_DOMAIN}{TELEGRAM_WEBHOOK_URL}") # Set up webserver async def telegram_webhook(request: Request) -> Response: """Handle incoming Telegram updates by putting them into the `update_queue`""" await application.update_queue.put( Update.de_json(data=await request.json(), bot=application.bot) ) return Response() async def fedi_oauth_login_callback(request: Request) -> PlainTextResponse: """ Handle incoming webhook updates by also putting them into the `update_queue` if the required parameters were passed correctly. """ try: code = request.query_params["code"] state = request.query_params.get("state") user = get_user_by_state(state) except KeyError: return PlainTextResponse( status_code=HTTPStatus.BAD_REQUEST, content="Mastodon callback request doesn't contain a valid OAuth code", ) await application.update_queue.put(FediLoginCallbackUpdate(state=state, code=code, user_id=user["telegram_user_id"])) return PlainTextResponse("Thank you for login! Now you can close the browser") async def healthcheck(_: Request) -> PlainTextResponse: return PlainTextResponse(content="OK") starlette_app = Starlette( routes=[ Route(TELEGRAM_WEBHOOK_URL, telegram_webhook, methods=["POST"]), Route(HEALTHCHECK_URL, healthcheck, methods=["GET"]), Route(FEDI_LOGIN_CALLBACK_URL, fedi_oauth_login_callback, methods=["POST", "GET"]), ] ) webserver = uvicorn.Server( config=uvicorn.Config( app=starlette_app, port=BOT_PORT, use_colors=False, host="127.0.0.1", ) ) # Run application and webserver together async with application: await application.start() await webserver.serve() await application.stop() if __name__ == "__main__": asyncio.run(main())