aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'mastodon/internals.py')
-rw-r--r--mastodon/internals.py664
1 files changed, 664 insertions, 0 deletions
diff --git a/mastodon/internals.py b/mastodon/internals.py
new file mode 100644
index 0000000..a19ed77
--- /dev/null
+++ b/mastodon/internals.py
@@ -0,0 +1,664 @@
1import datetime
2from contextlib import closing
3import mimetypes
4import threading
5import six
6import uuid
7import pytz
8import dateutil.parser
9import time
10import copy
11import requests
12import re
13import collections
14import base64
15import os
16
17from .utility import AttribAccessDict, AttribAccessList
18from .error import MastodonNetworkError, MastodonIllegalArgumentError, MastodonRatelimitError, MastodonNotFoundError, \
19 MastodonUnauthorizedError, MastodonInternalServerError, MastodonBadGatewayError, MastodonServiceUnavailableError, \
20 MastodonGatewayTimeoutError, MastodonServerError, MastodonAPIError, MastodonMalformedEventError
21from .compat import urlparse, magic, PurePath
22from .defaults import _DEFAULT_STREAM_TIMEOUT, _DEFAULT_STREAM_RECONNECT_WAIT_SEC
23
24###
25# Internal helpers, dragons probably
26###
27class Mastodon():
28 def __datetime_to_epoch(self, date_time):
29 """
30 Converts a python datetime to unix epoch, accounting for
31 time zones and such.
32
33 Assumes UTC if timezone is not given.
34 """
35 date_time_utc = None
36 if date_time.tzinfo is None:
37 date_time_utc = date_time.replace(tzinfo=pytz.utc)
38 else:
39 date_time_utc = date_time.astimezone(pytz.utc)
40
41 epoch_utc = datetime.datetime.utcfromtimestamp(0).replace(tzinfo=pytz.utc)
42
43 return (date_time_utc - epoch_utc).total_seconds()
44
45 def __get_logged_in_id(self):
46 """
47 Fetch the logged in user's ID, with caching. ID is reset on calls to log_in.
48 """
49 if self.__logged_in_id is None:
50 self.__logged_in_id = self.account_verify_credentials().id
51 return self.__logged_in_id
52
53 @staticmethod
54 def __json_allow_dict_attrs(json_object):
55 """
56 Makes it possible to use attribute notation to access a dicts
57 elements, while still allowing the dict to act as a dict.
58 """
59 if isinstance(json_object, dict):
60 return AttribAccessDict(json_object)
61 return json_object
62
63 @staticmethod
64 def __json_date_parse(json_object):
65 """
66 Parse dates in certain known json fields, if possible.
67 """
68 known_date_fields = ["created_at", "week", "day", "expires_at", "scheduled_at",
69 "updated_at", "last_status_at", "starts_at", "ends_at", "published_at", "edited_at"]
70 mark_delete = []
71 for k, v in json_object.items():
72 if k in known_date_fields:
73 if v is not None:
74 try:
75 if isinstance(v, int):
76 json_object[k] = datetime.datetime.fromtimestamp(v, pytz.utc)
77 else:
78 json_object[k] = dateutil.parser.parse(v)
79 except:
80 # When we can't parse a date, we just leave the field out
81 mark_delete.append(k)
82 # Two step process because otherwise python gets very upset
83 for k in mark_delete:
84 del json_object[k]
85 return json_object
86
87 @staticmethod
88 def __json_truefalse_parse(json_object):
89 """
90 Parse 'True' / 'False' strings in certain known fields
91 """
92 for key in ('follow', 'favourite', 'reblog', 'mention'):
93 if (key in json_object and isinstance(json_object[key], six.text_type)):
94 if json_object[key].lower() == 'true':
95 json_object[key] = True
96 if json_object[key].lower() == 'false':
97 json_object[key] = False
98 return json_object
99
100 @staticmethod
101 def __json_strnum_to_bignum(json_object):
102 """
103 Converts json string numerals to native python bignums.
104 """
105 for key in ('id', 'week', 'in_reply_to_id', 'in_reply_to_account_id', 'logins', 'registrations', 'statuses', 'day', 'last_read_id'):
106 if (key in json_object and isinstance(json_object[key], six.text_type)):
107 try:
108 json_object[key] = int(json_object[key])
109 except ValueError:
110 pass
111
112 return json_object
113
114 @staticmethod
115 def __json_hooks(json_object):
116 """
117 All the json hooks. Used in request parsing.
118 """
119 json_object = Mastodon.__json_strnum_to_bignum(json_object)
120 json_object = Mastodon.__json_date_parse(json_object)
121 json_object = Mastodon.__json_truefalse_parse(json_object)
122 json_object = Mastodon.__json_allow_dict_attrs(json_object)
123 return json_object
124
125 @staticmethod
126 def __consistent_isoformat_utc(datetime_val):
127 """
128 Function that does what isoformat does but it actually does the same
129 every time instead of randomly doing different things on some systems
130 and also it represents that time as the equivalent UTC time.
131 """
132 isotime = datetime_val.astimezone(pytz.utc).strftime("%Y-%m-%dT%H:%M:%S%z")
133 if isotime[-2] != ":":
134 isotime = isotime[:-2] + ":" + isotime[-2:]
135 return isotime
136
137 def __api_request(self, method, endpoint, params={}, files={}, headers={}, access_token_override=None, base_url_override=None,
138 do_ratelimiting=True, use_json=False, parse=True, return_response_object=False, skip_error_check=False, lang_override=None):
139 """
140 Internal API request helper.
141 """
142 response = None
143 remaining_wait = 0
144
145 # Add language to params if not None
146 lang = self.lang
147 if lang_override is not None:
148 lang = lang_override
149 if lang is not None:
150 params["lang"] = lang
151
152 # "pace" mode ratelimiting: Assume constant rate of requests, sleep a little less long than it
153 # would take to not hit the rate limit at that request rate.
154 if do_ratelimiting and self.ratelimit_method == "pace":
155 if self.ratelimit_remaining == 0:
156 to_next = self.ratelimit_reset - time.time()
157 if to_next > 0:
158 # As a precaution, never sleep longer than 5 minutes
159 to_next = min(to_next, 5 * 60)
160 time.sleep(to_next)
161 else:
162 time_waited = time.time() - self.ratelimit_lastcall
163 time_wait = float(self.ratelimit_reset - time.time()) / float(self.ratelimit_remaining)
164 remaining_wait = time_wait - time_waited
165
166 if remaining_wait > 0:
167 to_next = remaining_wait / self.ratelimit_pacefactor
168 to_next = min(to_next, 5 * 60)
169 time.sleep(to_next)
170
171 # Generate request headers
172 headers = copy.deepcopy(headers)
173 if self.access_token is not None:
174 headers['Authorization'] = 'Bearer ' + self.access_token
175 if access_token_override is not None:
176 headers['Authorization'] = 'Bearer ' + access_token_override
177
178 # Add user-agent
179 if self.user_agent:
180 headers['User-Agent'] = self.user_agent
181
182 # Determine base URL
183 base_url = self.api_base_url
184 if base_url_override is not None:
185 base_url = base_url_override
186
187 if self.debug_requests:
188 print('Mastodon: Request to endpoint "' + base_url +
189 endpoint + '" using method "' + method + '".')
190 print('Parameters: ' + str(params))
191 print('Headers: ' + str(headers))
192 print('Files: ' + str(files))
193
194 # Make request
195 request_complete = False
196 while not request_complete:
197 request_complete = True
198
199 response_object = None
200 try:
201 kwargs = dict(headers=headers, files=files, timeout=self.request_timeout)
202 if use_json:
203 kwargs['json'] = params
204 elif method == 'GET':
205 kwargs['params'] = params
206 else:
207 kwargs['data'] = params
208
209 response_object = self.session.request(method, base_url + endpoint, **kwargs)
210 except Exception as e:
211 raise MastodonNetworkError("Could not complete request: %s" % e)
212
213 if response_object is None:
214 raise MastodonIllegalArgumentError("Illegal request.")
215
216 # Parse rate limiting headers
217 if 'X-RateLimit-Remaining' in response_object.headers and do_ratelimiting:
218 self.ratelimit_remaining = int(
219 response_object.headers['X-RateLimit-Remaining'])
220 self.ratelimit_limit = int(
221 response_object.headers['X-RateLimit-Limit'])
222
223 # For gotosocial, we need an int representation, but for non-ints this would crash
224 try:
225 ratelimit_intrep = str(
226 int(response_object.headers['X-RateLimit-Reset']))
227 except:
228 ratelimit_intrep = None
229
230 try:
231 if ratelimit_intrep is not None and ratelimit_intrep == response_object.headers['X-RateLimit-Reset']:
232 self.ratelimit_reset = int(
233 response_object.headers['X-RateLimit-Reset'])
234 else:
235 ratelimit_reset_datetime = dateutil.parser.parse(response_object.headers['X-RateLimit-Reset'])
236 self.ratelimit_reset = self.__datetime_to_epoch(ratelimit_reset_datetime)
237
238 # Adjust server time to local clock
239 if 'Date' in response_object.headers:
240 server_time_datetime = dateutil.parser.parse(response_object.headers['Date'])
241 server_time = self.__datetime_to_epoch(server_time_datetime)
242 server_time_diff = time.time() - server_time
243 self.ratelimit_reset += server_time_diff
244 self.ratelimit_lastcall = time.time()
245 except Exception as e:
246 raise MastodonRatelimitError("Rate limit time calculations failed: %s" % e)
247
248 # Handle response
249 if self.debug_requests:
250 print('Mastodon: Response received with code ' + str(response_object.status_code) + '.')
251 print('response headers: ' + str(response_object.headers))
252 print('Response text content: ' + str(response_object.text))
253
254 if not response_object.ok:
255 try:
256 response = response_object.json(object_hook=self.__json_hooks)
257 if isinstance(response, dict) and 'error' in response:
258 error_msg = response['error']
259 elif isinstance(response, str):
260 error_msg = response
261 else:
262 error_msg = None
263 except ValueError:
264 error_msg = None
265
266 # Handle rate limiting
267 if response_object.status_code == 429:
268 if self.ratelimit_method == 'throw' or not do_ratelimiting:
269 raise MastodonRatelimitError('Hit rate limit.')
270 elif self.ratelimit_method in ('wait', 'pace'):
271 to_next = self.ratelimit_reset - time.time()
272 if to_next > 0:
273 # As a precaution, never sleep longer than 5 minutes
274 to_next = min(to_next, 5 * 60)
275 time.sleep(to_next)
276 request_complete = False
277 continue
278
279 if not skip_error_check:
280 if response_object.status_code == 404:
281 ex_type = MastodonNotFoundError
282 if not error_msg:
283 error_msg = 'Endpoint not found.'
284 # this is for compatibility with older versions
285 # which raised MastodonAPIError('Endpoint not found.')
286 # on any 404
287 elif response_object.status_code == 401:
288 ex_type = MastodonUnauthorizedError
289 elif response_object.status_code == 500:
290 ex_type = MastodonInternalServerError
291 elif response_object.status_code == 502:
292 ex_type = MastodonBadGatewayError
293 elif response_object.status_code == 503:
294 ex_type = MastodonServiceUnavailableError
295 elif response_object.status_code == 504:
296 ex_type = MastodonGatewayTimeoutError
297 elif response_object.status_code >= 500 and response_object.status_code <= 511:
298 ex_type = MastodonServerError
299 else:
300 ex_type = MastodonAPIError
301
302 raise ex_type('Mastodon API returned error', response_object.status_code, response_object.reason, error_msg)
303
304 if return_response_object:
305 return response_object
306
307 if parse:
308 try:
309 response = response_object.json(object_hook=self.__json_hooks)
310 except:
311 raise MastodonAPIError(
312 "Could not parse response as JSON, response code was %s, "
313 "bad json content was '%s'" % (response_object.status_code,
314 response_object.content))
315 else:
316 response = response_object.content
317
318 # Parse link headers
319 if isinstance(response, list) and \
320 'Link' in response_object.headers and \
321 response_object.headers['Link'] != "":
322 response = AttribAccessList(response)
323 tmp_urls = requests.utils.parse_header_links(
324 response_object.headers['Link'].rstrip('>').replace('>,<', ',<'))
325 for url in tmp_urls:
326 if 'rel' not in url:
327 continue
328
329 if url['rel'] == 'next':
330 # Be paranoid and extract max_id specifically
331 next_url = url['url']
332 matchgroups = re.search(r"[?&]max_id=([^&]+)", next_url)
333
334 if matchgroups:
335 next_params = copy.deepcopy(params)
336 next_params['_pagination_method'] = method
337 next_params['_pagination_endpoint'] = endpoint
338 max_id = matchgroups.group(1)
339 if max_id.isdigit():
340 next_params['max_id'] = int(max_id)
341 else:
342 next_params['max_id'] = max_id
343 if "since_id" in next_params:
344 del next_params['since_id']
345 if "min_id" in next_params:
346 del next_params['min_id']
347 response._pagination_next = next_params
348
349 # Maybe other API users rely on the pagination info in the last item
350 # Will be removed in future
351 if isinstance(response[-1], AttribAccessDict):
352 response[-1]._pagination_next = next_params
353
354 if url['rel'] == 'prev':
355 # Be paranoid and extract since_id or min_id specifically
356 prev_url = url['url']
357
358 # Old and busted (pre-2.6.0): since_id pagination
359 matchgroups = re.search(
360 r"[?&]since_id=([^&]+)", prev_url)
361 if matchgroups:
362 prev_params = copy.deepcopy(params)
363 prev_params['_pagination_method'] = method
364 prev_params['_pagination_endpoint'] = endpoint
365 since_id = matchgroups.group(1)
366 if since_id.isdigit():
367 prev_params['since_id'] = int(since_id)
368 else:
369 prev_params['since_id'] = since_id
370 if "max_id" in prev_params:
371 del prev_params['max_id']
372 response._pagination_prev = prev_params
373
374 # Maybe other API users rely on the pagination info in the first item
375 # Will be removed in future
376 if isinstance(response[0], AttribAccessDict):
377 response[0]._pagination_prev = prev_params
378
379 # New and fantastico (post-2.6.0): min_id pagination
380 matchgroups = re.search(
381 r"[?&]min_id=([^&]+)", prev_url)
382 if matchgroups:
383 prev_params = copy.deepcopy(params)
384 prev_params['_pagination_method'] = method
385 prev_params['_pagination_endpoint'] = endpoint
386 min_id = matchgroups.group(1)
387 if min_id.isdigit():
388 prev_params['min_id'] = int(min_id)
389 else:
390 prev_params['min_id'] = min_id
391 if "max_id" in prev_params:
392 del prev_params['max_id']
393 response._pagination_prev = prev_params
394
395 # Maybe other API users rely on the pagination info in the first item
396 # Will be removed in future
397 if isinstance(response[0], AttribAccessDict):
398 response[0]._pagination_prev = prev_params
399
400 return response
401
402 def __get_streaming_base(self):
403 """
404 Internal streaming API helper.
405
406 Returns the correct URL for the streaming API.
407 """
408 instance = self.instance()
409 if "streaming_api" in instance["urls"] and instance["urls"]["streaming_api"] != self.api_base_url:
410 # This is probably a websockets URL, which is really for the browser, but requests can't handle it
411 # So we do this below to turn it into an HTTPS or HTTP URL
412 parse = urlparse(instance["urls"]["streaming_api"])
413 if parse.scheme == 'wss':
414 url = "https://" + parse.netloc
415 elif parse.scheme == 'ws':
416 url = "http://" + parse.netloc
417 else:
418 raise MastodonAPIError(
419 "Could not parse streaming api location returned from server: {}.".format(
420 instance["urls"]["streaming_api"]))
421 else:
422 url = self.api_base_url
423 return url
424
425 def __stream(self, endpoint, listener, params={}, run_async=False, timeout=_DEFAULT_STREAM_TIMEOUT, reconnect_async=False, reconnect_async_wait_sec=_DEFAULT_STREAM_RECONNECT_WAIT_SEC):
426 """
427 Internal streaming API helper.
428
429 Returns a handle to the open connection that the user can close if they
430 wish to terminate it.
431 """
432
433 # Check if we have to redirect
434 url = self.__get_streaming_base()
435
436 # The streaming server can't handle two slashes in a path, so remove trailing slashes
437 if url[-1] == '/':
438 url = url[:-1]
439
440 # Connect function (called and then potentially passed to async handler)
441 def connect_func():
442 headers = {"Authorization": "Bearer " +
443 self.access_token} if self.access_token else {}
444 if self.user_agent:
445 headers['User-Agent'] = self.user_agent
446 connection = self.session.get(url + endpoint, headers=headers, data=params, stream=True,
447 timeout=(self.request_timeout, timeout))
448
449 if connection.status_code != 200:
450 raise MastodonNetworkError(
451 "Could not connect to streaming server: %s" % connection.reason)
452 return connection
453 connection = None
454
455 # Async stream handler
456 class __stream_handle():
457 def __init__(self, connection, connect_func, reconnect_async, reconnect_async_wait_sec):
458 self.closed = False
459 self.running = True
460 self.connection = connection
461 self.connect_func = connect_func
462 self.reconnect_async = reconnect_async
463 self.reconnect_async_wait_sec = reconnect_async_wait_sec
464 self.reconnecting = False
465
466 def close(self):
467 self.closed = True
468 if self.connection is not None:
469 self.connection.close()
470
471 def is_alive(self):
472 return self._thread.is_alive()
473
474 def is_receiving(self):
475 if self.closed or not self.running or self.reconnecting or not self.is_alive():
476 return False
477 else:
478 return True
479
480 def _sleep_attentive(self):
481 if self._thread != threading.current_thread():
482 raise RuntimeError(
483 "Illegal call from outside the stream_handle thread")
484 time_remaining = self.reconnect_async_wait_sec
485 while time_remaining > 0 and not self.closed:
486 time.sleep(0.5)
487 time_remaining -= 0.5
488
489 def _threadproc(self):
490 self._thread = threading.current_thread()
491
492 # Run until closed or until error if not autoreconnecting
493 while self.running:
494 if self.connection is not None:
495 with closing(self.connection) as r:
496 try:
497 listener.handle_stream(r)
498 except (AttributeError, MastodonMalformedEventError, MastodonNetworkError) as e:
499 if not (self.closed or self.reconnect_async):
500 raise e
501 else:
502 if self.closed:
503 self.running = False
504
505 # Reconnect loop. Try immediately once, then with delays on error.
506 if (self.reconnect_async and not self.closed) or self.connection is None:
507 self.reconnecting = True
508 connect_success = False
509 while not connect_success:
510 if self.closed:
511 # Someone from outside stopped the streaming
512 self.running = False
513 break
514 try:
515 the_connection = self.connect_func()
516 if the_connection.status_code != 200:
517 exception = MastodonNetworkError(f"Could not connect to server. "
518 f"HTTP status: {the_connection.status_code}")
519 listener.on_abort(exception)
520 self._sleep_attentive()
521 if self.closed:
522 # Here we have maybe a rare race condition. Exactly on connect, someone
523 # stopped the streaming before. We close the previous established connection:
524 the_connection.close()
525 else:
526 self.connection = the_connection
527 connect_success = True
528 except:
529 self._sleep_attentive()
530 connect_success = False
531 self.reconnecting = False
532 else:
533 self.running = False
534 return 0
535
536 if run_async:
537 handle = __stream_handle(
538 connection, connect_func, reconnect_async, reconnect_async_wait_sec)
539 t = threading.Thread(args=(), target=handle._threadproc)
540 t.daemon = True
541 t.start()
542 return handle
543 else:
544 # Blocking, never returns (can only leave via exception)
545 connection = connect_func()
546 with closing(connection) as r:
547 listener.handle_stream(r)
548
549 def __generate_params(self, params, exclude=[]):
550 """
551 Internal named-parameters-to-dict helper.
552
553 Note for developers: If called with locals() as params,
554 as is the usual practice in this code, the __generate_params call
555 (or at least the locals() call) should generally be the first thing
556 in your function.
557 """
558 params = collections.OrderedDict(params)
559
560 if 'self' in params:
561 del params['self']
562
563 param_keys = list(params.keys())
564 for key in param_keys:
565 if isinstance(params[key], bool):
566 params[key] = '1' if params[key] else '0'
567
568 for key in param_keys:
569 if params[key] is None or key in exclude:
570 del params[key]
571
572 param_keys = list(params.keys())
573 for key in param_keys:
574 if isinstance(params[key], list):
575 params[key + "[]"] = params[key]
576 del params[key]
577
578 return params
579
580 def __unpack_id(self, id, dateconv=False):
581 """
582 Internal object-to-id converter
583
584 Checks if id is a dict that contains id and
585 returns the id inside, otherwise just returns
586 the id straight.
587
588 Also unpacks datetimes to snowflake IDs if requested.
589 """
590 if isinstance(id, dict) and "id" in id:
591 id = id["id"]
592 if dateconv and isinstance(id, datetime.datetime):
593 id = (int(id.timestamp()) << 16) * 1000
594 return id
595
596 def __decode_webpush_b64(self, data):
597 """
598 Re-pads and decodes urlsafe base64.
599 """
600 missing_padding = len(data) % 4
601 if missing_padding != 0:
602 data += '=' * (4 - missing_padding)
603 return base64.urlsafe_b64decode(data)
604
605 def __get_token_expired(self):
606 """Internal helper for oauth code"""
607 return self._token_expired < datetime.datetime.now()
608
609 def __set_token_expired(self, value):
610 """Internal helper for oauth code"""
611 self._token_expired = datetime.datetime.now() + datetime.timedelta(seconds=value)
612 return
613
614 def __get_refresh_token(self):
615 """Internal helper for oauth code"""
616 return self._refresh_token
617
618 def __set_refresh_token(self, value):
619 """Internal helper for oauth code"""
620 self._refresh_token = value
621 return
622
623 def __guess_type(self, media_file):
624 """Internal helper to guess media file type"""
625 mime_type = None
626 try:
627 mime_type = magic.from_file(media_file, mime=True)
628 except AttributeError:
629 mime_type = mimetypes.guess_type(media_file)[0]
630 return mime_type
631
632 def __load_media_file(self, media_file, mime_type=None, file_name=None):
633 if isinstance(media_file, PurePath):
634 media_file = str(media_file)
635 if isinstance(media_file, str) and os.path.isfile(media_file):
636 mime_type = self.__guess_type(media_file)
637 media_file = open(media_file, 'rb')
638 elif isinstance(media_file, str) and os.path.isfile(media_file):
639 media_file = open(media_file, 'rb')
640 if mime_type is None:
641 raise MastodonIllegalArgumentError('Could not determine mime type or data passed directly without mime type.')
642 if file_name is None:
643 random_suffix = uuid.uuid4().hex
644 file_name = "mastodonpyupload_" + str(time.time()) + "_" + str(random_suffix) + mimetypes.guess_extension(mime_type)
645 return (file_name, media_file, mime_type)
646
647 @staticmethod
648 def __protocolize(base_url):
649 """Internal add-protocol-to-url helper"""
650 if not base_url.startswith("http://") and not base_url.startswith("https://"):
651 base_url = "https://" + base_url
652
653 # Some API endpoints can't handle extra /'s in path requests
654 base_url = base_url.rstrip("/")
655 return base_url
656
657 @staticmethod
658 def __deprotocolize(base_url):
659 """Internal helper to strip http and https from a URL"""
660 if base_url.startswith("http://"):
661 base_url = base_url[7:]
662 elif base_url.startswith("https://") or base_url.startswith("onion://"):
663 base_url = base_url[8:]
664 return base_url
Powered by cgit v1.2.3 (git 2.41.0)