aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dbstore/peewee_store.py')
-rw-r--r--dbstore/peewee_store.py34
1 files changed, 30 insertions, 4 deletions
diff --git a/dbstore/peewee_store.py b/dbstore/peewee_store.py
index f10627b..c8d9c22 100644
--- a/dbstore/peewee_store.py
+++ b/dbstore/peewee_store.py
@@ -25,6 +25,31 @@ class User(BaseModel):
25 toot_visibility = CharField(max_length=128, default=TOOT_VISIBILITY_PRIVATE) 25 toot_visibility = CharField(max_length=128, default=TOOT_VISIBILITY_PRIVATE)
26 26
27 27
28def update_user_visibility(telegram_user_id: str, visibility: str) -> int:
29 with db.connection_context():
30 return User.update(toot_visibility=visibility).where(
31 User.telegram_user_id == telegram_user_id
32 ).execute()
33
34
35def get_user_by_id(telegram_user_id: str) -> dict:
36 with db.connection_context():
37 try:
38 user = User.get(User.telegram_user_id == telegram_user_id)
39 return {
40 "telegram_user_id": user.telegram_user_id,
41 "access_key": user.access_key,
42 "home_instance": user.home_instance,
43 "home_instance_type": user.home_instance_type,
44 "state": user.state,
45 "client_id": user.client_id,
46 "client_secret": user.client_secret,
47 "toot_visibility": user.toot_visibility,
48 }
49 except DoesNotExist:
50 return {}
51
52
28def get_user_by_state(state: str) -> dict: 53def get_user_by_state(state: str) -> dict:
29 with db.connection_context(): 54 with db.connection_context():
30 try: 55 try:
@@ -33,6 +58,7 @@ def get_user_by_state(state: str) -> dict:
33 "telegram_user_id": user.telegram_user_id, 58 "telegram_user_id": user.telegram_user_id,
34 "access_key": user.access_key, 59 "access_key": user.access_key,
35 "home_instance": user.home_instance, 60 "home_instance": user.home_instance,
61 "home_instance_type": user.home_instance_type,
36 "state": user.state, 62 "state": user.state,
37 "client_id": user.client_id, 63 "client_id": user.client_id,
38 "client_secret": user.client_secret, 64 "client_secret": user.client_secret,
@@ -78,10 +104,6 @@ class Location(BaseModel):
78 longitude = CharField(max_length=128) 104 longitude = CharField(max_length=128)
79 105
80 106
81with db.connection_context():
82 db.create_tables([User, Location])
83
84
85def get_poi_by_fsq_id(fsq_id) -> dict: 107def get_poi_by_fsq_id(fsq_id) -> dict:
86 with db.connection_context(): 108 with db.connection_context():
87 try: 109 try:
@@ -107,3 +129,7 @@ def create_or_update_poi(poi: dict) -> int:
107 latitude=poi["latitude"], 129 latitude=poi["latitude"],
108 longitude=poi["longitude"], 130 longitude=poi["longitude"],
109 ).on_conflict_replace().execute() 131 ).on_conflict_replace().execute()
132
133
134with db.connection_context():
135 db.create_tables([User, Location])
Powered by cgit v1.2.3 (git 2.41.0)