From a9fb6f252553ae49a2ba372434073824babe31e4 Mon Sep 17 00:00:00 2001 From: clarkzjw Date: Thu, 23 Feb 2023 21:46:27 -0800 Subject: bot: support changing default toot visibility --- dbstore/peewee_store.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) (limited to 'dbstore/peewee_store.py') diff --git a/dbstore/peewee_store.py b/dbstore/peewee_store.py index f10627b..c8d9c22 100644 --- a/dbstore/peewee_store.py +++ b/dbstore/peewee_store.py @@ -25,6 +25,31 @@ class User(BaseModel): toot_visibility = CharField(max_length=128, default=TOOT_VISIBILITY_PRIVATE) +def update_user_visibility(telegram_user_id: str, visibility: str) -> int: + with db.connection_context(): + return User.update(toot_visibility=visibility).where( + User.telegram_user_id == telegram_user_id + ).execute() + + +def get_user_by_id(telegram_user_id: str) -> dict: + with db.connection_context(): + try: + user = User.get(User.telegram_user_id == telegram_user_id) + return { + "telegram_user_id": user.telegram_user_id, + "access_key": user.access_key, + "home_instance": user.home_instance, + "home_instance_type": user.home_instance_type, + "state": user.state, + "client_id": user.client_id, + "client_secret": user.client_secret, + "toot_visibility": user.toot_visibility, + } + except DoesNotExist: + return {} + + def get_user_by_state(state: str) -> dict: with db.connection_context(): try: @@ -33,6 +58,7 @@ def get_user_by_state(state: str) -> dict: "telegram_user_id": user.telegram_user_id, "access_key": user.access_key, "home_instance": user.home_instance, + "home_instance_type": user.home_instance_type, "state": user.state, "client_id": user.client_id, "client_secret": user.client_secret, @@ -78,10 +104,6 @@ class Location(BaseModel): longitude = CharField(max_length=128) -with db.connection_context(): - db.create_tables([User, Location]) - - def get_poi_by_fsq_id(fsq_id) -> dict: with db.connection_context(): try: @@ -107,3 +129,7 @@ def create_or_update_poi(poi: dict) -> int: latitude=poi["latitude"], longitude=poi["longitude"], ).on_conflict_replace().execute() + + +with db.connection_context(): + db.create_tables([User, Location]) -- cgit v1.2.3