diff options
-rw-r--r-- | bot.py | 41 | ||||
-rw-r--r-- | command.py | 14 | ||||
-rw-r--r-- | dbstore/peewee_store.py | 9 |
3 files changed, 42 insertions, 22 deletions
@@ -5,32 +5,24 @@ import logging | |||
5 | from dataclasses import dataclass | 5 | from dataclasses import dataclass |
6 | from http import HTTPStatus | 6 | from http import HTTPStatus |
7 | from config import BOT_TOKEN, TELEGRAM_WEBHOOK_URL, HEALTHCHECK_URL, FEDI_LOGIN_CALLBACK_URL, BOT_DOMAIN, BOT_PORT | 7 | from config import BOT_TOKEN, TELEGRAM_WEBHOOK_URL, HEALTHCHECK_URL, FEDI_LOGIN_CALLBACK_URL, BOT_DOMAIN, BOT_PORT |
8 | |||
9 | import uvicorn | 8 | import uvicorn |
10 | from starlette.applications import Starlette | 9 | from starlette.applications import Starlette |
11 | from starlette.requests import Request | 10 | from starlette.requests import Request |
12 | from starlette.responses import PlainTextResponse, Response | 11 | from starlette.responses import PlainTextResponse, Response |
13 | from starlette.routing import Route | 12 | from starlette.routing import Route |
14 | |||
15 | |||
16 | from telegram import Update | 13 | from telegram import Update |
17 | from telegram.ext import ( | 14 | from telegram.ext import ( |
18 | Application, | 15 | Application, |
19 | CallbackContext, | 16 | CallbackContext, |
20 | ContextTypes, | 17 | ContextTypes, |
21 | ExtBot, | 18 | ExtBot, |
22 | TypeHandler, | ||
23 | ) | ||
24 | |||
25 | from telegram.ext import ( | ||
26 | Application, | ||
27 | CallbackQueryHandler, | 19 | CallbackQueryHandler, |
28 | CommandHandler, | 20 | CommandHandler, |
29 | MessageHandler, | 21 | MessageHandler, |
30 | filters, | 22 | filters, |
31 | ConversationHandler | 23 | ConversationHandler, |
24 | TypeHandler, | ||
32 | ) | 25 | ) |
33 | |||
34 | from callback import ( | 26 | from callback import ( |
35 | callback_generate_fedi_login_url, | 27 | callback_generate_fedi_login_url, |
36 | callback_skip_media, | 28 | callback_skip_media, |
@@ -52,18 +44,20 @@ from command import ( | |||
52 | from config import ( | 44 | from config import ( |
53 | FEDI_LOGIN, | 45 | FEDI_LOGIN, |
54 | WAIT_LOCATION, | 46 | WAIT_LOCATION, |
47 | PROMPT_FEDI_LOGIN, | ||
55 | LOCATION_SEARCH_KEYWORD, | 48 | LOCATION_SEARCH_KEYWORD, |
56 | LOCATION_CONFIRMATION, | 49 | LOCATION_CONFIRMATION, |
57 | ADD_MEDIA, | 50 | ADD_MEDIA, |
58 | ADD_COMMENT, | 51 | ADD_COMMENT, |
59 | BOT_TOKEN, | 52 | BOT_TOKEN, |
60 | BOT_SCOPE | 53 | BOT_SCOPE, |
54 | MAIN_MENU | ||
61 | ) | 55 | ) |
62 | from mastodon import Mastodon | ||
63 | 56 | ||
64 | from dbstore.peewee_store import db, User | 57 | from prompt.string import PROMPT_CHOOSE_ACTION |
58 | from mastodon import Mastodon | ||
59 | from dbstore.peewee_store import db, User, get_user_by_state | ||
65 | 60 | ||
66 | # Enable logging | ||
67 | logging.basicConfig( | 61 | logging.basicConfig( |
68 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | 62 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO |
69 | ) | 63 | ) |
@@ -74,10 +68,19 @@ logger = logging.getLogger(__name__) | |||
74 | class FediLoginCallbackUpdate: | 68 | class FediLoginCallbackUpdate: |
75 | code: str | 69 | code: str |
76 | state: str | 70 | state: str |
71 | user_id: str | ||
77 | 72 | ||
78 | 73 | ||
79 | class FediLoginCallbackContext(CallbackContext[ExtBot, dict, dict, dict]): | 74 | class FediLoginCallbackContext(CallbackContext[ExtBot, dict, dict, dict]): |
80 | pass | 75 | """ |
76 | Custom CallbackContext class that makes `user_data` available for updates of type | ||
77 | `WebhookUpdate`. | ||
78 | """ | ||
79 | @classmethod | ||
80 | def from_update(cls, update: object, application: "Application") -> "FediLoginCallbackContext": | ||
81 | if isinstance(update, FediLoginCallbackUpdate): | ||
82 | return cls(application=application, user_id=int(update.user_id)) | ||
83 | return super().from_update(update, application) | ||
81 | 84 | ||
82 | 85 | ||
83 | async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: FediLoginCallbackContext) -> None: | 86 | async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: FediLoginCallbackContext) -> None: |
@@ -101,13 +104,13 @@ async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: | |||
101 | user.save() | 104 | user.save() |
102 | 105 | ||
103 | text = "You have successfully logged in to your Mastodon account!" | 106 | text = "You have successfully logged in to your Mastodon account!" |
107 | await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN]) | ||
104 | await context.bot.send_message(chat_id=user.telegram_user_id, text=text) | 108 | await context.bot.send_message(chat_id=user.telegram_user_id, text=text) |
109 | await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) | ||
105 | 110 | ||
106 | 111 | ||
107 | async def main() -> None: | 112 | async def main() -> None: |
108 | context_types = ContextTypes(context=FediLoginCallbackContext) | 113 | context_types = ContextTypes(context=FediLoginCallbackContext) |
109 | # Here we set updater to None because we want our custom webhook server to handle the updates | ||
110 | # and hence we don't need an Updater instance | ||
111 | application = ( | 114 | application = ( |
112 | Application.builder().updater(None).token(BOT_TOKEN).context_types(context_types).build() | 115 | Application.builder().updater(None).token(BOT_TOKEN).context_types(context_types).build() |
113 | ) | 116 | ) |
@@ -169,13 +172,15 @@ async def main() -> None: | |||
169 | try: | 172 | try: |
170 | code = request.query_params["code"] | 173 | code = request.query_params["code"] |
171 | state = request.query_params.get("state") | 174 | state = request.query_params.get("state") |
175 | user = get_user_by_state(state) | ||
172 | except KeyError: | 176 | except KeyError: |
173 | return PlainTextResponse( | 177 | return PlainTextResponse( |
174 | status_code=HTTPStatus.BAD_REQUEST, | 178 | status_code=HTTPStatus.BAD_REQUEST, |
175 | content="Mastodon callback request doesn't contain a valid OAuth code", | 179 | content="Mastodon callback request doesn't contain a valid OAuth code", |
176 | ) | 180 | ) |
177 | 181 | ||
178 | await application.update_queue.put(FediLoginCallbackUpdate(state=state, code=code)) | 182 | await application.update_queue.put(FediLoginCallbackUpdate(state=state, code=code, |
183 | user_id=user["telegram_user_id"])) | ||
179 | return PlainTextResponse("Thank you for login! Now you can close the browser") | 184 | return PlainTextResponse("Thank you for login! Now you can close the browser") |
180 | 185 | ||
181 | async def healthcheck(_: Request) -> PlainTextResponse: | 186 | async def healthcheck(_: Request) -> PlainTextResponse: |
@@ -2,15 +2,21 @@ from telegram import Update | |||
2 | from telegram.constants import ParseMode | 2 | from telegram.constants import ParseMode |
3 | from telegram.error import BadRequest | 3 | from telegram.error import BadRequest |
4 | from telegram.ext import ContextTypes, ConversationHandler | 4 | from telegram.ext import ContextTypes, ConversationHandler |
5 | 5 | from dbstore.peewee_store import get_user_access_key | |
6 | from config import * | 6 | from config import * |
7 | 7 | ||
8 | 8 | ||
9 | async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 9 | async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: |
10 | await update.message.reply_text(PROMPT_START, parse_mode=ParseMode.MARKDOWN) | 10 | await update.message.reply_text(PROMPT_START, parse_mode=ParseMode.MARKDOWN) |
11 | await update.message.reply_text(PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) | 11 | user_access_key = get_user_access_key(str(update.effective_user.id)) |
12 | 12 | # TODO | |
13 | return WAIT_LOCATION | 13 | # verify user access key still valid |
14 | if len(user_access_key) == 0: | ||
15 | await update.message.reply_text(PROMPT_FEDI_LOGIN_WHERE_IS_INSTANCE, parse_mode=ParseMode.MARKDOWN) | ||
16 | return FEDI_LOGIN | ||
17 | else: | ||
18 | await update.message.reply_text(PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) | ||
19 | return WAIT_LOCATION | ||
14 | 20 | ||
15 | 21 | ||
16 | async def fedi_login_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 22 | async def fedi_login_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: |
diff --git a/dbstore/peewee_store.py b/dbstore/peewee_store.py index 3d37140..272e230 100644 --- a/dbstore/peewee_store.py +++ b/dbstore/peewee_store.py | |||
@@ -41,6 +41,15 @@ def get_user_by_state(state: str) -> dict: | |||
41 | return {} | 41 | return {} |
42 | 42 | ||
43 | 43 | ||
44 | def get_user_access_key(telegram_user_id: str) -> str: | ||
45 | with db.connection_context(): | ||
46 | try: | ||
47 | user = User.get(User.telegram_user_id == telegram_user_id) | ||
48 | return user.access_key | ||
49 | except DoesNotExist: | ||
50 | return "" | ||
51 | |||
52 | |||
44 | class Location(BaseModel): | 53 | class Location(BaseModel): |
45 | fsq_id = CharField(unique=True, primary_key=True) | 54 | fsq_id = CharField(unique=True, primary_key=True) |
46 | name = CharField(max_length=128) | 55 | name = CharField(max_length=128) |