aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'mastodon/internals.py')
-rw-r--r--mastodon/internals.py658
1 files changed, 658 insertions, 0 deletions
diff --git a/mastodon/internals.py b/mastodon/internals.py
new file mode 100644
index 0000000..415e22d
--- /dev/null
+++ b/mastodon/internals.py
@@ -0,0 +1,658 @@
1import datetime
2from contextlib import closing
3import mimetypes
4import threading
5import six
6import uuid
7import dateutil.parser
8import time
9import copy
10import requests
11import re
12import collections
13import base64
14import os
15
16from .utility import AttribAccessDict, AttribAccessList
17from .error import MastodonNetworkError, MastodonIllegalArgumentError, MastodonRatelimitError, MastodonNotFoundError, \
18 MastodonUnauthorizedError, MastodonInternalServerError, MastodonBadGatewayError, MastodonServiceUnavailableError, \
19 MastodonGatewayTimeoutError, MastodonServerError, MastodonAPIError, MastodonMalformedEventError
20from .compat import urlparse, magic, PurePath
21from .defaults import _DEFAULT_STREAM_TIMEOUT, _DEFAULT_STREAM_RECONNECT_WAIT_SEC
22
23###
24# Internal helpers, dragons probably
25###
26class Mastodon():
27 def __datetime_to_epoch(self, date_time):
28 """
29 Converts a python datetime to unix epoch, accounting for
30 time zones and such.
31
32 Assumes UTC if timezone is not given.
33 """
34 if date_time.tzinfo is None:
35 date_time = date_time.replace(tzinfo=datetime.timezone.utc)
36 return date_time.timestamp()
37
38
39 def __get_logged_in_id(self):
40 """
41 Fetch the logged in user's ID, with caching. ID is reset on calls to log_in.
42 """
43 if self.__logged_in_id is None:
44 self.__logged_in_id = self.account_verify_credentials().id
45 return self.__logged_in_id
46
47 @staticmethod
48 def __json_allow_dict_attrs(json_object):
49 """
50 Makes it possible to use attribute notation to access a dicts
51 elements, while still allowing the dict to act as a dict.
52 """
53 if isinstance(json_object, dict):
54 return AttribAccessDict(json_object)
55 return json_object
56
57 @staticmethod
58 def __json_date_parse(json_object):
59 """
60 Parse dates in certain known json fields, if possible.
61 """
62 known_date_fields = ["created_at", "week", "day", "expires_at", "scheduled_at",
63 "updated_at", "last_status_at", "starts_at", "ends_at", "published_at", "edited_at"]
64 mark_delete = []
65 for k, v in json_object.items():
66 if k in known_date_fields:
67 if v is not None:
68 try:
69 if isinstance(v, int):
70 json_object[k] = datetime.datetime.fromtimestamp(v, datetime.timezone.utc)
71 else:
72 json_object[k] = dateutil.parser.parse(v)
73 except:
74 # When we can't parse a date, we just leave the field out
75 mark_delete.append(k)
76 # Two step process because otherwise python gets very upset
77 for k in mark_delete:
78 del json_object[k]
79 return json_object
80
81 @staticmethod
82 def __json_truefalse_parse(json_object):
83 """
84 Parse 'True' / 'False' strings in certain known fields
85 """
86 for key in ('follow', 'favourite', 'reblog', 'mention'):
87 if (key in json_object and isinstance(json_object[key], six.text_type)):
88 if json_object[key].lower() == 'true':
89 json_object[key] = True
90 if json_object[key].lower() == 'false':
91 json_object[key] = False
92 return json_object
93
94 @staticmethod
95 def __json_strnum_to_bignum(json_object):
96 """
97 Converts json string numerals to native python bignums.
98 """
99 for key in ('id', 'week', 'in_reply_to_id', 'in_reply_to_account_id', 'logins', 'registrations', 'statuses', 'day', 'last_read_id'):
100 if (key in json_object and isinstance(json_object[key], six.text_type)):
101 try:
102 json_object[key] = int(json_object[key])
103 except ValueError:
104 pass
105
106 return json_object
107
108 @staticmethod
109 def __json_hooks(json_object):
110 """
111 All the json hooks. Used in request parsing.
112 """
113 json_object = Mastodon.__json_strnum_to_bignum(json_object)
114 json_object = Mastodon.__json_date_parse(json_object)
115 json_object = Mastodon.__json_truefalse_parse(json_object)
116 json_object = Mastodon.__json_allow_dict_attrs(json_object)
117 return json_object
118
119 @staticmethod
120 def __consistent_isoformat_utc(datetime_val):
121 """
122 Function that does what isoformat does but it actually does the same
123 every time instead of randomly doing different things on some systems
124 and also it represents that time as the equivalent UTC time.
125 """
126 isotime = datetime_val.astimezone(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%S%z")
127 if isotime[-2] != ":":
128 isotime = isotime[:-2] + ":" + isotime[-2:]
129 return isotime
130
131 def __api_request(self, method, endpoint, params={}, files={}, headers={}, access_token_override=None, base_url_override=None,
132 do_ratelimiting=True, use_json=False, parse=True, return_response_object=False, skip_error_check=False, lang_override=None):
133 """
134 Internal API request helper.
135 """
136 response = None
137 remaining_wait = 0
138
139 # Add language to params if not None
140 lang = self.lang
141 if lang_override is not None:
142 lang = lang_override
143 if lang is not None:
144 params["lang"] = lang
145
146 # "pace" mode ratelimiting: Assume constant rate of requests, sleep a little less long than it
147 # would take to not hit the rate limit at that request rate.
148 if do_ratelimiting and self.ratelimit_method == "pace":
149 if self.ratelimit_remaining == 0:
150 to_next = self.ratelimit_reset - time.time()
151 if to_next > 0:
152 # As a precaution, never sleep longer than 5 minutes
153 to_next = min(to_next, 5 * 60)
154 time.sleep(to_next)
155 else:
156 time_waited = time.time() - self.ratelimit_lastcall
157 time_wait = float(self.ratelimit_reset - time.time()) / float(self.ratelimit_remaining)
158 remaining_wait = time_wait - time_waited
159
160 if remaining_wait > 0:
161 to_next = remaining_wait / self.ratelimit_pacefactor
162 to_next = min(to_next, 5 * 60)
163 time.sleep(to_next)
164
165 # Generate request headers
166 headers = copy.deepcopy(headers)
167 if self.access_token is not None:
168 headers['Authorization'] = 'Bearer ' + self.access_token
169 if access_token_override is not None:
170 headers['Authorization'] = 'Bearer ' + access_token_override
171
172 # Add user-agent
173 if self.user_agent:
174 headers['User-Agent'] = self.user_agent
175
176 # Determine base URL
177 base_url = self.api_base_url
178 if base_url_override is not None:
179 base_url = base_url_override
180
181 if self.debug_requests:
182 print('Mastodon: Request to endpoint "' + base_url +
183 endpoint + '" using method "' + method + '".')
184 print('Parameters: ' + str(params))
185 print('Headers: ' + str(headers))
186 print('Files: ' + str(files))
187
188 # Make request
189 request_complete = False
190 while not request_complete:
191 request_complete = True
192
193 response_object = None
194 try:
195 kwargs = dict(headers=headers, files=files, timeout=self.request_timeout)
196 if use_json:
197 kwargs['json'] = params
198 elif method == 'GET':
199 kwargs['params'] = params
200 else:
201 kwargs['data'] = params
202
203 response_object = self.session.request(method, base_url + endpoint, **kwargs)
204 except Exception as e:
205 raise MastodonNetworkError("Could not complete request: %s" % e)
206
207 if response_object is None:
208 raise MastodonIllegalArgumentError("Illegal request.")
209
210 # Parse rate limiting headers
211 if 'X-RateLimit-Remaining' in response_object.headers and do_ratelimiting:
212 self.ratelimit_remaining = int(
213 response_object.headers['X-RateLimit-Remaining'])
214 self.ratelimit_limit = int(
215 response_object.headers['X-RateLimit-Limit'])
216
217 # For gotosocial, we need an int representation, but for non-ints this would crash
218 try:
219 ratelimit_intrep = str(
220 int(response_object.headers['X-RateLimit-Reset']))
221 except:
222 ratelimit_intrep = None
223
224 try:
225 if ratelimit_intrep is not None and ratelimit_intrep == response_object.headers['X-RateLimit-Reset']:
226 self.ratelimit_reset = int(
227 response_object.headers['X-RateLimit-Reset'])
228 else:
229 ratelimit_reset_datetime = dateutil.parser.parse(response_object.headers['X-RateLimit-Reset'])
230 self.ratelimit_reset = self.__datetime_to_epoch(ratelimit_reset_datetime)
231
232 # Adjust server time to local clock
233 if 'Date' in response_object.headers:
234 server_time_datetime = dateutil.parser.parse(response_object.headers['Date'])
235 server_time = self.__datetime_to_epoch(server_time_datetime)
236 server_time_diff = time.time() - server_time
237 self.ratelimit_reset += server_time_diff
238 self.ratelimit_lastcall = time.time()
239 except Exception as e:
240 raise MastodonRatelimitError("Rate limit time calculations failed: %s" % e)
241
242 # Handle response
243 if self.debug_requests:
244 print('Mastodon: Response received with code ' + str(response_object.status_code) + '.')
245 print('response headers: ' + str(response_object.headers))
246 print('Response text content: ' + str(response_object.text))
247
248 if not response_object.ok:
249 try:
250 response = response_object.json(object_hook=self.__json_hooks)
251 if isinstance(response, dict) and 'error' in response:
252 error_msg = response['error']
253 elif isinstance(response, str):
254 error_msg = response
255 else:
256 error_msg = None
257 except ValueError:
258 error_msg = None
259
260 # Handle rate limiting
261 if response_object.status_code == 429:
262 if self.ratelimit_method == 'throw' or not do_ratelimiting:
263 raise MastodonRatelimitError('Hit rate limit.')
264 elif self.ratelimit_method in ('wait', 'pace'):
265 to_next = self.ratelimit_reset - time.time()
266 if to_next > 0:
267 # As a precaution, never sleep longer than 5 minutes
268 to_next = min(to_next, 5 * 60)
269 time.sleep(to_next)
270 request_complete = False
271 continue
272
273 if not skip_error_check:
274 if response_object.status_code == 404:
275 ex_type = MastodonNotFoundError
276 if not error_msg:
277 error_msg = 'Endpoint not found.'
278 # this is for compatibility with older versions
279 # which raised MastodonAPIError('Endpoint not found.')
280 # on any 404
281 elif response_object.status_code == 401:
282 ex_type = MastodonUnauthorizedError
283 elif response_object.status_code == 500:
284 ex_type = MastodonInternalServerError
285 elif response_object.status_code == 502:
286 ex_type = MastodonBadGatewayError
287 elif response_object.status_code == 503:
288 ex_type = MastodonServiceUnavailableError
289 elif response_object.status_code == 504:
290 ex_type = MastodonGatewayTimeoutError
291 elif response_object.status_code >= 500 and response_object.status_code <= 511:
292 ex_type = MastodonServerError
293 else:
294 ex_type = MastodonAPIError
295
296 raise ex_type('Mastodon API returned error', response_object.status_code, response_object.reason, error_msg)
297
298 if return_response_object:
299 return response_object
300
301 if parse:
302 try:
303 response = response_object.json(object_hook=self.__json_hooks)
304 except:
305 raise MastodonAPIError(
306 "Could not parse response as JSON, response code was %s, "
307 "bad json content was '%s'" % (response_object.status_code,
308 response_object.content))
309 else:
310 response = response_object.content
311
312 # Parse link headers
313 if isinstance(response, list) and \
314 'Link' in response_object.headers and \
315 response_object.headers['Link'] != "":
316 response = AttribAccessList(response)
317 tmp_urls = requests.utils.parse_header_links(
318 response_object.headers['Link'].rstrip('>').replace('>,<', ',<'))
319 for url in tmp_urls:
320 if 'rel' not in url:
321 continue
322
323 if url['rel'] == 'next':
324 # Be paranoid and extract max_id specifically
325 next_url = url['url']
326 matchgroups = re.search(r"[?&]max_id=([^&]+)", next_url)
327
328 if matchgroups:
329 next_params = copy.deepcopy(params)
330 next_params['_pagination_method'] = method
331 next_params['_pagination_endpoint'] = endpoint
332 max_id = matchgroups.group(1)
333 if max_id.isdigit():
334 next_params['max_id'] = int(max_id)
335 else:
336 next_params['max_id'] = max_id
337 if "since_id" in next_params:
338 del next_params['since_id']
339 if "min_id" in next_params:
340 del next_params['min_id']
341 response._pagination_next = next_params
342
343 # Maybe other API users rely on the pagination info in the last item
344 # Will be removed in future
345 if isinstance(response[-1], AttribAccessDict):
346 response[-1]._pagination_next = next_params
347
348 if url['rel'] == 'prev':
349 # Be paranoid and extract since_id or min_id specifically
350 prev_url = url['url']
351
352 # Old and busted (pre-2.6.0): since_id pagination
353 matchgroups = re.search(
354 r"[?&]since_id=([^&]+)", prev_url)
355 if matchgroups:
356 prev_params = copy.deepcopy(params)
357 prev_params['_pagination_method'] = method
358 prev_params['_pagination_endpoint'] = endpoint
359 since_id = matchgroups.group(1)
360 if since_id.isdigit():
361 prev_params['since_id'] = int(since_id)
362 else:
363 prev_params['since_id'] = since_id
364 if "max_id" in prev_params:
365 del prev_params['max_id']
366 response._pagination_prev = prev_params
367
368 # Maybe other API users rely on the pagination info in the first item
369 # Will be removed in future
370 if isinstance(response[0], AttribAccessDict):
371 response[0]._pagination_prev = prev_params
372
373 # New and fantastico (post-2.6.0): min_id pagination
374 matchgroups = re.search(
375 r"[?&]min_id=([^&]+)", prev_url)
376 if matchgroups:
377 prev_params = copy.deepcopy(params)
378 prev_params['_pagination_method'] = method
379 prev_params['_pagination_endpoint'] = endpoint
380 min_id = matchgroups.group(1)
381 if min_id.isdigit():
382 prev_params['min_id'] = int(min_id)
383 else:
384 prev_params['min_id'] = min_id
385 if "max_id" in prev_params:
386 del prev_params['max_id']
387 response._pagination_prev = prev_params
388
389 # Maybe other API users rely on the pagination info in the first item
390 # Will be removed in future
391 if isinstance(response[0], AttribAccessDict):
392 response[0]._pagination_prev = prev_params
393
394 return response
395
396 def __get_streaming_base(self):
397 """
398 Internal streaming API helper.
399
400 Returns the correct URL for the streaming API.
401 """
402 instance = self.instance()
403 if "streaming_api" in instance["urls"] and instance["urls"]["streaming_api"] != self.api_base_url:
404 # This is probably a websockets URL, which is really for the browser, but requests can't handle it
405 # So we do this below to turn it into an HTTPS or HTTP URL
406 parse = urlparse(instance["urls"]["streaming_api"])
407 if parse.scheme == 'wss':
408 url = "https://" + parse.netloc
409 elif parse.scheme == 'ws':
410 url = "http://" + parse.netloc
411 else:
412 raise MastodonAPIError(
413 "Could not parse streaming api location returned from server: {}.".format(
414 instance["urls"]["streaming_api"]))
415 else:
416 url = self.api_base_url
417 return url
418
419 def __stream(self, endpoint, listener, params={}, run_async=False, timeout=_DEFAULT_STREAM_TIMEOUT, reconnect_async=False, reconnect_async_wait_sec=_DEFAULT_STREAM_RECONNECT_WAIT_SEC):
420 """
421 Internal streaming API helper.
422
423 Returns a handle to the open connection that the user can close if they
424 wish to terminate it.
425 """
426
427 # Check if we have to redirect
428 url = self.__get_streaming_base()
429
430 # The streaming server can't handle two slashes in a path, so remove trailing slashes
431 if url[-1] == '/':
432 url = url[:-1]
433
434 # Connect function (called and then potentially passed to async handler)
435 def connect_func():
436 headers = {"Authorization": "Bearer " +
437 self.access_token} if self.access_token else {}
438 if self.user_agent:
439 headers['User-Agent'] = self.user_agent
440 connection = self.session.get(url + endpoint, headers=headers, data=params, stream=True,
441 timeout=(self.request_timeout, timeout))
442
443 if connection.status_code != 200:
444 raise MastodonNetworkError(
445 "Could not connect to streaming server: %s" % connection.reason)
446 return connection
447 connection = None
448
449 # Async stream handler
450 class __stream_handle():
451 def __init__(self, connection, connect_func, reconnect_async, reconnect_async_wait_sec):
452 self.closed = False
453 self.running = True
454 self.connection = connection
455 self.connect_func = connect_func
456 self.reconnect_async = reconnect_async
457 self.reconnect_async_wait_sec = reconnect_async_wait_sec
458 self.reconnecting = False
459
460 def close(self):
461 self.closed = True
462 if self.connection is not None:
463 self.connection.close()
464
465 def is_alive(self):
466 return self._thread.is_alive()
467
468 def is_receiving(self):
469 if self.closed or not self.running or self.reconnecting or not self.is_alive():
470 return False
471 else:
472 return True
473
474 def _sleep_attentive(self):
475 if self._thread != threading.current_thread():
476 raise RuntimeError(
477 "Illegal call from outside the stream_handle thread")
478 time_remaining = self.reconnect_async_wait_sec
479 while time_remaining > 0 and not self.closed:
480 time.sleep(0.5)
481 time_remaining -= 0.5
482
483 def _threadproc(self):
484 self._thread = threading.current_thread()
485
486 # Run until closed or until error if not autoreconnecting
487 while self.running:
488 if self.connection is not None:
489 with closing(self.connection) as r:
490 try:
491 listener.handle_stream(r)
492 except (AttributeError, MastodonMalformedEventError, MastodonNetworkError) as e:
493 if not (self.closed or self.reconnect_async):
494 raise e
495 else:
496 if self.closed:
497 self.running = False
498
499 # Reconnect loop. Try immediately once, then with delays on error.
500 if (self.reconnect_async and not self.closed) or self.connection is None:
501 self.reconnecting = True
502 connect_success = False
503 while not connect_success:
504 if self.closed:
505 # Someone from outside stopped the streaming
506 self.running = False
507 break
508 try:
509 the_connection = self.connect_func()
510 if the_connection.status_code != 200:
511 exception = MastodonNetworkError(f"Could not connect to server. "
512 f"HTTP status: {the_connection.status_code}")
513 listener.on_abort(exception)
514 self._sleep_attentive()
515 if self.closed:
516 # Here we have maybe a rare race condition. Exactly on connect, someone
517 # stopped the streaming before. We close the previous established connection:
518 the_connection.close()
519 else:
520 self.connection = the_connection
521 connect_success = True
522 except:
523 self._sleep_attentive()
524 connect_success = False
525 self.reconnecting = False
526 else:
527 self.running = False
528 return 0
529
530 if run_async:
531 handle = __stream_handle(
532 connection, connect_func, reconnect_async, reconnect_async_wait_sec)
533 t = threading.Thread(args=(), target=handle._threadproc)
534 t.daemon = True
535 t.start()
536 return handle
537 else:
538 # Blocking, never returns (can only leave via exception)
539 connection = connect_func()
540 with closing(connection) as r:
541 listener.handle_stream(r)
542
543 def __generate_params(self, params, exclude=[]):
544 """
545 Internal named-parameters-to-dict helper.
546
547 Note for developers: If called with locals() as params,
548 as is the usual practice in this code, the __generate_params call
549 (or at least the locals() call) should generally be the first thing
550 in your function.
551 """
552 params = collections.OrderedDict(params)
553
554 if 'self' in params:
555 del params['self']
556
557 param_keys = list(params.keys())
558 for key in param_keys:
559 if isinstance(params[key], bool):
560 params[key] = '1' if params[key] else '0'
561
562 for key in param_keys:
563 if params[key] is None or key in exclude:
564 del params[key]
565
566 param_keys = list(params.keys())
567 for key in param_keys:
568 if isinstance(params[key], list):
569 params[key + "[]"] = params[key]
570 del params[key]
571
572 return params
573
574 def __unpack_id(self, id, dateconv=False):
575 """
576 Internal object-to-id converter
577
578 Checks if id is a dict that contains id and
579 returns the id inside, otherwise just returns
580 the id straight.
581
582 Also unpacks datetimes to snowflake IDs if requested.
583 """
584 if isinstance(id, dict) and "id" in id:
585 id = id["id"]
586 if dateconv and isinstance(id, datetime.datetime):
587 id = (int(id.timestamp()) << 16) * 1000
588 return id
589
590 def __decode_webpush_b64(self, data):
591 """
592 Re-pads and decodes urlsafe base64.
593 """
594 missing_padding = len(data) % 4
595 if missing_padding != 0:
596 data += '=' * (4 - missing_padding)
597 return base64.urlsafe_b64decode(data)
598
599 def __get_token_expired(self):
600 """Internal helper for oauth code"""
601 return self._token_expired < datetime.datetime.now()
602
603 def __set_token_expired(self, value):
604 """Internal helper for oauth code"""
605 self._token_expired = datetime.datetime.now() + datetime.timedelta(seconds=value)
606 return
607
608 def __get_refresh_token(self):
609 """Internal helper for oauth code"""
610 return self._refresh_token
611
612 def __set_refresh_token(self, value):
613 """Internal helper for oauth code"""
614 self._refresh_token = value
615 return
616
617 def __guess_type(self, media_file):
618 """Internal helper to guess media file type"""
619 mime_type = None
620 try:
621 mime_type = magic.from_file(media_file, mime=True)
622 except AttributeError:
623 mime_type = mimetypes.guess_type(media_file)[0]
624 return mime_type
625
626 def __load_media_file(self, media_file, mime_type=None, file_name=None):
627 if isinstance(media_file, PurePath):
628 media_file = str(media_file)
629 if isinstance(media_file, str) and os.path.isfile(media_file):
630 mime_type = self.__guess_type(media_file)
631 media_file = open(media_file, 'rb')
632 elif isinstance(media_file, str) and os.path.isfile(media_file):
633 media_file = open(media_file, 'rb')
634 if mime_type is None:
635 raise MastodonIllegalArgumentError('Could not determine mime type or data passed directly without mime type.')
636 if file_name is None:
637 random_suffix = uuid.uuid4().hex
638 file_name = "mastodonpyupload_" + str(time.time()) + "_" + str(random_suffix) + mimetypes.guess_extension(mime_type)
639 return (file_name, media_file, mime_type)
640
641 @staticmethod
642 def __protocolize(base_url):
643 """Internal add-protocol-to-url helper"""
644 if not base_url.startswith("http://") and not base_url.startswith("https://"):
645 base_url = "https://" + base_url
646
647 # Some API endpoints can't handle extra /'s in path requests
648 base_url = base_url.rstrip("/")
649 return base_url
650
651 @staticmethod
652 def __deprotocolize(base_url):
653 """Internal helper to strip http and https from a URL"""
654 if base_url.startswith("http://"):
655 base_url = base_url[7:]
656 elif base_url.startswith("https://") or base_url.startswith("onion://"):
657 base_url = base_url[8:]
658 return base_url
Powered by cgit v1.2.3 (git 2.41.0)