aboutsummaryrefslogblamecommitdiff
path: root/bot.py
blob: 364a85f90b5e86ed8609e3cbbff6822c96bb258e (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                     
 
              
              

                                 
                                                                                                                               


                                            
                                                                         
                                   





                           



                         

                        
 
                      
                                     




                                     
                                          





                          
                       
                   


                              
                               

                   

                    
               
                  
                      



                            
              
              

                   
 
 


                                                            
                                 
 




                                                                                     
 


                              
              
                


                                                                          








                                                                                                   


                                                                                                                   



                                            
 




                                          

                                                                                             




                                                                                


                                                               
                                                                
                       
 



                                                                                                                            



                                                                  
                   
                                                                                                 
     
 

                                                      
                                          
                      
                                                   
                                                        
                                                                        

                


                                                                                                  

                                                                            
              
                                      
                                                                                                  
                                                                            
              

                                                                     
                                                                                         
              
                          
                                                                                      
                                                            
              

                                                                          
          
                                                             
                          
                           

     
                       
                                                               












                                                                


                                                                     

                                                                                                             


                                                               



                                                                                
                                                                 



                                                                                      
                                        







                                                                                       
                                                     
                                           





                                                                                       

                                                                                                     








                                                                                      
                                                                                       






                               
                           







                                            


                          
                       
#!/usr/bin/env python

import asyncio
import logging
from dataclasses import dataclass
from http import HTTPStatus
from config import BOT_TOKEN, TELEGRAM_WEBHOOK_URL, HEALTHCHECK_URL, FEDI_LOGIN_CALLBACK_URL, BOT_DOMAIN, BOT_PORT, ENCRYPT_KEY
import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, JSONResponse
from starlette.routing import Route
from telegram import Update
from telegram.ext import (
    Application,
    CallbackContext,
    ContextTypes,
    ExtBot,
    CallbackQueryHandler,
    CommandHandler,
    MessageHandler,
    filters,
    ConversationHandler,
    TypeHandler,
)
from callback import (
    callback_generate_fedi_login_url,
    callback_skip_media,
    callback_location_sharing,
    callback_manual_location,
    callback_location_confirmation,
    callback_location_keyword_search,
    callback_skip_location_keyword_search,
    callback_add_comment,
    callback_skip_comment,
    callback_add_media
)
from command import (
    start_command,
    fedi_login_command,
    cancel_command,
    help_command,
    tos_command,
    toggle_visibility_command,
    callback_toggle_visibility,
    logout_command,
    list_command
)
from config import (
    FEDI_LOGIN,
    WAIT_LOCATION,
    PROMPT_FEDI_LOGIN,
    LOCATION_SEARCH_KEYWORD,
    LOCATION_CONFIRMATION,
    ADD_MEDIA,
    ADD_COMMENT,
    BOT_TOKEN,
    BOT_SCOPE,
    MAIN_MENU,
    WAIT_VISIBILITY
)

from prompt.string import PROMPT_CHOOSE_ACTION
from mastodon import Mastodon
from dbstore.peewee_store import db, User, get_user_by_state
from util import encrypt, decrypt

logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)


@dataclass
class FediLoginCallbackUpdate:
    code: str
    state: str
    user_id: str


class FediLoginCallbackContext(CallbackContext[ExtBot, dict, dict, dict]):
    """
    Custom CallbackContext class that makes `user_data` available for updates of type
    `WebhookUpdate`.
    """
    @classmethod
    def from_update(cls, update: object, application: "Application") -> "FediLoginCallbackContext":
        if isinstance(update, FediLoginCallbackUpdate):
            return cls(application=application, user_id=int(update.user_id))
        return super().from_update(update, application)


async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: FediLoginCallbackContext) -> None:
    state = update.state

    with db.connection_context():
        user = User.get(User.state == state)

        client_id = user.client_id
        client_secret = user.client_secret
        home_instance = user.home_instance

        if len(user.access_key) == 0:
            mastodon_client = Mastodon(client_id=client_id, client_secret=client_secret,
                                       api_base_url=home_instance, version_check_mode="none")
            access_token = mastodon_client.log_in(
                code=update.code,
                redirect_uri="{}{}".format(BOT_DOMAIN, FEDI_LOGIN_CALLBACK_URL),
                scopes=BOT_SCOPE
            )
            instance_info = mastodon_client.instance_nodeinfo()
            if instance_info["software"]["name"] == "pleroma":
                user.home_instance_type = "pleroma"
            user.access_key = encrypt(access_token, ENCRYPT_KEY)
            user.save()

            text = "You have successfully logged in to your Mastodon account!"
            await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN])
            await context.bot.send_message(chat_id=user.telegram_user_id, text=text)
            await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU)


async def main() -> None:
    context_types = ContextTypes(context=FediLoginCallbackContext)
    application = (
        Application.builder().updater(None).token(BOT_TOKEN).context_types(context_types).build()
    )

    # TODO:
    # check user login status before invoking commands
    checkin_handler = ConversationHandler(
        entry_points=[
            CommandHandler("start", start_command),
            CommandHandler("login", fedi_login_command),
            MessageHandler(filters.LOCATION, callback_location_sharing),
        ],
        states={
            FEDI_LOGIN: [
                MessageHandler(filters.TEXT & ~filters.COMMAND, callback_generate_fedi_login_url),
            ],
            WAIT_LOCATION: [
                MessageHandler(filters.LOCATION, callback_location_sharing),
            ],
            LOCATION_SEARCH_KEYWORD: [
                MessageHandler(filters.TEXT & ~filters.COMMAND, callback_location_keyword_search),
                CallbackQueryHandler(callback_skip_location_keyword_search),
            ],
            LOCATION_CONFIRMATION: [
                CallbackQueryHandler(callback_location_confirmation),
                MessageHandler(filters.TEXT & ~filters.COMMAND, callback_manual_location)
            ],
            ADD_COMMENT: [
                MessageHandler(filters.TEXT & ~filters.COMMAND, callback_add_comment),
                CallbackQueryHandler(callback_skip_comment),
            ],
            ADD_MEDIA: [MessageHandler(filters.PHOTO, callback_add_media),
                        CallbackQueryHandler(callback_skip_media)],
        },
        fallbacks=[CommandHandler("cancel", cancel_command)],
        per_message=False,
        allow_reentry=True,
    )

    # register handlers
    application.add_handler(CommandHandler("tos", tos_command))
    visibility_conversation_handler = ConversationHandler(
        entry_points=[
            CommandHandler("vis", toggle_visibility_command)
        ],
        states={
            WAIT_VISIBILITY: [
                CallbackQueryHandler(callback_toggle_visibility)
        ]},
        fallbacks=[CommandHandler("cancel", cancel_command)],
        per_message=False,
        allow_reentry=True,
    )

    application.add_handler(CommandHandler("logout", logout_command))
    application.add_handler(CommandHandler("list", list_command))
    application.add_handler(CommandHandler("Help", help_command))
    application.add_handler(TypeHandler(type=FediLoginCallbackUpdate, callback=process_oauth_login_callback))

    application.add_handler(visibility_conversation_handler, 2)
    application.add_handler(checkin_handler, 1)

    # Pass webhook settings to telegram
    await application.bot.set_webhook(url=f"{BOT_DOMAIN}{TELEGRAM_WEBHOOK_URL}")

    # Set up webserver
    async def telegram_webhook(request: Request) -> JSONResponse:
        """Handle incoming Telegram updates by putting them into the `update_queue`"""
        await application.update_queue.put(
            Update.de_json(data=await request.json(), bot=application.bot)
        )
        return JSONResponse({'OK': 200})

    async def fedi_oauth_login_callback(request: Request) -> PlainTextResponse:
        """
        Handle incoming webhook updates by also putting them into the `update_queue` if
        the required parameters were passed correctly.
        """
        try:
            code = request.query_params["code"]
            state = request.query_params.get("state")
            user = get_user_by_state(state)
        except KeyError:
            return PlainTextResponse(
                status_code=HTTPStatus.BAD_REQUEST,
                content="Mastodon callback request doesn't contain a valid OAuth code",
            )

        await application.update_queue.put(FediLoginCallbackUpdate(state=state, code=code,
                                                                   user_id=user["telegram_user_id"]))
        return PlainTextResponse("Thank you for login! Now you can close the browser")

    async def healthcheck(_: Request) -> PlainTextResponse:
        return PlainTextResponse(content="OK")

    starlette_app = Starlette(
        routes=[
            Route(TELEGRAM_WEBHOOK_URL, telegram_webhook, methods=["POST"]),
            Route(HEALTHCHECK_URL, healthcheck, methods=["GET"]),
            Route(FEDI_LOGIN_CALLBACK_URL, fedi_oauth_login_callback, methods=["GET"]),
        ]
    )
    webserver = uvicorn.Server(
        config=uvicorn.Config(
            app=starlette_app,
            port=BOT_PORT,
            use_colors=False,
            host="0.0.0.0",
        )
    )

    # Run application and webserver together
    async with application:
        await application.start()
        await webserver.serve()
        await application.stop()


if __name__ == "__main__":
    asyncio.run(main())
Powered by cgit v1.2.3 (git 2.41.0)