diff options
-rw-r--r-- | bot.py | 8 | ||||
-rw-r--r-- | callback.py | 49 | ||||
-rw-r--r-- | command.py | 15 | ||||
-rw-r--r-- | util.py | 27 |
4 files changed, 67 insertions, 32 deletions
@@ -114,10 +114,10 @@ async def process_oauth_login_callback(update: FediLoginCallbackUpdate, context: | |||
114 | user.access_key = encrypt(access_token, ENCRYPT_KEY) | 114 | user.access_key = encrypt(access_token, ENCRYPT_KEY) |
115 | user.save() | 115 | user.save() |
116 | 116 | ||
117 | text = "You have successfully logged in to your Mastodon account!" | 117 | text = "You have successfully logged in to your Mastodon account!" |
118 | await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN]) | 118 | await context.bot.delete_message(chat_id=user.telegram_user_id, message_id=context.user_data[PROMPT_FEDI_LOGIN]) |
119 | await context.bot.send_message(chat_id=user.telegram_user_id, text=text) | 119 | await context.bot.send_message(chat_id=user.telegram_user_id, text=text) |
120 | await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) | 120 | await context.bot.send_message(chat_id=user.telegram_user_id, text=PROMPT_CHOOSE_ACTION, reply_markup=MAIN_MENU) |
121 | 121 | ||
122 | 122 | ||
123 | async def main() -> None: | 123 | async def main() -> None: |
diff --git a/callback.py b/callback.py index 7530afb..5fe4593 100644 --- a/callback.py +++ b/callback.py | |||
@@ -12,7 +12,7 @@ from config import BOT_SCOPE, ENCRYPT_KEY | |||
12 | from dbstore.peewee_store import User, db, TOOT_VISIBILITY_PRIVATE, TOOT_VISIBILITY_PUBLIC, TOOT_VISIBILITY_UNLISTED | 12 | from dbstore.peewee_store import User, db, TOOT_VISIBILITY_PRIVATE, TOOT_VISIBILITY_PUBLIC, TOOT_VISIBILITY_UNLISTED |
13 | import uuid | 13 | import uuid |
14 | from mastodon import Mastodon | 14 | from mastodon import Mastodon |
15 | from util import decrypt | 15 | from util import decrypt, check_user |
16 | 16 | ||
17 | 17 | ||
18 | def generate_uuid(): | 18 | def generate_uuid(): |
@@ -118,7 +118,8 @@ async def callback_generate_fedi_login_url(update: Update, context: ContextTypes | |||
118 | return FEDI_LOGIN | 118 | return FEDI_LOGIN |
119 | 119 | ||
120 | 120 | ||
121 | async def callback_location_sharing(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 121 | @check_user |
122 | async def callback_location_sharing(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
122 | if update.message.venue is not None: | 123 | if update.message.venue is not None: |
123 | context.user_data["fsq_id"] = update.message.venue.foursquare_id | 124 | context.user_data["fsq_id"] = update.message.venue.foursquare_id |
124 | context.user_data["title"] = update.message.venue.title | 125 | context.user_data["title"] = update.message.venue.title |
@@ -128,11 +129,10 @@ async def callback_location_sharing(update: Update, context: ContextTypes.DEFAUL | |||
128 | poi = query_poi_by_fsq_id(context.user_data.get("fsq_id")) | 129 | poi = query_poi_by_fsq_id(context.user_data.get("fsq_id")) |
129 | content = generate_toot_text(poi["name"], poi["locality"], poi["region"], poi["latitude"], poi["longitude"]) | 130 | content = generate_toot_text(poi["name"], poi["locality"], poi["region"], poi["latitude"], poi["longitude"]) |
130 | 131 | ||
131 | u = get_user_by_id(str(update.effective_user.id)) | 132 | content_type = "text/markdown" if user["home_instance_type"] == "pleroma" else None |
132 | content_type = "text/markdown" if u["home_instance_type"] == "pleroma" else None | ||
133 | 133 | ||
134 | status = get_mastodon_client(update.effective_user.id).status_post(content, | 134 | status = get_mastodon_client(update.effective_user.id).status_post(content, |
135 | visibility=u["tool_visibility"], | 135 | visibility=user["tool_visibility"], |
136 | content_type=content_type, | 136 | content_type=content_type, |
137 | media_ids=[]) | 137 | media_ids=[]) |
138 | 138 | ||
@@ -174,7 +174,8 @@ async def _process_location_search(keyword, lat, lon) -> list: | |||
174 | return keyboard | 174 | return keyboard |
175 | 175 | ||
176 | 176 | ||
177 | async def callback_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 177 | @check_user |
178 | async def callback_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
178 | await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_LOCATION_KEYWORD)) | 179 | await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_LOCATION_KEYWORD)) |
179 | key = update.effective_message.text | 180 | key = update.effective_message.text |
180 | 181 | ||
@@ -191,7 +192,8 @@ async def callback_location_keyword_search(update: Update, context: ContextTypes | |||
191 | return LOCATION_CONFIRMATION | 192 | return LOCATION_CONFIRMATION |
192 | 193 | ||
193 | 194 | ||
194 | async def callback_skip_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 195 | @check_user |
196 | async def callback_skip_location_keyword_search(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
195 | query = update.callback_query | 197 | query = update.callback_query |
196 | await query.answer() | 198 | await query.answer() |
197 | await query.message.delete() | 199 | await query.message.delete() |
@@ -202,7 +204,7 @@ async def callback_skip_location_keyword_search(update: Update, context: Context | |||
202 | return LOCATION_CONFIRMATION | 204 | return LOCATION_CONFIRMATION |
203 | 205 | ||
204 | 206 | ||
205 | async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE) -> int: | 207 | async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE, user: User) -> int: |
206 | poi_name = context.user_data.get("poi_name") | 208 | poi_name = context.user_data.get("poi_name") |
207 | if context.user_data.get("fsq_id") is not None: | 209 | if context.user_data.get("fsq_id") is not None: |
208 | poi = get_poi_by_fsq_id(context.user_data.get("fsq_id")) | 210 | poi = get_poi_by_fsq_id(context.user_data.get("fsq_id")) |
@@ -212,11 +214,10 @@ async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE) -> int | |||
212 | content = generate_toot_text(poi_name, "", "", context.user_data.get("latitude"), | 214 | content = generate_toot_text(poi_name, "", "", context.user_data.get("latitude"), |
213 | context.user_data.get("longitude")) | 215 | context.user_data.get("longitude")) |
214 | 216 | ||
215 | u = get_user_by_id(context.user_data["user_id"]) | 217 | content_type = "text/markdown" if user["home_instance_type"] == "pleroma" else None |
216 | content_type = "text/markdown" if u["home_instance_type"] == "pleroma" else None | ||
217 | 218 | ||
218 | status = get_mastodon_client(context.user_data["user_id"]).status_post(content, | 219 | status = get_mastodon_client(context.user_data["user_id"]).status_post(content, |
219 | visibility=u["toot_visibility"], | 220 | visibility=user["toot_visibility"], |
220 | content_type=content_type, | 221 | content_type=content_type, |
221 | media_ids=[]) | 222 | media_ids=[]) |
222 | 223 | ||
@@ -236,7 +237,8 @@ async def _process_location_selection(context: ContextTypes.DEFAULT_TYPE) -> int | |||
236 | return ADD_COMMENT | 237 | return ADD_COMMENT |
237 | 238 | ||
238 | 239 | ||
239 | async def callback_location_confirmation(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 240 | @check_user |
241 | async def callback_location_confirmation(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
240 | query = update.callback_query | 242 | query = update.callback_query |
241 | await query.answer() | 243 | await query.answer() |
242 | context.user_data["fsq_id"] = query.data | 244 | context.user_data["fsq_id"] = query.data |
@@ -245,15 +247,16 @@ async def callback_location_confirmation(update: Update, context: ContextTypes.D | |||
245 | 247 | ||
246 | context.user_data["chat_id"] = update.effective_chat.id | 248 | context.user_data["chat_id"] = update.effective_chat.id |
247 | 249 | ||
248 | return await _process_location_selection(context) | 250 | return await _process_location_selection(context, user) |
249 | 251 | ||
250 | 252 | ||
251 | async def callback_manual_location(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 253 | @check_user |
254 | async def callback_manual_location(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
252 | context.user_data["poi_name"] = update.effective_message.text | 255 | context.user_data["poi_name"] = update.effective_message.text |
253 | context.user_data["chat_id"] = update.effective_chat.id | 256 | context.user_data["chat_id"] = update.effective_chat.id |
254 | context.user_data["user_id"] = update.effective_user.id | 257 | context.user_data["user_id"] = update.effective_user.id |
255 | 258 | ||
256 | return await _process_location_selection(context) | 259 | return await _process_location_selection(context, user) |
257 | 260 | ||
258 | 261 | ||
259 | async def _process_comment(context: ContextTypes.DEFAULT_TYPE) -> int: | 262 | async def _process_comment(context: ContextTypes.DEFAULT_TYPE) -> int: |
@@ -264,14 +267,13 @@ async def _process_comment(context: ContextTypes.DEFAULT_TYPE) -> int: | |||
264 | return ADD_MEDIA | 267 | return ADD_MEDIA |
265 | 268 | ||
266 | 269 | ||
267 | async def callback_add_comment(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 270 | @check_user |
271 | async def callback_add_comment(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
268 | context.user_data["chat_id"] = update.effective_chat.id | 272 | context.user_data["chat_id"] = update.effective_chat.id |
269 | context.user_data["user_id"] = update.effective_user.id | 273 | context.user_data["user_id"] = update.effective_user.id |
270 | await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_ADD_COMMENT)) | 274 | await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_ADD_COMMENT)) |
271 | 275 | ||
272 | with db.connection_context(): | 276 | content_type = "text/markdown" if user["home_instance_type"] == "pleroma" else None |
273 | u = User.get(User.telegram_user_id == context.user_data["user_id"]) | ||
274 | content_type = "text/markdown" if u.home_instance_type == "pleroma" else None | ||
275 | 277 | ||
276 | comment = update.effective_message.text | 278 | comment = update.effective_message.text |
277 | get_mastodon_client(update.effective_user.id).status_update(id=context.user_data.get(KEY_TOOT_STATUS_ID), | 279 | get_mastodon_client(update.effective_user.id).status_update(id=context.user_data.get(KEY_TOOT_STATUS_ID), |
@@ -283,14 +285,16 @@ async def callback_add_comment(update: Update, context: ContextTypes.DEFAULT_TYP | |||
283 | return await _process_comment(context) | 285 | return await _process_comment(context) |
284 | 286 | ||
285 | 287 | ||
286 | async def callback_skip_comment(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 288 | @check_user |
289 | async def callback_skip_comment(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
287 | context.user_data["chat_id"] = update.effective_chat.id | 290 | context.user_data["chat_id"] = update.effective_chat.id |
288 | await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_ADD_COMMENT)) | 291 | await context.bot.delete_message(update.effective_chat.id, context.user_data.get(PROMPT_ADD_COMMENT)) |
289 | 292 | ||
290 | return await _process_comment(context) | 293 | return await _process_comment(context) |
291 | 294 | ||
292 | 295 | ||
293 | async def callback_add_media(update: Update, context: CallbackContext): | 296 | @check_user |
297 | async def callback_add_media(update: Update, context: CallbackContext, user: User): | ||
294 | await update.message.reply_chat_action(ChatAction.TYPING) | 298 | await update.message.reply_chat_action(ChatAction.TYPING) |
295 | 299 | ||
296 | try: | 300 | try: |
@@ -342,7 +346,8 @@ async def callback_add_media(update: Update, context: CallbackContext): | |||
342 | await update.message.reply_text(text=PROMPT_DONE, reply_markup=MAIN_MENU) | 346 | await update.message.reply_text(text=PROMPT_DONE, reply_markup=MAIN_MENU) |
343 | 347 | ||
344 | 348 | ||
345 | async def callback_skip_media(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 349 | @check_user |
350 | async def callback_skip_media(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
346 | query = update.callback_query | 351 | query = update.callback_query |
347 | await query.answer() | 352 | await query.answer() |
348 | 353 | ||
@@ -1,11 +1,11 @@ | |||
1 | from telegram import Update | 1 | from telegram import Update, User |
2 | from telegram.constants import ParseMode | 2 | from telegram.constants import ParseMode |
3 | from telegram.error import BadRequest | 3 | from telegram.error import BadRequest |
4 | from telegram.ext import ContextTypes, ConversationHandler | 4 | from telegram.ext import ContextTypes, ConversationHandler |
5 | from dbstore.peewee_store import get_user_access_key, get_user_home_instance, delete_user_by_id, update_user_visibility | 5 | from dbstore.peewee_store import get_user_access_key, get_user_home_instance, delete_user_by_id, update_user_visibility |
6 | from dbstore.peewee_store import get_user_by_id | ||
7 | from dbstore.peewee_store import TOOT_VISIBILITY_PRIVATE, TOOT_VISIBILITY_UNLISTED, TOOT_VISIBILITY_PUBLIC | 6 | from dbstore.peewee_store import TOOT_VISIBILITY_PRIVATE, TOOT_VISIBILITY_UNLISTED, TOOT_VISIBILITY_PUBLIC |
8 | from config import * | 7 | from config import * |
8 | from util import check_user | ||
9 | 9 | ||
10 | 10 | ||
11 | async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 11 | async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: |
@@ -34,10 +34,11 @@ async def tos_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non | |||
34 | await update.message.reply_text(PROMPT_TOS, parse_mode=ParseMode.HTML, reply_markup=MAIN_MENU) | 34 | await update.message.reply_text(PROMPT_TOS, parse_mode=ParseMode.HTML, reply_markup=MAIN_MENU) |
35 | 35 | ||
36 | 36 | ||
37 | @check_user | ||
37 | async def list_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: | 38 | async def list_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: |
38 | result = get_user_home_instance(str(update.effective_user.id)) | 39 | result = get_user_home_instance(str(update.effective_user.id)) |
39 | if len(result) == 0: | 40 | if len(result) == 0: |
40 | await update.message.reply_text(PROMPT_LIST_NO_RESULT, parse_mode=ParseMode.HTML) | 41 | pass |
41 | else: | 42 | else: |
42 | await update.message.reply_text(f"You are linked with the following Fediverse accounts:\n\n" | 43 | await update.message.reply_text(f"You are linked with the following Fediverse accounts:\n\n" |
43 | f"<b>Instance</b>: {result['home_instance']}\n" | 44 | f"<b>Instance</b>: {result['home_instance']}\n" |
@@ -47,26 +48,28 @@ async def list_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No | |||
47 | reply_markup=MAIN_MENU) | 48 | reply_markup=MAIN_MENU) |
48 | 49 | ||
49 | 50 | ||
51 | @check_user | ||
50 | async def logout_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: | 52 | async def logout_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: |
51 | if delete_user_by_id(str(update.effective_user.id)): | 53 | if delete_user_by_id(str(update.effective_user.id)): |
52 | await update.message.reply_text(PROMPT_LOGOUT_SUCCESS, parse_mode=ParseMode.HTML, reply_markup=LOGIN_MENU) | 54 | await update.message.reply_text(PROMPT_LOGOUT_SUCCESS, parse_mode=ParseMode.HTML, reply_markup=LOGIN_MENU) |
53 | 55 | ||
54 | 56 | ||
55 | async def toggle_visibility_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 57 | @check_user |
58 | async def toggle_visibility_command(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
56 | visibility_menu = InlineKeyboardMarkup([ | 59 | visibility_menu = InlineKeyboardMarkup([ |
57 | [InlineKeyboardButton("Private", callback_data=TOOT_VISIBILITY_PRIVATE)], | 60 | [InlineKeyboardButton("Private", callback_data=TOOT_VISIBILITY_PRIVATE)], |
58 | [InlineKeyboardButton("Unlisted", callback_data=TOOT_VISIBILITY_UNLISTED)], | 61 | [InlineKeyboardButton("Unlisted", callback_data=TOOT_VISIBILITY_UNLISTED)], |
59 | [InlineKeyboardButton("Public", callback_data=TOOT_VISIBILITY_PUBLIC)] | 62 | [InlineKeyboardButton("Public", callback_data=TOOT_VISIBILITY_PUBLIC)] |
60 | ]) | 63 | ]) |
61 | 64 | ||
62 | user = get_user_by_id(str(update.effective_user.id)) | ||
63 | await update.message.reply_text(PROMPT_TOGGLE_VIS.format(user["toot_visibility"]), | 65 | await update.message.reply_text(PROMPT_TOGGLE_VIS.format(user["toot_visibility"]), |
64 | parse_mode=ParseMode.HTML, | 66 | parse_mode=ParseMode.HTML, |
65 | reply_markup=visibility_menu) | 67 | reply_markup=visibility_menu) |
66 | return WAIT_VISIBILITY | 68 | return WAIT_VISIBILITY |
67 | 69 | ||
68 | 70 | ||
69 | async def callback_toggle_visibility(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: | 71 | @check_user |
72 | async def callback_toggle_visibility(update: Update, context: ContextTypes.DEFAULT_TYPE, user: User) -> int: | ||
70 | query = update.callback_query | 73 | query = update.callback_query |
71 | await query.answer() | 74 | await query.answer() |
72 | 75 | ||
@@ -1,4 +1,7 @@ | |||
1 | from cryptography.fernet import Fernet | 1 | from cryptography.fernet import Fernet |
2 | from telegram import Update | ||
3 | from dbstore.peewee_store import db, User, get_user_by_id | ||
4 | import functools | ||
2 | 5 | ||
3 | 6 | ||
4 | def encrypt(input: str, key: str) -> str: | 7 | def encrypt(input: str, key: str) -> str: |
@@ -11,6 +14,30 @@ def decrypt(input: str, key: str) -> str: | |||
11 | return f.decrypt(input).decode('utf-8') | 14 | return f.decrypt(input).decode('utf-8') |
12 | 15 | ||
13 | 16 | ||
17 | def check_user(fn): | ||
18 | """ | ||
19 | Decorator: loads User model and passes it to the function or stops the request. | ||
20 | Ref: https://shallowdepth.online/posts/2021/12/using-python-decorators-to-process-and-authorize-requests/ | ||
21 | """ | ||
22 | |||
23 | @functools.wraps(fn) | ||
24 | async def wrapper(*args, **kwargs): | ||
25 | # Expects that Update object is always the first arg | ||
26 | update: Update = args[0] | ||
27 | user = get_user_by_id(str(update.effective_user.id)) | ||
28 | |||
29 | if len(user) == 0: | ||
30 | await update.effective_message.reply_text("You are not logged in. Use `/login` to link your account first") | ||
31 | return | ||
32 | else: | ||
33 | # TODO | ||
34 | # check access_key is still valid | ||
35 | pass | ||
36 | return await fn(*args, **kwargs, user=user) | ||
37 | |||
38 | return wrapper | ||
39 | |||
40 | |||
14 | # if __name__ == "__main__": | 41 | # if __name__ == "__main__": |
15 | # key = Fernet.generate_key().decode('utf-8') | 42 | # key = Fernet.generate_key().decode('utf-8') |
16 | # print(key) | 43 | # print(key) |