diff options
Diffstat (limited to 'mastodon/internals.py')
-rw-r--r-- | mastodon/internals.py | 658 |
1 files changed, 658 insertions, 0 deletions
diff --git a/mastodon/internals.py b/mastodon/internals.py new file mode 100644 index 0000000..415e22d --- /dev/null +++ b/mastodon/internals.py | |||
@@ -0,0 +1,658 @@ | |||
1 | import datetime | ||
2 | from contextlib import closing | ||
3 | import mimetypes | ||
4 | import threading | ||
5 | import six | ||
6 | import uuid | ||
7 | import dateutil.parser | ||
8 | import time | ||
9 | import copy | ||
10 | import requests | ||
11 | import re | ||
12 | import collections | ||
13 | import base64 | ||
14 | import os | ||
15 | |||
16 | from .utility import AttribAccessDict, AttribAccessList | ||
17 | from .error import MastodonNetworkError, MastodonIllegalArgumentError, MastodonRatelimitError, MastodonNotFoundError, \ | ||
18 | MastodonUnauthorizedError, MastodonInternalServerError, MastodonBadGatewayError, MastodonServiceUnavailableError, \ | ||
19 | MastodonGatewayTimeoutError, MastodonServerError, MastodonAPIError, MastodonMalformedEventError | ||
20 | from .compat import urlparse, magic, PurePath | ||
21 | from .defaults import _DEFAULT_STREAM_TIMEOUT, _DEFAULT_STREAM_RECONNECT_WAIT_SEC | ||
22 | |||
23 | ### | ||
24 | # Internal helpers, dragons probably | ||
25 | ### | ||
26 | class Mastodon(): | ||
27 | def __datetime_to_epoch(self, date_time): | ||
28 | """ | ||
29 | Converts a python datetime to unix epoch, accounting for | ||
30 | time zones and such. | ||
31 | |||
32 | Assumes UTC if timezone is not given. | ||
33 | """ | ||
34 | if date_time.tzinfo is None: | ||
35 | date_time = date_time.replace(tzinfo=datetime.timezone.utc) | ||
36 | return date_time.timestamp() | ||
37 | |||
38 | |||
39 | def __get_logged_in_id(self): | ||
40 | """ | ||
41 | Fetch the logged in user's ID, with caching. ID is reset on calls to log_in. | ||
42 | """ | ||
43 | if self.__logged_in_id is None: | ||
44 | self.__logged_in_id = self.account_verify_credentials().id | ||
45 | return self.__logged_in_id | ||
46 | |||
47 | @staticmethod | ||
48 | def __json_allow_dict_attrs(json_object): | ||
49 | """ | ||
50 | Makes it possible to use attribute notation to access a dicts | ||
51 | elements, while still allowing the dict to act as a dict. | ||
52 | """ | ||
53 | if isinstance(json_object, dict): | ||
54 | return AttribAccessDict(json_object) | ||
55 | return json_object | ||
56 | |||
57 | @staticmethod | ||
58 | def __json_date_parse(json_object): | ||
59 | """ | ||
60 | Parse dates in certain known json fields, if possible. | ||
61 | """ | ||
62 | known_date_fields = ["created_at", "week", "day", "expires_at", "scheduled_at", | ||
63 | "updated_at", "last_status_at", "starts_at", "ends_at", "published_at", "edited_at"] | ||
64 | mark_delete = [] | ||
65 | for k, v in json_object.items(): | ||
66 | if k in known_date_fields: | ||
67 | if v is not None: | ||
68 | try: | ||
69 | if isinstance(v, int): | ||
70 | json_object[k] = datetime.datetime.fromtimestamp(v, datetime.timezone.utc) | ||
71 | else: | ||
72 | json_object[k] = dateutil.parser.parse(v) | ||
73 | except: | ||
74 | # When we can't parse a date, we just leave the field out | ||
75 | mark_delete.append(k) | ||
76 | # Two step process because otherwise python gets very upset | ||
77 | for k in mark_delete: | ||
78 | del json_object[k] | ||
79 | return json_object | ||
80 | |||
81 | @staticmethod | ||
82 | def __json_truefalse_parse(json_object): | ||
83 | """ | ||
84 | Parse 'True' / 'False' strings in certain known fields | ||
85 | """ | ||
86 | for key in ('follow', 'favourite', 'reblog', 'mention'): | ||
87 | if (key in json_object and isinstance(json_object[key], six.text_type)): | ||
88 | if json_object[key].lower() == 'true': | ||
89 | json_object[key] = True | ||
90 | if json_object[key].lower() == 'false': | ||
91 | json_object[key] = False | ||
92 | return json_object | ||
93 | |||
94 | @staticmethod | ||
95 | def __json_strnum_to_bignum(json_object): | ||
96 | """ | ||
97 | Converts json string numerals to native python bignums. | ||
98 | """ | ||
99 | for key in ('id', 'week', 'in_reply_to_id', 'in_reply_to_account_id', 'logins', 'registrations', 'statuses', 'day', 'last_read_id'): | ||
100 | if (key in json_object and isinstance(json_object[key], six.text_type)): | ||
101 | try: | ||
102 | json_object[key] = int(json_object[key]) | ||
103 | except ValueError: | ||
104 | pass | ||
105 | |||
106 | return json_object | ||
107 | |||
108 | @staticmethod | ||
109 | def __json_hooks(json_object): | ||
110 | """ | ||
111 | All the json hooks. Used in request parsing. | ||
112 | """ | ||
113 | json_object = Mastodon.__json_strnum_to_bignum(json_object) | ||
114 | json_object = Mastodon.__json_date_parse(json_object) | ||
115 | json_object = Mastodon.__json_truefalse_parse(json_object) | ||
116 | json_object = Mastodon.__json_allow_dict_attrs(json_object) | ||
117 | return json_object | ||
118 | |||
119 | @staticmethod | ||
120 | def __consistent_isoformat_utc(datetime_val): | ||
121 | """ | ||
122 | Function that does what isoformat does but it actually does the same | ||
123 | every time instead of randomly doing different things on some systems | ||
124 | and also it represents that time as the equivalent UTC time. | ||
125 | """ | ||
126 | isotime = datetime_val.astimezone(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%S%z") | ||
127 | if isotime[-2] != ":": | ||
128 | isotime = isotime[:-2] + ":" + isotime[-2:] | ||
129 | return isotime | ||
130 | |||
131 | def __api_request(self, method, endpoint, params={}, files={}, headers={}, access_token_override=None, base_url_override=None, | ||
132 | do_ratelimiting=True, use_json=False, parse=True, return_response_object=False, skip_error_check=False, lang_override=None): | ||
133 | """ | ||
134 | Internal API request helper. | ||
135 | """ | ||
136 | response = None | ||
137 | remaining_wait = 0 | ||
138 | |||
139 | # Add language to params if not None | ||
140 | lang = self.lang | ||
141 | if lang_override is not None: | ||
142 | lang = lang_override | ||
143 | if lang is not None: | ||
144 | params["lang"] = lang | ||
145 | |||
146 | # "pace" mode ratelimiting: Assume constant rate of requests, sleep a little less long than it | ||
147 | # would take to not hit the rate limit at that request rate. | ||
148 | if do_ratelimiting and self.ratelimit_method == "pace": | ||
149 | if self.ratelimit_remaining == 0: | ||
150 | to_next = self.ratelimit_reset - time.time() | ||
151 | if to_next > 0: | ||
152 | # As a precaution, never sleep longer than 5 minutes | ||
153 | to_next = min(to_next, 5 * 60) | ||
154 | time.sleep(to_next) | ||
155 | else: | ||
156 | time_waited = time.time() - self.ratelimit_lastcall | ||
157 | time_wait = float(self.ratelimit_reset - time.time()) / float(self.ratelimit_remaining) | ||
158 | remaining_wait = time_wait - time_waited | ||
159 | |||
160 | if remaining_wait > 0: | ||
161 | to_next = remaining_wait / self.ratelimit_pacefactor | ||
162 | to_next = min(to_next, 5 * 60) | ||
163 | time.sleep(to_next) | ||
164 | |||
165 | # Generate request headers | ||
166 | headers = copy.deepcopy(headers) | ||
167 | if self.access_token is not None: | ||
168 | headers['Authorization'] = 'Bearer ' + self.access_token | ||
169 | if access_token_override is not None: | ||
170 | headers['Authorization'] = 'Bearer ' + access_token_override | ||
171 | |||
172 | # Add user-agent | ||
173 | if self.user_agent: | ||
174 | headers['User-Agent'] = self.user_agent | ||
175 | |||
176 | # Determine base URL | ||
177 | base_url = self.api_base_url | ||
178 | if base_url_override is not None: | ||
179 | base_url = base_url_override | ||
180 | |||
181 | if self.debug_requests: | ||
182 | print('Mastodon: Request to endpoint "' + base_url + | ||
183 | endpoint + '" using method "' + method + '".') | ||
184 | print('Parameters: ' + str(params)) | ||
185 | print('Headers: ' + str(headers)) | ||
186 | print('Files: ' + str(files)) | ||
187 | |||
188 | # Make request | ||
189 | request_complete = False | ||
190 | while not request_complete: | ||
191 | request_complete = True | ||
192 | |||
193 | response_object = None | ||
194 | try: | ||
195 | kwargs = dict(headers=headers, files=files, timeout=self.request_timeout) | ||
196 | if use_json: | ||
197 | kwargs['json'] = params | ||
198 | elif method == 'GET': | ||
199 | kwargs['params'] = params | ||
200 | else: | ||
201 | kwargs['data'] = params | ||
202 | |||
203 | response_object = self.session.request(method, base_url + endpoint, **kwargs) | ||
204 | except Exception as e: | ||
205 | raise MastodonNetworkError("Could not complete request: %s" % e) | ||
206 | |||
207 | if response_object is None: | ||
208 | raise MastodonIllegalArgumentError("Illegal request.") | ||
209 | |||
210 | # Parse rate limiting headers | ||
211 | if 'X-RateLimit-Remaining' in response_object.headers and do_ratelimiting: | ||
212 | self.ratelimit_remaining = int( | ||
213 | response_object.headers['X-RateLimit-Remaining']) | ||
214 | self.ratelimit_limit = int( | ||
215 | response_object.headers['X-RateLimit-Limit']) | ||
216 | |||
217 | # For gotosocial, we need an int representation, but for non-ints this would crash | ||
218 | try: | ||
219 | ratelimit_intrep = str( | ||
220 | int(response_object.headers['X-RateLimit-Reset'])) | ||
221 | except: | ||
222 | ratelimit_intrep = None | ||
223 | |||
224 | try: | ||
225 | if ratelimit_intrep is not None and ratelimit_intrep == response_object.headers['X-RateLimit-Reset']: | ||
226 | self.ratelimit_reset = int( | ||
227 | response_object.headers['X-RateLimit-Reset']) | ||
228 | else: | ||
229 | ratelimit_reset_datetime = dateutil.parser.parse(response_object.headers['X-RateLimit-Reset']) | ||
230 | self.ratelimit_reset = self.__datetime_to_epoch(ratelimit_reset_datetime) | ||
231 | |||
232 | # Adjust server time to local clock | ||
233 | if 'Date' in response_object.headers: | ||
234 | server_time_datetime = dateutil.parser.parse(response_object.headers['Date']) | ||
235 | server_time = self.__datetime_to_epoch(server_time_datetime) | ||
236 | server_time_diff = time.time() - server_time | ||
237 | self.ratelimit_reset += server_time_diff | ||
238 | self.ratelimit_lastcall = time.time() | ||
239 | except Exception as e: | ||
240 | raise MastodonRatelimitError("Rate limit time calculations failed: %s" % e) | ||
241 | |||
242 | # Handle response | ||
243 | if self.debug_requests: | ||
244 | print('Mastodon: Response received with code ' + str(response_object.status_code) + '.') | ||
245 | print('response headers: ' + str(response_object.headers)) | ||
246 | print('Response text content: ' + str(response_object.text)) | ||
247 | |||
248 | if not response_object.ok: | ||
249 | try: | ||
250 | response = response_object.json(object_hook=self.__json_hooks) | ||
251 | if isinstance(response, dict) and 'error' in response: | ||
252 | error_msg = response['error'] | ||
253 | elif isinstance(response, str): | ||
254 | error_msg = response | ||
255 | else: | ||
256 | error_msg = None | ||
257 | except ValueError: | ||
258 | error_msg = None | ||
259 | |||
260 | # Handle rate limiting | ||
261 | if response_object.status_code == 429: | ||
262 | if self.ratelimit_method == 'throw' or not do_ratelimiting: | ||
263 | raise MastodonRatelimitError('Hit rate limit.') | ||
264 | elif self.ratelimit_method in ('wait', 'pace'): | ||
265 | to_next = self.ratelimit_reset - time.time() | ||
266 | if to_next > 0: | ||
267 | # As a precaution, never sleep longer than 5 minutes | ||
268 | to_next = min(to_next, 5 * 60) | ||
269 | time.sleep(to_next) | ||
270 | request_complete = False | ||
271 | continue | ||
272 | |||
273 | if not skip_error_check: | ||
274 | if response_object.status_code == 404: | ||
275 | ex_type = MastodonNotFoundError | ||
276 | if not error_msg: | ||
277 | error_msg = 'Endpoint not found.' | ||
278 | # this is for compatibility with older versions | ||
279 | # which raised MastodonAPIError('Endpoint not found.') | ||
280 | # on any 404 | ||
281 | elif response_object.status_code == 401: | ||
282 | ex_type = MastodonUnauthorizedError | ||
283 | elif response_object.status_code == 500: | ||
284 | ex_type = MastodonInternalServerError | ||
285 | elif response_object.status_code == 502: | ||
286 | ex_type = MastodonBadGatewayError | ||
287 | elif response_object.status_code == 503: | ||
288 | ex_type = MastodonServiceUnavailableError | ||
289 | elif response_object.status_code == 504: | ||
290 | ex_type = MastodonGatewayTimeoutError | ||
291 | elif response_object.status_code >= 500 and response_object.status_code <= 511: | ||
292 | ex_type = MastodonServerError | ||
293 | else: | ||
294 | ex_type = MastodonAPIError | ||
295 | |||
296 | raise ex_type('Mastodon API returned error', response_object.status_code, response_object.reason, error_msg) | ||
297 | |||
298 | if return_response_object: | ||
299 | return response_object | ||
300 | |||
301 | if parse: | ||
302 | try: | ||
303 | response = response_object.json(object_hook=self.__json_hooks) | ||
304 | except: | ||
305 | raise MastodonAPIError( | ||
306 | "Could not parse response as JSON, response code was %s, " | ||
307 | "bad json content was '%s'" % (response_object.status_code, | ||
308 | response_object.content)) | ||
309 | else: | ||
310 | response = response_object.content | ||
311 | |||
312 | # Parse link headers | ||
313 | if isinstance(response, list) and \ | ||
314 | 'Link' in response_object.headers and \ | ||
315 | response_object.headers['Link'] != "": | ||
316 | response = AttribAccessList(response) | ||
317 | tmp_urls = requests.utils.parse_header_links( | ||
318 | response_object.headers['Link'].rstrip('>').replace('>,<', ',<')) | ||
319 | for url in tmp_urls: | ||
320 | if 'rel' not in url: | ||
321 | continue | ||
322 | |||
323 | if url['rel'] == 'next': | ||
324 | # Be paranoid and extract max_id specifically | ||
325 | next_url = url['url'] | ||
326 | matchgroups = re.search(r"[?&]max_id=([^&]+)", next_url) | ||
327 | |||
328 | if matchgroups: | ||
329 | next_params = copy.deepcopy(params) | ||
330 | next_params['_pagination_method'] = method | ||
331 | next_params['_pagination_endpoint'] = endpoint | ||
332 | max_id = matchgroups.group(1) | ||
333 | if max_id.isdigit(): | ||
334 | next_params['max_id'] = int(max_id) | ||
335 | else: | ||
336 | next_params['max_id'] = max_id | ||
337 | if "since_id" in next_params: | ||
338 | del next_params['since_id'] | ||
339 | if "min_id" in next_params: | ||
340 | del next_params['min_id'] | ||
341 | response._pagination_next = next_params | ||
342 | |||
343 | # Maybe other API users rely on the pagination info in the last item | ||
344 | # Will be removed in future | ||
345 | if isinstance(response[-1], AttribAccessDict): | ||
346 | response[-1]._pagination_next = next_params | ||
347 | |||
348 | if url['rel'] == 'prev': | ||
349 | # Be paranoid and extract since_id or min_id specifically | ||
350 | prev_url = url['url'] | ||
351 | |||
352 | # Old and busted (pre-2.6.0): since_id pagination | ||
353 | matchgroups = re.search( | ||
354 | r"[?&]since_id=([^&]+)", prev_url) | ||
355 | if matchgroups: | ||
356 | prev_params = copy.deepcopy(params) | ||
357 | prev_params['_pagination_method'] = method | ||
358 | prev_params['_pagination_endpoint'] = endpoint | ||
359 | since_id = matchgroups.group(1) | ||
360 | if since_id.isdigit(): | ||
361 | prev_params['since_id'] = int(since_id) | ||
362 | else: | ||
363 | prev_params['since_id'] = since_id | ||
364 | if "max_id" in prev_params: | ||
365 | del prev_params['max_id'] | ||
366 | response._pagination_prev = prev_params | ||
367 | |||
368 | # Maybe other API users rely on the pagination info in the first item | ||
369 | # Will be removed in future | ||
370 | if isinstance(response[0], AttribAccessDict): | ||
371 | response[0]._pagination_prev = prev_params | ||
372 | |||
373 | # New and fantastico (post-2.6.0): min_id pagination | ||
374 | matchgroups = re.search( | ||
375 | r"[?&]min_id=([^&]+)", prev_url) | ||
376 | if matchgroups: | ||
377 | prev_params = copy.deepcopy(params) | ||
378 | prev_params['_pagination_method'] = method | ||
379 | prev_params['_pagination_endpoint'] = endpoint | ||
380 | min_id = matchgroups.group(1) | ||
381 | if min_id.isdigit(): | ||
382 | prev_params['min_id'] = int(min_id) | ||
383 | else: | ||
384 | prev_params['min_id'] = min_id | ||
385 | if "max_id" in prev_params: | ||
386 | del prev_params['max_id'] | ||
387 | response._pagination_prev = prev_params | ||
388 | |||
389 | # Maybe other API users rely on the pagination info in the first item | ||
390 | # Will be removed in future | ||
391 | if isinstance(response[0], AttribAccessDict): | ||
392 | response[0]._pagination_prev = prev_params | ||
393 | |||
394 | return response | ||
395 | |||
396 | def __get_streaming_base(self): | ||
397 | """ | ||
398 | Internal streaming API helper. | ||
399 | |||
400 | Returns the correct URL for the streaming API. | ||
401 | """ | ||
402 | instance = self.instance() | ||
403 | if "streaming_api" in instance["urls"] and instance["urls"]["streaming_api"] != self.api_base_url: | ||
404 | # This is probably a websockets URL, which is really for the browser, but requests can't handle it | ||
405 | # So we do this below to turn it into an HTTPS or HTTP URL | ||
406 | parse = urlparse(instance["urls"]["streaming_api"]) | ||
407 | if parse.scheme == 'wss': | ||
408 | url = "https://" + parse.netloc | ||
409 | elif parse.scheme == 'ws': | ||
410 | url = "http://" + parse.netloc | ||
411 | else: | ||
412 | raise MastodonAPIError( | ||
413 | "Could not parse streaming api location returned from server: {}.".format( | ||
414 | instance["urls"]["streaming_api"])) | ||
415 | else: | ||
416 | url = self.api_base_url | ||
417 | return url | ||
418 | |||
419 | def __stream(self, endpoint, listener, params={}, run_async=False, timeout=_DEFAULT_STREAM_TIMEOUT, reconnect_async=False, reconnect_async_wait_sec=_DEFAULT_STREAM_RECONNECT_WAIT_SEC): | ||
420 | """ | ||
421 | Internal streaming API helper. | ||
422 | |||
423 | Returns a handle to the open connection that the user can close if they | ||
424 | wish to terminate it. | ||
425 | """ | ||
426 | |||
427 | # Check if we have to redirect | ||
428 | url = self.__get_streaming_base() | ||
429 | |||
430 | # The streaming server can't handle two slashes in a path, so remove trailing slashes | ||
431 | if url[-1] == '/': | ||
432 | url = url[:-1] | ||
433 | |||
434 | # Connect function (called and then potentially passed to async handler) | ||
435 | def connect_func(): | ||
436 | headers = {"Authorization": "Bearer " + | ||
437 | self.access_token} if self.access_token else {} | ||
438 | if self.user_agent: | ||
439 | headers['User-Agent'] = self.user_agent | ||
440 | connection = self.session.get(url + endpoint, headers=headers, data=params, stream=True, | ||
441 | timeout=(self.request_timeout, timeout)) | ||
442 | |||
443 | if connection.status_code != 200: | ||
444 | raise MastodonNetworkError( | ||
445 | "Could not connect to streaming server: %s" % connection.reason) | ||
446 | return connection | ||
447 | connection = None | ||
448 | |||
449 | # Async stream handler | ||
450 | class __stream_handle(): | ||
451 | def __init__(self, connection, connect_func, reconnect_async, reconnect_async_wait_sec): | ||
452 | self.closed = False | ||
453 | self.running = True | ||
454 | self.connection = connection | ||
455 | self.connect_func = connect_func | ||
456 | self.reconnect_async = reconnect_async | ||
457 | self.reconnect_async_wait_sec = reconnect_async_wait_sec | ||
458 | self.reconnecting = False | ||
459 | |||
460 | def close(self): | ||
461 | self.closed = True | ||
462 | if self.connection is not None: | ||
463 | self.connection.close() | ||
464 | |||
465 | def is_alive(self): | ||
466 | return self._thread.is_alive() | ||
467 | |||
468 | def is_receiving(self): | ||
469 | if self.closed or not self.running or self.reconnecting or not self.is_alive(): | ||
470 | return False | ||
471 | else: | ||
472 | return True | ||
473 | |||
474 | def _sleep_attentive(self): | ||
475 | if self._thread != threading.current_thread(): | ||
476 | raise RuntimeError( | ||
477 | "Illegal call from outside the stream_handle thread") | ||
478 | time_remaining = self.reconnect_async_wait_sec | ||
479 | while time_remaining > 0 and not self.closed: | ||
480 | time.sleep(0.5) | ||
481 | time_remaining -= 0.5 | ||
482 | |||
483 | def _threadproc(self): | ||
484 | self._thread = threading.current_thread() | ||
485 | |||
486 | # Run until closed or until error if not autoreconnecting | ||
487 | while self.running: | ||
488 | if self.connection is not None: | ||
489 | with closing(self.connection) as r: | ||
490 | try: | ||
491 | listener.handle_stream(r) | ||
492 | except (AttributeError, MastodonMalformedEventError, MastodonNetworkError) as e: | ||
493 | if not (self.closed or self.reconnect_async): | ||
494 | raise e | ||
495 | else: | ||
496 | if self.closed: | ||
497 | self.running = False | ||
498 | |||
499 | # Reconnect loop. Try immediately once, then with delays on error. | ||
500 | if (self.reconnect_async and not self.closed) or self.connection is None: | ||
501 | self.reconnecting = True | ||
502 | connect_success = False | ||
503 | while not connect_success: | ||
504 | if self.closed: | ||
505 | # Someone from outside stopped the streaming | ||
506 | self.running = False | ||
507 | break | ||
508 | try: | ||
509 | the_connection = self.connect_func() | ||
510 | if the_connection.status_code != 200: | ||
511 | exception = MastodonNetworkError(f"Could not connect to server. " | ||
512 | f"HTTP status: {the_connection.status_code}") | ||
513 | listener.on_abort(exception) | ||
514 | self._sleep_attentive() | ||
515 | if self.closed: | ||
516 | # Here we have maybe a rare race condition. Exactly on connect, someone | ||
517 | # stopped the streaming before. We close the previous established connection: | ||
518 | the_connection.close() | ||
519 | else: | ||
520 | self.connection = the_connection | ||
521 | connect_success = True | ||
522 | except: | ||
523 | self._sleep_attentive() | ||
524 | connect_success = False | ||
525 | self.reconnecting = False | ||
526 | else: | ||
527 | self.running = False | ||
528 | return 0 | ||
529 | |||
530 | if run_async: | ||
531 | handle = __stream_handle( | ||
532 | connection, connect_func, reconnect_async, reconnect_async_wait_sec) | ||
533 | t = threading.Thread(args=(), target=handle._threadproc) | ||
534 | t.daemon = True | ||
535 | t.start() | ||
536 | return handle | ||
537 | else: | ||
538 | # Blocking, never returns (can only leave via exception) | ||
539 | connection = connect_func() | ||
540 | with closing(connection) as r: | ||
541 | listener.handle_stream(r) | ||
542 | |||
543 | def __generate_params(self, params, exclude=[]): | ||
544 | """ | ||
545 | Internal named-parameters-to-dict helper. | ||
546 | |||
547 | Note for developers: If called with locals() as params, | ||
548 | as is the usual practice in this code, the __generate_params call | ||
549 | (or at least the locals() call) should generally be the first thing | ||
550 | in your function. | ||
551 | """ | ||
552 | params = collections.OrderedDict(params) | ||
553 | |||
554 | if 'self' in params: | ||
555 | del params['self'] | ||
556 | |||
557 | param_keys = list(params.keys()) | ||
558 | for key in param_keys: | ||
559 | if isinstance(params[key], bool): | ||
560 | params[key] = '1' if params[key] else '0' | ||
561 | |||
562 | for key in param_keys: | ||
563 | if params[key] is None or key in exclude: | ||
564 | del params[key] | ||
565 | |||
566 | param_keys = list(params.keys()) | ||
567 | for key in param_keys: | ||
568 | if isinstance(params[key], list): | ||
569 | params[key + "[]"] = params[key] | ||
570 | del params[key] | ||
571 | |||
572 | return params | ||
573 | |||
574 | def __unpack_id(self, id, dateconv=False): | ||
575 | """ | ||
576 | Internal object-to-id converter | ||
577 | |||
578 | Checks if id is a dict that contains id and | ||
579 | returns the id inside, otherwise just returns | ||
580 | the id straight. | ||
581 | |||
582 | Also unpacks datetimes to snowflake IDs if requested. | ||
583 | """ | ||
584 | if isinstance(id, dict) and "id" in id: | ||
585 | id = id["id"] | ||
586 | if dateconv and isinstance(id, datetime.datetime): | ||
587 | id = (int(id.timestamp()) << 16) * 1000 | ||
588 | return id | ||
589 | |||
590 | def __decode_webpush_b64(self, data): | ||
591 | """ | ||
592 | Re-pads and decodes urlsafe base64. | ||
593 | """ | ||
594 | missing_padding = len(data) % 4 | ||
595 | if missing_padding != 0: | ||
596 | data += '=' * (4 - missing_padding) | ||
597 | return base64.urlsafe_b64decode(data) | ||
598 | |||
599 | def __get_token_expired(self): | ||
600 | """Internal helper for oauth code""" | ||
601 | return self._token_expired < datetime.datetime.now() | ||
602 | |||
603 | def __set_token_expired(self, value): | ||
604 | """Internal helper for oauth code""" | ||
605 | self._token_expired = datetime.datetime.now() + datetime.timedelta(seconds=value) | ||
606 | return | ||
607 | |||
608 | def __get_refresh_token(self): | ||
609 | """Internal helper for oauth code""" | ||
610 | return self._refresh_token | ||
611 | |||
612 | def __set_refresh_token(self, value): | ||
613 | """Internal helper for oauth code""" | ||
614 | self._refresh_token = value | ||
615 | return | ||
616 | |||
617 | def __guess_type(self, media_file): | ||
618 | """Internal helper to guess media file type""" | ||
619 | mime_type = None | ||
620 | try: | ||
621 | mime_type = magic.from_file(media_file, mime=True) | ||
622 | except AttributeError: | ||
623 | mime_type = mimetypes.guess_type(media_file)[0] | ||
624 | return mime_type | ||
625 | |||
626 | def __load_media_file(self, media_file, mime_type=None, file_name=None): | ||
627 | if isinstance(media_file, PurePath): | ||
628 | media_file = str(media_file) | ||
629 | if isinstance(media_file, str) and os.path.isfile(media_file): | ||
630 | mime_type = self.__guess_type(media_file) | ||
631 | media_file = open(media_file, 'rb') | ||
632 | elif isinstance(media_file, str) and os.path.isfile(media_file): | ||
633 | media_file = open(media_file, 'rb') | ||
634 | if mime_type is None: | ||
635 | raise MastodonIllegalArgumentError('Could not determine mime type or data passed directly without mime type.') | ||
636 | if file_name is None: | ||
637 | random_suffix = uuid.uuid4().hex | ||
638 | file_name = "mastodonpyupload_" + str(time.time()) + "_" + str(random_suffix) + mimetypes.guess_extension(mime_type) | ||
639 | return (file_name, media_file, mime_type) | ||
640 | |||
641 | @staticmethod | ||
642 | def __protocolize(base_url): | ||
643 | """Internal add-protocol-to-url helper""" | ||
644 | if not base_url.startswith("http://") and not base_url.startswith("https://"): | ||
645 | base_url = "https://" + base_url | ||
646 | |||
647 | # Some API endpoints can't handle extra /'s in path requests | ||
648 | base_url = base_url.rstrip("/") | ||
649 | return base_url | ||
650 | |||
651 | @staticmethod | ||
652 | def __deprotocolize(base_url): | ||
653 | """Internal helper to strip http and https from a URL""" | ||
654 | if base_url.startswith("http://"): | ||
655 | base_url = base_url[7:] | ||
656 | elif base_url.startswith("https://") or base_url.startswith("onion://"): | ||
657 | base_url = base_url[8:] | ||
658 | return base_url | ||