Skip to content
This repository was archived by the owner on Jul 20, 2023. It is now read-only.

Commit 03366dd

Browse files
authored
refactor!: rewrite http ratelimiting (#29)
* refactor!: rewrite http ratelimiting * docs: update according to new changes * refactor: bucket should not be changed if hashes are same
1 parent 4968322 commit 03366dd

File tree

5 files changed

+70
-97
lines changed

5 files changed

+70
-97
lines changed

discatcore/errors.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
__all__ = (
99
"DisCatCoreException",
1010
"HTTPException",
11-
"BucketMigrated",
1211
"UnsupportedAPIVersionWarning",
1312
"GatewayReconnect",
1413
)
@@ -81,13 +80,6 @@ def __init__(
8180
super().__init__(format.format(response.status, response.reason, self.code, self.text)) # type: ignore
8281

8382

84-
class BucketMigrated(DisCatCoreException):
85-
"""Represents an internal exception for when a bucket migrates."""
86-
87-
def __init__(self, discord_hash: str) -> None:
88-
super().__init__(f"This bucket has been migrated to a bucket located at {discord_hash}")
89-
90-
9183
class UnsupportedAPIVersionWarning(Warning):
9284
"""Represents a warning for unsupported API versions."""
9385

discatcore/http/client.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import discord_typings as dt
1313

1414
from .. import __version__
15-
from ..errors import BucketMigrated, HTTPException, UnsupportedAPIVersionWarning
15+
from ..errors import HTTPException, UnsupportedAPIVersionWarning
1616
from ..file import BasicFile
1717
from ..types import Unset, UnsetOr
1818
from ..utils.json import dumps, loads
@@ -226,7 +226,8 @@ async def request(
226226
(like Create Guild Sticker).
227227
228228
Returns:
229-
If this route returns t.Any content, it will be processed and returned.
229+
If this route returns any content, it will be processed and returned. If the request fails after 5 tries,
230+
Unset will be returned instead.
230231
"""
231232
self._request_id += 1
232233
rid = self._request_id
@@ -236,6 +237,7 @@ async def request(
236237
query_params = _filter_dict_for_unset(query_params or {})
237238
max_tries = 5
238239
headers: dict[str, str] = self.default_headers
240+
bucket_hash: t.Optional[str] = None
239241

240242
if reason:
241243
headers["X-Audit-Log-Reason"] = _urlquote(reason, safe="/ ")
@@ -249,12 +251,13 @@ async def request(
249251
if data.multipart_content is not Unset:
250252
kwargs["data"] = data.multipart_content
251253

252-
bucket = self._ratelimiter.get_bucket(route.bucket)
254+
for try_ in range(max_tries):
255+
bucket = self._ratelimiter.get_bucket((route.bucket, bucket_hash))
253256

254-
for tries in range(max_tries):
255257
async with self._ratelimiter.global_bucket:
258+
_log.debug("REQUEST:%d The global ratelimit bucket has been acquired!", rid)
256259
async with bucket:
257-
_log.debug("REQUEST:%d All ratelimit buckets have been acquired!", rid)
260+
_log.debug("REQUEST:%d The route ratelimit bucket has been acquired!", rid)
258261

259262
response = await self._session.request(
260263
route.method,
@@ -264,24 +267,28 @@ async def request(
264267
**kwargs,
265268
)
266269
_log.debug(
267-
"REQUEST:%d Made request to %s with method %s.",
270+
"REQUEST:%d Made request to %s with method %s and got status code %d.",
268271
rid,
269272
f"{self._api_url}{url}",
270273
route.method,
274+
response.status,
271275
)
272276

273-
URL_BUCKET = bucket.bucket is None
274-
bucket.update_info(response)
275-
await bucket.acquire()
276-
277-
if URL_BUCKET and bucket.bucket is not None:
278-
try:
279-
self._ratelimiter.migrate_bucket(route.bucket, bucket.bucket)
280-
except BucketMigrated:
281-
bucket = self._ratelimiter.get_bucket(route.bucket)
277+
bucket_hash = response.headers.get("X-RateLimit-Bucket")
278+
if bucket_hash is not None and bucket_hash != bucket.bucket:
279+
_log.debug(
280+
"REQUEST:%d Migrating from bucket (%s, %s) to bucket (%s, %s).",
281+
rid,
282+
route.bucket,
283+
bucket.bucket,
284+
route.bucket,
285+
bucket_hash,
286+
)
287+
bucket = self._ratelimiter.get_bucket((route.bucket, bucket_hash))
282288

283289
# Everything is ok
284290
if 200 <= response.status < 300:
291+
bucket.update_info(response)
285292
return await self._text_or_json(response)
286293

287294
# Ratelimited
@@ -302,13 +309,23 @@ async def request(
302309
)
303310
self._ratelimiter.global_bucket.lock_for(retry_after)
304311
await self._ratelimiter.global_bucket.acquire()
312+
else:
313+
_log.info(
314+
"REQUEST:%d All requests with bucket (%s, %s) have hit a ratelimit! Retrying in %f.",
315+
rid,
316+
route.bucket,
317+
bucket.bucket,
318+
retry_after,
319+
)
320+
bucket.lock_for(retry_after)
321+
await bucket.acquire()
305322

306323
_log.info("REQUEST:%d Ratelimit is over. Continuing with the request.", rid)
307324
continue
308325

309326
# Specific Server Errors, retry after some time
310327
if response.status in {500, 502, 504}:
311-
wait_time = 1 + tries * 2
328+
wait_time = 1 + try_ * 2
312329
_log.info("REQUEST:%d Got a server error! Retrying in %d.", rid, wait_time)
313330
await asyncio.sleep(wait_time)
314331
continue
@@ -317,9 +334,14 @@ async def request(
317334
if response.status >= 400:
318335
raise HTTPException(response, await self._text_or_json(response))
319336

320-
raise RuntimeError(
321-
f'REQUEST:{rid} Tried sending request to "{url}" with method {route.method} {max_tries} times.'
337+
_log.error(
338+
'REQUEST:%d Tried sending request to "%s" with method %s %d times.',
339+
rid,
340+
url,
341+
route.method,
342+
max_tries,
322343
)
344+
return Unset
323345

324346
async def get_gateway_bot(self) -> dt.GetGatewayBotData:
325347
"""Fetches the gateway information from the Discord API.

discatcore/http/ratelimiter.py

Lines changed: 11 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
# SPDX-License-Identifier: MIT
22

3-
import logging
43
import typing as t
54
from datetime import datetime, timezone
65

76
from aiohttp import ClientResponse
87

9-
from ..errors import BucketMigrated
108
from ..utils.ratelimit import BurstRatelimiter, ManualRatelimiter
119

12-
_log = logging.getLogger(__name__)
13-
1410
__all__ = (
1511
"Bucket",
1612
"Ratelimiter",
@@ -31,7 +27,6 @@ class Bucket(BurstRatelimiter):
3127
"reset",
3228
"bucket",
3329
"_first_update",
34-
"_migrated",
3530
)
3631

3732
def __init__(self) -> None:
@@ -40,7 +35,6 @@ def __init__(self) -> None:
4035
self.reset: t.Optional[datetime] = None
4136
self.bucket: t.Optional[str] = None
4237
self._first_update: bool = True
43-
self._migrated: bool = False
4438

4539
def update_info(self, response: ClientResponse) -> None:
4640
"""Updates the bucket's underlying information via the new headers.
@@ -86,76 +80,30 @@ def update_info(self, response: ClientResponse) -> None:
8680
if self._first_update:
8781
self._first_update = False
8882

89-
if self.reset_after is not None and self.remaining == 0 and not self.is_locked():
90-
self.lock_for(self.reset_after)
91-
92-
def migrate_to(self, discord_hash: str) -> t.NoReturn:
93-
"""Migrates this bucket to a new one provided by the Discord API.
94-
95-
Raises:
96-
BucketMigrated: An internal exception for the request function to change buckets.
97-
"""
98-
self._migrated = True
99-
raise BucketMigrated(discord_hash)
100-
101-
@property
102-
def migrated(self) -> bool:
103-
return self._migrated
104-
10583

10684
class Ratelimiter:
10785
"""Represents the global ratelimiter.
10886
10987
Attributes:
110-
discord_buckets (dict[str, Bucket]): A mapping that maps Discord hashes to bucket objects.
111-
url_buckets (dict[str, Bucket]): A mapping that maps psuedo-buckets to bucket objects.
112-
This is primarily used by new requests that do not know their Discord hashes.
113-
url_to_discord_hash (dict[str, str]): A mapping that maps psuedo-buckets to Discord hashes.
114-
This is primarily set by requests that have just discovered their Discord hashes.
115-
global_bucket (ManualRatelimiter): The global bucket. Used for requests that involve global 429s.
88+
buckets: A mapping of route urls and bucket hashes to buckets.
89+
global_bucket: The global bucket. Used for requests that involve global 429s.
11690
"""
11791

118-
__slots__ = ("discord_buckets", "url_buckets", "url_to_discord_hash", "global_bucket")
92+
__slots__ = ("buckets", "global_bucket")
11993

12094
def __init__(self) -> None:
121-
self.discord_buckets: dict[str, Bucket] = {}
122-
self.url_buckets: dict[str, Bucket] = {}
123-
self.url_to_discord_hash: dict[str, str] = {}
95+
self.buckets: dict[tuple[str, t.Optional[str]], Bucket] = {}
12496
self.global_bucket = ManualRatelimiter()
12597

126-
def get_bucket(self, url: str) -> Bucket:
127-
"""Gets a bucket object from the providing url.
98+
def get_bucket(self, key: tuple[str, t.Optional[str]]) -> Bucket:
99+
"""Gets a bucket object with the provided key.
128100
129101
Args:
130-
url (str): The url to grab the bucket with.
131-
This can either be a pseudo-bucket or an actual Discord hash.
102+
key: The key to grab the bucket with. This key is in the format (route_url, bucket_hash).
132103
"""
133-
if url not in self.url_to_discord_hash:
134-
# this serves as a temporary bucket until further ratelimiting info is provided
135-
# or since some routes have no ratelimiting, they have to backfire to this bucket instead
104+
if key not in self.buckets:
136105
new_bucket = Bucket()
137-
self.url_buckets[url] = new_bucket
138-
return new_bucket
139-
140-
discord_hash = self.url_to_discord_hash[url]
141-
return self.discord_buckets[discord_hash]
142-
143-
def migrate_bucket(self, url: str, discord_hash: str) -> t.NoReturn:
144-
"""Migrates a bucket object from a pseudo-bucket to a Discord hash.
145-
Not only will this call :meth:`Bucket.migrate_to`, but the mappings will be updated.
146-
147-
Args:
148-
url (str): The pseudo-bucket of the bucket to migrate.
149-
discord_hash (str): The Discord hash to migrate the bucket to.
150-
"""
151-
self.url_to_discord_hash[url] = discord_hash
152-
153-
cur_bucket = self.url_buckets.get(url)
154-
155-
if not cur_bucket:
156-
raise BucketMigrated(discord_hash)
157-
158-
self.discord_buckets[discord_hash] = cur_bucket
159-
del self.url_buckets[url]
106+
self.buckets[key] = new_bucket
107+
new_bucket.bucket = key[1]
160108

161-
cur_bucket.migrate_to(discord_hash)
109+
return self.buckets[key]

discatcore/http/route.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,31 @@ def __init__(self, method: str, url: str, **params: t.Any) -> None:
3636
self.url: str = url
3737

3838
# top-level resource parameters
39-
self.guild_id: t.Optional[dt.Snowflake] = params.get("guild_id")
40-
self.channel_id: t.Optional[dt.Snowflake] = params.get("channel_id")
41-
self.webhook_id: t.Optional[dt.Snowflake] = params.get("webhook_id")
42-
self.webhook_token: t.Optional[str] = params.get("webhook_token")
39+
self.guild_id: t.Optional[dt.Snowflake] = params.pop("guild_id", None)
40+
self.channel_id: t.Optional[dt.Snowflake] = params.pop("channel_id", None)
41+
self.webhook_id: t.Optional[dt.Snowflake] = params.pop("webhook_id", None)
42+
self.webhook_token: t.Optional[str] = params.pop("webhook_token", None)
4343

4444
@property
4545
def endpoint(self) -> str:
4646
"""The formatted url for this route."""
47-
return self.url.format_map({k: _urlquote(str(v)) for k, v in self.params.items()})
47+
top_level_params = {
48+
k: getattr(self, k)
49+
for k in ("guild_id", "channel_id", "webhook_id", "webhook_token")
50+
if getattr(self, k) is not None
51+
}
52+
other_params = {k: _urlquote(str(v)) for k, v in self.params.items()}
53+
54+
return self.url.format_map({**top_level_params, **other_params})
4855

4956
@property
5057
def bucket(self) -> str:
51-
"""The pseudo-bucket that represents this route. This is generated via the method, raw url and top level parameters."""
58+
"""The pseudo-bucket that represents this route. This is generated with the method and top level parameters filled into the raw url."""
5259
top_level_params = {
5360
k: getattr(self, k)
5461
for k in ("guild_id", "channel_id", "webhook_id", "webhook_token")
5562
if getattr(self, k) is not None
5663
}
57-
other_params = {k: None for k in self.params.keys() if k not in top_level_params.keys()}
64+
other_params = {k: None for k in self.params.keys()}
5865

5966
return f"{self.method}:{self.url.format_map({**top_level_params, **other_params})}"

discatcore/utils/ratelimit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: MIT
22

33
import asyncio
4+
import logging
45
import typing as t
56

67
__all__ = (
@@ -9,6 +10,8 @@
910
"BurstRatelimiter",
1011
)
1112

13+
_log = logging.getLogger(__name__)
14+
1215

1316
class BaseRatelimiter:
1417
"""The base class for all ratelimiters. Locking algorithms are up to the subclassed Ratelimiter."""
@@ -35,7 +38,7 @@ async def __aexit__(self, *args: t.Any) -> None:
3538

3639

3740
class ManualRatelimiter(BaseRatelimiter):
38-
"""A simple ratelimiter that simply locks at the command of t.Anything."""
41+
"""A simple ratelimiter that simply locks at the command of anything."""
3942

4043
async def _unlock(self, delay: float) -> None:
4144
await asyncio.sleep(delay)
@@ -74,6 +77,7 @@ def __init__(self) -> None:
7477

7578
async def acquire(self) -> None:
7679
if self.reset_after is not None and self.remaining == 0 and not self.is_locked():
80+
_log.info("Auto-locking for %f seconds.", self.reset_after)
7781
self.lock_for(self.reset_after)
7882

7983
return await super().acquire()

0 commit comments

Comments
 (0)