11import asyncio
22import json
3- from enum import Enum
43from uuid import UUID
54
65from aiohttp import web
76from aiohttp_sse import sse_response
8-
9- # import grpc
107from forestadmin .agent_rpc .options import RpcOptions
11-
12- # from forestadmin.agent_rpc.services.datasource import DatasourceService
138from forestadmin .agent_toolkit .agent import Agent
9+ from forestadmin .agent_toolkit .options import Options
1410from forestadmin .datasource_toolkit .interfaces .fields import PrimitiveType
1511from forestadmin .datasource_toolkit .utils .schema import SchemaUtils
1612from forestadmin .rpc_common .hmac import is_valid_hmac
3026from forestadmin .rpc_common .serializers .schema .schema import SchemaSerializer
3127from forestadmin .rpc_common .serializers .utils import CallerSerializer
3228
33- # from concurrent import futures
34-
35-
36- # from forestadmin.rpc_common.proto import datasource_pb2_grpc
37-
38-
39- class RcpJsonEncoder (json .JSONEncoder ):
40- def default (self , o ):
41- if isinstance (o , Enum ):
42- return o .value
43- if isinstance (o , set ):
44- return list (sorted (o , key = lambda x : x .value if isinstance (x , Enum ) else str (x )))
45- if isinstance (o , set ):
46- return list (sorted (o , key = lambda x : x .value if isinstance (x , Enum ) else str (x )))
47-
48- try :
49- return super ().default (o )
50- except Exception as exc :
51- print (f"error on seriliaze { o } , { type (o )} : { exc } " )
52-
5329
5430class RpcAgent (Agent ):
55- # TODO: options to add:
56- # * listen addr
5731 def __init__ (self , options : RpcOptions ):
58- # self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
59- # self.server = grpc.aio.server()
6032 self .listen_addr , self .listen_port = options ["listen_addr" ].rsplit (":" , 1 )
33+ agent_options : Options = {** options } # type:ignore
34+ agent_options ["skip_schema_update" ] = True
35+ agent_options ["env_secret" ] = "f" * 64
36+ agent_options ["server_url" ] = "http://fake"
37+ agent_options ["schema_path" ] = "./.forestadmin-schema.json"
38+ super ().__init__ (agent_options )
39+
6140 self .app = web .Application (middlewares = [self .hmac_middleware ])
62- # self.server.add_insecure_port(options["listen_addr"])
63- options ["skip_schema_update" ] = True
64- options ["env_secret" ] = "f" * 64
65- options ["server_url" ] = "http://fake"
66- # options["auth_secret"] = "f48186505a3c5d62c27743126d6a76c1dd8b3e2d8897de19"
67- options ["schema_path" ] = "./.forestadmin-schema.json"
68- super ().__init__ (options )
69-
70- self .aes_key = self .options ["auth_secret" ][:16 ].encode ()
71- self .aes_iv = self .options ["auth_secret" ][- 16 :].encode ()
72- self ._server_stop = False
7341 self .setup_routes ()
74- # signal.signal(signal.SIGUSR1, self.stop_handler)
7542
7643 @web .middleware
7744 async def hmac_middleware (self , request : web .Request , handler ):
@@ -80,11 +47,10 @@ async def hmac_middleware(self, request: web.Request, handler):
8047 if not is_valid_hmac (
8148 self .options ["auth_secret" ].encode (), body , request .headers .get ("X-FOREST-HMAC" , "" ).encode ("utf-8" )
8249 ):
83- return web .Response (status = 401 )
50+ return web .Response (status = 401 , text = "Unauthorized from HMAC verification" )
8451 return await handler (request )
8552
8653 def setup_routes (self ):
87- # self.app.middlewares.append(self.hmac_middleware)
8854 self .app .router .add_route ("GET" , "/sse" , self .sse_handler )
8955 self .app .router .add_route ("GET" , "/schema" , self .schema )
9056 self .app .router .add_route ("POST" , "/collection/list" , self .collection_list )
@@ -98,11 +64,11 @@ def setup_routes(self):
9864
9965 self .app .router .add_route ("POST" , "/execute-native-query" , self .native_query )
10066 self .app .router .add_route ("POST" , "/render-chart" , self .render_chart )
101- self .app .router .add_route ("GET" , "/" , lambda _ : web .Response (text = "OK" ))
67+ self .app .router .add_route ("GET" , "/" , lambda _ : web .Response (text = "OK" )) # type: ignore
10268
10369 async def sse_handler (self , request : web .Request ) -> web .StreamResponse :
10470 async with sse_response (request ) as resp :
105- while resp .is_connected () and not self . _server_stop :
71+ while resp .is_connected ():
10672 await resp .send ("" , event = "heartbeat" )
10773 await asyncio .sleep (1 )
10874 data = json .dumps ({"event" : "RpcServerStop" })
@@ -112,44 +78,42 @@ async def sse_handler(self, request: web.Request) -> web.StreamResponse:
11278 async def schema (self , request ):
11379 await self .customizer .get_datasource ()
11480
115- return web .Response ( text = json . dumps ( await SchemaSerializer (await self .customizer .get_datasource ()).serialize () ))
81+ return web .json_response ( await SchemaSerializer (await self .customizer .get_datasource ()).serialize ())
11682
11783 async def collection_list (self , request : web .Request ):
11884 body_params = await request .json ()
11985 ds = await self .customizer .get_datasource ()
12086 collection = ds .get_collection (body_params ["collectionName" ])
12187 caller = CallerSerializer .deserialize (body_params ["caller" ])
122- filter_ = PaginatedFilterSerializer .deserialize (body_params ["filter" ], collection )
88+ filter_ = PaginatedFilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
12389 projection = ProjectionSerializer .deserialize (body_params ["projection" ])
12490
12591 records = await collection .list (caller , filter_ , projection )
12692 records = [RecordSerializer .serialize (record ) for record in records ]
127- return web .Response ( text = aes_encrypt ( json . dumps ( records ), self . aes_key , self . aes_iv ) )
93+ return web .json_response ( records )
12894
12995 async def collection_create (self , request : web .Request ):
13096 body_params = await request .text ()
131- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
13297 body_params = json .loads (body_params )
13398 ds = await self .customizer .get_datasource ()
13499
135100 collection = ds .get_collection (body_params ["collectionName" ])
136101 caller = CallerSerializer .deserialize (body_params ["caller" ])
137- data = [RecordSerializer .deserialize (r , collection ) for r in body_params ["data" ]]
102+ data = [RecordSerializer .deserialize (r , collection ) for r in body_params ["data" ]] # type:ignore
138103
139104 records = await collection .create (caller , data )
140105 records = [RecordSerializer .serialize (record ) for record in records ]
141106 return web .json_response (records )
142107
143108 async def collection_update (self , request : web .Request ):
144109 body_params = await request .text ()
145- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
146110 body_params = json .loads (body_params )
147111
148112 ds = await self .customizer .get_datasource ()
149113 collection = ds .get_collection (body_params ["collectionName" ])
150114 caller = CallerSerializer .deserialize (body_params ["caller" ])
151- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection )
152- patch = RecordSerializer .deserialize (body_params ["patch" ], collection )
115+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
116+ patch = RecordSerializer .deserialize (body_params ["patch" ], collection ) # type:ignore
153117
154118 await collection .update (caller , filter_ , patch )
155119 return web .Response (text = "OK" )
@@ -159,7 +123,7 @@ async def collection_delete(self, request: web.Request):
159123 ds = await self .customizer .get_datasource ()
160124 collection = ds .get_collection (body_params ["collectionName" ])
161125 caller = CallerSerializer .deserialize (body_params ["caller" ])
162- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection )
126+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
163127
164128 await collection .delete (caller , filter_ )
165129 return web .Response (text = "OK" )
@@ -169,53 +133,51 @@ async def collection_aggregate(self, request: web.Request):
169133 ds = await self .customizer .get_datasource ()
170134 collection = ds .get_collection (body_params ["collectionName" ])
171135 caller = CallerSerializer .deserialize (body_params ["caller" ])
172- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection )
136+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
173137 aggregation = AggregationSerializer .deserialize (body_params ["aggregation" ])
174138
175139 records = await collection .aggregate (caller , filter_ , aggregation )
176- # records = [RecordSerializer.serialize(record) for record in records]
177- return web .Response (text = aes_encrypt (json .dumps (records ), self .aes_key , self .aes_iv ))
140+ return web .json_response (records )
178141
179142 async def collection_get_form (self , request : web .Request ):
180143 body_params = await request .text ()
181- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
182144 body_params = json .loads (body_params )
183145
184146 ds = await self .customizer .get_datasource ()
185147 collection = ds .get_collection (body_params ["collectionName" ])
186148
187- caller = CallerSerializer .deserialize (body_params ["caller" ]) if body_params [ "caller" ] else None
149+ caller = CallerSerializer .deserialize (body_params ["caller" ])
188150 action_name = body_params ["actionName" ]
189- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) if body_params ["filter" ] else None
151+ if body_params ["filter" ]:
152+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
153+ else :
154+ filter_ = None
190155 data = ActionFormValuesSerializer .deserialize (body_params ["data" ])
191156 meta = body_params ["meta" ]
192157
193158 form = await collection .get_form (caller , action_name , data , filter_ , meta )
194- return web .Response (
195- text = aes_encrypt (json .dumps (ActionFormSerializer .serialize (form )), self .aes_key , self .aes_iv )
196- )
159+ return web .json_response (ActionFormSerializer .serialize (form ))
197160
198161 async def collection_execute (self , request : web .Request ):
199162 body_params = await request .text ()
200- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
201163 body_params = json .loads (body_params )
202164
203165 ds = await self .customizer .get_datasource ()
204166 collection = ds .get_collection (body_params ["collectionName" ])
205167
206- caller = CallerSerializer .deserialize (body_params ["caller" ]) if body_params [ "caller" ] else None
168+ caller = CallerSerializer .deserialize (body_params ["caller" ])
207169 action_name = body_params ["actionName" ]
208- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) if body_params ["filter" ] else None
170+ if body_params ["filter" ]:
171+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
172+ else :
173+ filter_ = None
209174 data = ActionFormValuesSerializer .deserialize (body_params ["data" ])
210175
211176 result = await collection .execute (caller , action_name , data , filter_ )
212- return web .Response (
213- text = aes_encrypt (json .dumps (ActionResultSerializer .serialize (result )), self .aes_key , self .aes_iv )
214- )
177+ return web .json_response (ActionResultSerializer .serialize (result )) # type:ignore
215178
216179 async def collection_render_chart (self , request : web .Request ):
217180 body_params = await request .text ()
218- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
219181 body_params = json .loads (body_params )
220182
221183 ds = await self .customizer .get_datasource ()
@@ -244,11 +206,10 @@ async def collection_render_chart(self, request: web.Request):
244206 ret .append (value )
245207
246208 result = await collection .render_chart (caller , name , record_id )
247- return web .Response ( text = aes_encrypt ( json . dumps ( result ), self . aes_key , self . aes_iv ) )
209+ return web .json_response ( result )
248210
249211 async def render_chart (self , request : web .Request ):
250212 body_params = await request .text ()
251- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
252213 body_params = json .loads (body_params )
253214
254215 ds = await self .customizer .get_datasource ()
@@ -257,11 +218,10 @@ async def render_chart(self, request: web.Request):
257218 name = body_params ["name" ]
258219
259220 result = await ds .render_chart (caller , name )
260- return web .Response ( text = aes_encrypt ( json . dumps ( result ), self . aes_key , self . aes_iv ) )
221+ return web .json_response ( result )
261222
262223 async def native_query (self , request : web .Request ):
263224 body_params = await request .text ()
264- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
265225 body_params = json .loads (body_params )
266226
267227 ds = await self .customizer .get_datasource ()
@@ -270,7 +230,7 @@ async def native_query(self, request: web.Request):
270230 parameters = body_params ["parameters" ]
271231
272232 result = await ds .execute_native_query (connection_name , native_query , parameters )
273- return web .Response ( text = aes_encrypt ( json . dumps ( result ), self . aes_key , self . aes_iv ) )
233+ return web .json_response ( result )
274234
275235 def start (self ):
276236 web .run_app (self .app , host = self .listen_addr , port = int (self .listen_port ))
0 commit comments