1
1
# from typing import Any
2
2
3
3
import decimal
4
+ import importlib
4
5
import json
5
6
import re
6
7
from functools import reduce
17
18
18
19
from django .http import JsonResponse
19
20
20
-
21
- from modules .maintenance .util import find_connection
22
- from modules .orm .filters import Filter as OrmFilter
23
- from settings .context import session
24
21
from .exception import HTTPException
25
22
26
23
re_id = re .compile (r'(.*)\/(\d+)(\/.)?$' )
27
24
search_regex = re .compile (r'__isnull|__gte|__lte|__lt|__gt|__startswith' )
28
25
29
26
27
+ async def fake_tenant (* args , ** kwargs ):
28
+ return 'default'
29
+
30
+
31
+ tenant = importlib .find_loader ('tenant' )
32
+ if tenant :
33
+ tenant = importlib .import_module ('tenant' )
34
+ set_tenant = tenant .set_tenant
35
+ OrmFilter = tenant .Filter
36
+ session = tenant .session
37
+ else :
38
+ set_tenant = fake_tenant
39
+ OrmFilter = None
40
+ session = None
41
+
42
+
30
43
async def method_not_allowed (self , ** kwargs ):
31
44
raise HTTPException (405 , 'Method not allowed' )
32
45
@@ -44,6 +57,7 @@ def decoder(obj):
44
57
45
58
46
59
class BaseResource (View ):
60
+ authenticated = False
47
61
allowed_methods = ['delete' , 'get' , 'patch' , 'post' ]
48
62
routes = []
49
63
@@ -66,7 +80,6 @@ class BaseResource(View):
66
80
67
81
related_models = None
68
82
list_related_fields = None
69
- # related_fields = []
70
83
many_to_many_models = {}
71
84
72
85
filter_fields = []
@@ -88,7 +101,7 @@ class BaseResource(View):
88
101
normalize_list = False
89
102
90
103
def __init__ (self ):
91
- user_session = session .get ()
104
+ user_session = session .get () if session else None
92
105
self .account_id = user_session ['account' ].id if user_session else None
93
106
self .account = user_session ['account' ] if user_session else None
94
107
self .user = user_session ['user' ] if user_session else None
@@ -143,10 +156,11 @@ async def dispatch(self, request, *args, **kwargs) -> None:
143
156
if func :
144
157
self .allowed_methods = allowed_methods or self .allowed_methods
145
158
159
+ if self .authenticated and not self .user :
160
+ raise HTTPException (401 , 'Not authorized' )
161
+
146
162
if method not in self .allowed_methods :
147
- return HTTPException (
148
- status = 405 , detail = f'{ method } not allowed'
149
- )
163
+ raise HTTPException (405 , f'{ method } not allowed' )
150
164
151
165
if not func :
152
166
handler = getattr (self , method , method_not_allowed )
@@ -313,7 +327,7 @@ def dehydrate(self, response):
313
327
314
328
# count só se aplica a listagens
315
329
@sync_to_async
316
- def count (self ):
330
+ def count (self , account_db ):
317
331
count = 0
318
332
self .count_results = 10
319
333
if not hasattr (self .queryset , 'query' ):
@@ -326,7 +340,6 @@ def count(self):
326
340
f'SELECT count(DISTINCT { table } .id) FROM' ,
327
341
query ,
328
342
)
329
- account_db = find_connection (self .account_id )
330
343
connection = connections [account_db ]
331
344
cursor = connection .cursor ()
332
345
cursor .execute (query , params )
@@ -347,12 +360,13 @@ def get_filters(self, request):
347
360
if not conditions :
348
361
return
349
362
350
- queryset = OrmFilter (
351
- self .model ,
352
- self .user .timezone or 'UTC'
353
- )
354
- queryset = queryset .filter_by (conditions )
355
- self .queryset = queryset .distinct ()
363
+ if OrmFilter :
364
+ queryset = OrmFilter (
365
+ self .model ,
366
+ self .user .timezone if self .user else 'UTC'
367
+ )
368
+ queryset = queryset .filter_by (conditions )
369
+ self .queryset = queryset .distinct ()
356
370
357
371
async def return_results (self , results ):
358
372
if self .count_results :
@@ -392,7 +406,8 @@ async def get_objs(self, request):
392
406
self .filter_objs ()
393
407
394
408
if request .GET .get ('count' ):
395
- return await self .count ()
409
+ account_db = await set_tenant (self .account_id )
410
+ return await self .count (account_db )
396
411
397
412
if self .page > 0 :
398
413
start = (self .page - 1 ) * self .limit
@@ -600,14 +615,14 @@ async def patch(self, request):
600
615
body = request .json
601
616
except Exception :
602
617
return HTTPException (
603
- status = 400 , detail = 'Invalid body'
618
+ 400 , 'Invalid body'
604
619
)
605
620
match = re_id .match (request .path_info )
606
621
if match :
607
622
results = await self ._update_obj (match [2 ], body )
608
623
return self .serialize (results )
609
624
else :
610
- return HTTPException (status = 404 , detail = "Item not found" )
625
+ return HTTPException (404 , "Item not found" )
611
626
612
627
#########################################################
613
628
# POST
0 commit comments