aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorclarkzjw <[email protected]>2023-02-23 00:30:52 -0800
committerclarkzjw <[email protected]>2023-02-23 12:07:32 -0800
commitadb574ab528fbc4401fb7cd50b8748d1fb529450 (patch)
tree9cdcad48cc8963d1b2fccb8b4e32cd5a0b6877bc
parent770cd0ec5eb556d912bd9d200b2da76c7e3bd7c8 (diff)
downloadswarm2fediverse-adb574ab528fbc4401fb7cd50b8748d1fb529450.tar.gz
bot: multi user working
-rw-r--r--bot.py41
-rw-r--r--command.py14
-rw-r--r--dbstore/peewee_store.py9
3 files changed, 42 insertions, 22 deletions
diff --git a/bot.py b/bot.py
index a059c9c..35884ff 100644
--- a/bot.py
+++ b/bot.py
@@ -5,32 +5,24 @@ import logging
5from dataclasses import dataclass 5from dataclasses import dataclass
6from http import HTTPStatus 6from http import HTTPStatus
7from config import BOT_TOKEN, TELEGRAM_WEBHOOK_URL, HEALTHCHECK_URL, FEDI_LOGIN_CALLBACK_URL, BOT_DOMAIN, BOT_PORT 7from config import BOT_TOKEN, TELEGRAM_WEBHOOK_URL, HEALTHCHECK_URL, FEDI_LOGIN_CALLBACK_URL, BOT_DOMAIN, BOT_PORT
8
9import uvicorn 8import uvicorn
10from starlette.applications import Starlette 9from starlette.applications import Starlette
11from starlette.requests import Request 10from starlette.requests import Request
12from starlette.responses import PlainTextResponse, Response 11from starlette.responses import PlainTextResponse, Response
13from starlette.routing import Route 12from starlette.routing import Route
14
15
16from telegram import Update 13from telegram import Update
17from telegram.ext import ( 14from telegram.ext import (
18 Application, 15 Application,
19 CallbackContext, 16 CallbackContext,
20 ContextTypes, 17 ContextTypes,
21 ExtBot, 18 ExtBot,
22 TypeHandler,
23)
24
25from telegram.ext import (
26 Application,
27 CallbackQueryHandler, 19 CallbackQueryHandler,
28 CommandHandler, 20 CommandHandler,
29 MessageHandler, 21 MessageHandler,
30 filters, 22 filters,
31 ConversationHandler 23 ConversationHandler,
24 TypeHandler,
32) 25)
33
34from callback import ( 26from callback import (
35 callback_generate_fedi_login_url, 27 callback_generate_fedi_login_url,
36 callback_skip_media, 28 callback_skip_media,
@@ -52,18 +44,20 @@ from command import (
52from config import ( 44from config import (
53 FEDI_LOGIN, 45 FEDI_LOGIN,
54 WAIT_LOCATION, 46 WAIT_LOCATION,
47 PROMPT_FEDI_LOGIN,
55 LOCATION_SEARCH_KEYWORD, 48 LOCATION_SEARCH_KEYWORD,
56 LOCATION_CONFIRMATION, 49 LOCATION_CONFIRMATION,
57 ADD_MEDIA, 50 ADD_MEDIA,
58 ADD_COMMENT, 51 ADD_COMMENT,
59 BOT_TOKEN, 52 BOT_TOKEN,
60 BOT_SCOPE 53 BOT_SCOPE,
54 MAIN_MENU
61) 55)
62from mastodon import Mastodon
63 56
64from dbstore.peewee_store import db, User 57from prompt.string import PROMPT_CHOOSE_ACTION
58from mastodon import Mastodon
59from dbstore.peewee_store import db, User, get_user_by_state
65 60
66# Enable logging
67logging.basicConfig( 61logging.basicConfig(
68 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 62 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
69) 63)
@@ -74,10 +68,19 @@ logger = logging.getLogger(__name__)
74class FediLoginCallbackUpdate: 68class FediLoginCallbackUpdate:
75 code: str 69 code: str
76 state: str 70 state: str
71 user_id: str
77 72
78 73
79class FediLoginCallbackContext(CallbackContext[ExtBot, dict, dict, dict]): 74class FediLoginCallbackContext(CallbackContext[ExtBot, dict, dict, dict]):
80 pass 75 """
76 Custom CallbackContext class that makes `user_data` available for updates of type
77 `WebhookUpdate`.
78 """
79 @classmethod
80 def from_update(cls, update: object, application: "Application") -> "FediLoginCallbackContext":
81 if isinstance(update, FediLoginCallbackUpdate):
82 return cls(application=application, user_id=int(update.user_id))
83 return super().from_update(update, application)
81 84
82 85
83async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: FediLoginCallbackContext) -> None: 86async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: FediLoginCallbackContext) -> None:
@@ -101,13 +104,13 @@ async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context:
101 user.save() 104 user.save()
102 105
103 text = "You have successfully logged in to your Mastodon account!" 106 text = "You have successfully logged in to your Mastodon account!"
107 await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN])
104 await context.bot.send_message(chat_id=user.telegram_user_id, text=text) 108 await context.bot.send_message(chat_id=user.telegram_user_id, text=text)
109 await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU)
105 110
106 111
107async def main() -> None: 112async def main() -> None:
108 context_types = ContextTypes(context=FediLoginCallbackContext) 113 context_types = ContextTypes(context=FediLoginCallbackContext)
109 # Here we set updater to None because we want our custom webhook server to handle the updates
110 # and hence we don't need an Updater instance
111 application = ( 114 application = (
112 Application.builder().updater(None).token(BOT_TOKEN).context_types(context_types).build() 115 Application.builder().updater(None).token(BOT_TOKEN).context_types(context_types).build()
113 ) 116 )
@@ -169,13 +172,15 @@ async def main() -> None:
169 try: 172 try:
170 code = request.query_params["code"] 173 code = request.query_params["code"]
171 state = request.query_params.get("state") 174 state = request.query_params.get("state")
175 user = get_user_by_state(state)
172 except KeyError: 176 except KeyError:
173 return PlainTextResponse( 177 return PlainTextResponse(
174 status_code=HTTPStatus.BAD_REQUEST, 178 status_code=HTTPStatus.BAD_REQUEST,
175 content="Mastodon callback request doesn't contain a valid OAuth code", 179 content="Mastodon callback request doesn't contain a valid OAuth code",
176 ) 180 )
177 181
178 await application.update_queue.put(FediLoginCallbackUpdate(state=state, code=code)) 182 await application.update_queue.put(FediLoginCallbackUpdate(state=state, code=code,
183 user_id=user["telegram_user_id"]))
179 return PlainTextResponse("Thank you for login! Now you can close the browser") 184 return PlainTextResponse("Thank you for login! Now you can close the browser")
180 185
181 async def healthcheck(_: Request) -> PlainTextResponse: 186 async def healthcheck(_: Request) -> PlainTextResponse:
diff --git a/command.py b/command.py
index 901b792..7426901 100644
--- a/command.py
+++ b/command.py
@@ -2,15 +2,21 @@ from telegram import Update
2from telegram.constants import ParseMode 2from telegram.constants import ParseMode
3from telegram.error import BadRequest 3from telegram.error import BadRequest
4from telegram.ext import ContextTypes, ConversationHandler 4from telegram.ext import ContextTypes, ConversationHandler
5 5from dbstore.peewee_store import get_user_access_key
6from config import * 6from config import *
7 7
8 8
9async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: 9async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
10 await update.message.reply_text(PROMPT_START, parse_mode=ParseMode.MARKDOWN) 10 await update.message.reply_text(PROMPT_START, parse_mode=ParseMode.MARKDOWN)
11 await update.message.reply_text(PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) 11 user_access_key = get_user_access_key(str(update.effective_user.id))
12 12 # TODO
13 return WAIT_LOCATION 13 # verify user access key still valid
14 if len(user_access_key) == 0:
15 await update.message.reply_text(PROMPT_FEDI_LOGIN_WHERE_IS_INSTANCE, parse_mode=ParseMode.MARKDOWN)
16 return FEDI_LOGIN
17 else:
18 await update.message.reply_text(PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU)
19 return WAIT_LOCATION
14 20
15 21
16async def fedi_login_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: 22async def fedi_login_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
diff --git a/dbstore/peewee_store.py b/dbstore/peewee_store.py
index 3d37140..272e230 100644
--- a/dbstore/peewee_store.py
+++ b/dbstore/peewee_store.py
@@ -41,6 +41,15 @@ def get_user_by_state(state: str) -> dict:
41 return {} 41 return {}
42 42
43 43
44def get_user_access_key(telegram_user_id: str) -> str:
45 with db.connection_context():
46 try:
47 user = User.get(User.telegram_user_id == telegram_user_id)
48 return user.access_key
49 except DoesNotExist:
50 return ""
51
52
44class Location(BaseModel): 53class Location(BaseModel):
45 fsq_id = CharField(unique=True, primary_key=True) 54 fsq_id = CharField(unique=True, primary_key=True)
46 name = CharField(max_length=128) 55 name = CharField(max_length=128)
Powered by cgit v1.2.3 (git 2.41.0)