Skip to content

Commit d44b30d

Browse files
authored
Merge pull request #63 from wildjames/dev
Rate limit user account actions
2 parents abb2d42 + a2fab1a commit d44b30d

File tree

7 files changed

+96
-11
lines changed

7 files changed

+96
-11
lines changed

todoqueue_backend/accounts/tests.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from rest_framework.test import APITestCase, APIClient
99
from rest_framework_simplejwt.tokens import RefreshToken
1010

11-
from logging import getLogger, basicConfig, INFO
12-
basicConfig(level=INFO)
11+
from logging import getLogger
1312
logger = getLogger(__name__)
1413

1514
from .serializers import CustomUserSerializer

todoqueue_backend/accounts/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from django.core.cache import cache
2+
from django.utils import timezone
3+
from datetime import timedelta
4+
5+
from logging import getLogger, DEBUG, basicConfig
6+
basicConfig(level=DEBUG)
7+
logger = getLogger(__name__)
8+
9+
def is_rate_limited(identifier, request_type, max_attempts=5, period=timedelta(hours=1)):
10+
"""
11+
Check if the identifier has exceeded the maximum number of attempts within the period.
12+
"""
13+
logger.debug("Checking rate limit")
14+
15+
current_time = timezone.now()
16+
cache_key = f"password_reset_attempts_{request_type}_{identifier}"
17+
attempts = cache.get(cache_key, [])
18+
19+
# Filter out attempts older than the rate-limiting period
20+
attempts = [attempt for attempt in attempts if current_time - attempt < period]
21+
22+
logger.debug(f"Endpoint {cache_key} attempts: {len(attempts)}")
23+
if len(attempts) >= max_attempts:
24+
return True
25+
26+
# Add the current attempt and update the cache
27+
attempts.append(current_time)
28+
cache.set(cache_key, attempts, period.total_seconds())
29+
30+
return False

todoqueue_backend/accounts/views.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta
12
from logging import getLogger
23
from time import sleep
34

@@ -23,6 +24,7 @@
2324
CustomUserRegistrationSerializer,
2425
ResetPasswordSerializer,
2526
)
27+
from .utils import is_rate_limited
2628

2729
logger = getLogger(__name__)
2830

@@ -44,8 +46,10 @@ def get_queryset(self, *args, **kwargs):
4446
def check_permissions(self, request):
4547
super().check_permissions(request)
4648
# If it's a 'list' action and user is not admin, deny access
47-
if self.action == 'list' and not request.user.is_staff:
48-
raise PermissionDenied("You do not have permission to view the list of users.")
49+
if self.action == "list" and not request.user.is_staff:
50+
raise PermissionDenied(
51+
"You do not have permission to view the list of users."
52+
)
4953

5054

5155
class AuthView(APIView):
@@ -75,6 +79,15 @@ def post(self, request):
7579

7680
class RegisterView(APIView):
7781
def post(self, request):
82+
# Check if the rate limit has been exceeded
83+
if is_rate_limited(
84+
request.META['REMOTE_ADDR'], "init_registration", max_attempts=20, period=timedelta(hours=1)
85+
):
86+
return Response(
87+
{"detail": "Registration requests are limited to 20 per hour."},
88+
status=status.HTTP_429_TOO_MANY_REQUESTS,
89+
)
90+
7891
logger.info(f"Registering user with email: {request.data['email']}")
7992
serializer = CustomUserRegistrationSerializer(data=request.data)
8093

@@ -156,6 +169,15 @@ def post(self, request):
156169

157170
class ConfirmRegistrationView(APIView):
158171
def get(self, request, uidb64, token):
172+
# Check if the rate limit has been exceeded
173+
if is_rate_limited(
174+
request.META['REMOTE_ADDR'], "confirm_registration", max_attempts=50, period=timedelta(hours=1)
175+
):
176+
return Response(
177+
{"detail": "Please stop spamming the confirmation endpoint."},
178+
status=status.HTTP_429_TOO_MANY_REQUESTS,
179+
)
180+
159181
logger.info(f"Confirming registration for user with uidb64: {uidb64}")
160182
try:
161183
uid = force_str(urlsafe_base64_decode(uidb64))
@@ -179,6 +201,15 @@ def get(self, request, uidb64, token):
179201

180202
class ForgotPasswordView(APIView):
181203
def post(self, request):
204+
# Check if the rate limit has been exceeded
205+
if is_rate_limited(
206+
request.META['REMOTE_ADDR'], "forgot_password", max_attempts=5, period=timedelta(hours=1)
207+
):
208+
return Response(
209+
{"detail": "Password reset requests are limited to 5 per hour."},
210+
status=status.HTTP_429_TOO_MANY_REQUESTS,
211+
)
212+
182213
logger.info(
183214
f"Forgot password request for user with email: {request.data['email']}"
184215
)
@@ -245,6 +276,15 @@ def post(self, request):
245276

246277
class CompleteForgotPasswordView(APIView):
247278
def post(self, request, uidb64, token):
279+
# Check if the rate limit has been exceeded
280+
if is_rate_limited(
281+
request.META['REMOTE_ADDR'], "new_password", max_attempts=20, period=timedelta(hours=1)
282+
):
283+
return Response(
284+
{"detail": "Please stop spamming the password reset endpoint."},
285+
status=status.HTTP_429_TOO_MANY_REQUESTS,
286+
)
287+
248288
logger.info(f"Completing forgot password for user with uidb64: {uidb64}")
249289
try:
250290
uid = force_str(urlsafe_base64_decode(uidb64))

todoqueue_backend/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ distlib==0.3.5
55
Django==4.2.5
66
django-cors-headers==4.2.0
77
django-rest-framework==0.1.0
8+
django-redis==5.4.0
89
djangorestframework==3.14.0
910
djangorestframework-simplejwt==5.3.0
1011
filelock==3.8.0
@@ -13,6 +14,7 @@ mysqlclient==2.2.0
1314
packaging==23.1
1415
platformdirs==2.5.2
1516
pytz==2023.3.post1
17+
redis==5.0.1
1618
sqlparse==0.4.4
1719
typing_extensions==4.7.1
1820
virtualenv==20.16.3

todoqueue_backend/tasks/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from datetime import timedelta
22
from typing import List
3-
from logging import getLogger, INFO, basicConfig
3+
from logging import getLogger
44
import math
55
from profanity_check import predict as is_profane
66
import random
77

88
logger = getLogger(__name__)
9-
basicConfig(level=INFO)
109

1110

1211
def renormalize(value, old_range, new_range):

todoqueue_backend/tasks/views.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import timedelta
2-
from logging import INFO, basicConfig, getLogger
2+
from logging import getLogger
33

44
from accounts.serializers import CustomUserWithBrowniePointsSerializer
55
from django.contrib.auth import get_user_model
@@ -40,7 +40,6 @@
4040
from .utils import bp_function, parse_duration
4141

4242
logger = getLogger(__name__)
43-
basicConfig(level=INFO)
4443

4544

4645
class ScheduledTaskViewSet(viewsets.ModelViewSet):

todoqueue_backend/todoqueue_backend/settings.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
logger = getLogger(__name__)
2222

2323

24-
25-
2624
# Build paths inside the project like this: BASE_DIR / 'subdir'.
2725
BASE_DIR = Path(__file__).resolve().parent.parent
2826

@@ -33,7 +31,7 @@
3331
SECRET_KEY = config("DJANGO_SECRET", default=os.urandom(32))
3432

3533
# SECURITY WARNING: don't run with debug turned on in production!
36-
DEBUG = True
34+
DEBUG = os.environ.get("DJANGO_DEBUG", "false").lower() == "true"
3735

3836
web_port = config("DJANGO_HOST_PORT", default=8000, cast=int)
3937
logger.info("Whilelisting host for CSRF: {}".format(config("FRONTEND_URL", default=None)))
@@ -129,6 +127,24 @@
129127
}
130128
}
131129

130+
if os.environ.get("DJANGO_CACHE_BACKEND", None) == "redis":
131+
logger.info("Using RedisCache")
132+
CACHES = {
133+
"default": {
134+
"BACKEND": "django_redis.cache.RedisCache",
135+
"LOCATION": os.environ.get("DJANGO_CACHE_LOCATION", "redis://127.0.0.1:6379/1"),
136+
"OPTIONS": {
137+
"CLIENT_CLASS": "django_redis.client.DefaultClient",
138+
}
139+
}
140+
}
141+
else:
142+
logger.info("Using LocMemCache")
143+
CACHES = {
144+
'default': {
145+
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
146+
}
147+
}
132148

133149
AUTH_USER_MODEL = "accounts.CustomUser"
134150

0 commit comments

Comments
 (0)