aboutsummaryrefslogblamecommitdiff
path: root/bot.py
blob: 44505b124781127c300a17596bc87e212f965375 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
         

              
                
 
                     



                                                                                           
                                      
 






                                                                                     
                                                                            
                                                                                    


                                                                              
                                    





















                                                                                                          
         

              


                                                                               

                                      


                            
                                        
 





                                                                               
                                                                                           
                                                                   

                          
                                                                                                      


                   


                                                               
                                                           

                                                                                           




                             
import io
import logging
import os
import traceback

from PIL import Image
from telegram import Update
from telegram.constants import ParseMode
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters

from square import square_size_padding

logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

logger = logging.getLogger(__name__)


async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
    await update.message.reply_text("This is a bot to output image in square shape")


async def process(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
    chat_id = update.message.chat_id
    save_format = "JPEG"

    if update.message.document is not None:
        names = update.message.document.file_name.split(".")
        file_ext = names[1]
        filename = names[0]

        if str.upper(file_ext) in ("JPG", "JPEG"):
            save_format = "JPEG"
        elif str.upper(file_ext) in ("PNG",):
            save_format = str.upper(file_ext)
        else:
            await context.bot.send_message(chat_id, "Image extension `{}` not supported".format(file_ext),
                                           parse_mode=ParseMode.MARKDOWN_V2)
            return

        file = await update.message.effective_attachment.get_file()

    elif update.message.photo is not None:
        filename = update.message.photo[-1].file_unique_id
        file_ext = "JPG"
        file = await update.message.effective_attachment[-1].get_file()
    else:
        return

    await context.bot.send_message(chat_id, "Processing `{}`".format(filename),
                                   parse_mode=ParseMode.MARKDOWN_V2)

    img = io.BytesIO()
    await file.download_to_memory(img)

    try:
        im = Image.open(img)
        result = square_size_padding(im)

        output = io.BytesIO()
        result.save(output, format=save_format, quality=100)

        await update.message.reply_markdown_v2(text="Sending processed result")

        await context.bot.send_document(chat_id=update.message.chat_id,
                                        filename="{}-result.{}".format(filename, file_ext),
                                        document=output.getvalue())

    except Exception as e:
        await update.message.reply_markdown_v2(text="Error:\n```{}```".format(traceback.format_exc()))


def main() -> None:
    tg_token = os.getenv("TG_TOKEN")
    application = Application.builder().token(tg_token).build()

    application.add_handler(CommandHandler("start", start))
    application.add_handler(MessageHandler(filters.ATTACHMENT & ~filters.COMMAND, process))

    application.run_polling()


if __name__ == "__main__":
    main()
Powered by cgit v1.2.3 (git 2.41.0)