Skip to content

Commit d8ffd4c

Browse files
committed
bug fixes
1 parent 6712bb5 commit d8ffd4c

File tree

11 files changed

+71
-163
lines changed

11 files changed

+71
-163
lines changed

src/app/api/__init__.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from fastapi.middleware.cors import CORSMiddleware
44
from fastapi.middleware.gzip import GZipMiddleware
55
from fastapi.openapi.utils import get_openapi
6+
from fastapi.exception_handlers import http_exception_handler
67
from fastapi.openapi.docs import (
78
get_redoc_html,
89
get_swagger_ui_html
910
)
1011
from fastapi.staticfiles import StaticFiles
1112
from starlette.middleware.base import BaseHTTPMiddleware
1213
from starlette.responses import PlainTextResponse
14+
from starlette.exceptions import HTTPException
1315
from sqlalchemy.exc import TimeoutError, DatabaseError
1416

1517
from app.api.endpoints import router
@@ -23,7 +25,10 @@ async def db_middleware(request: Request, call_next) -> None:
2325
"""
2426
ROUTE_PREFIX_WHITELIST = ("/api/v1/firehose")
2527
if request.url.path.startswith(ROUTE_PREFIX_WHITELIST):
26-
return await call_next(request)
28+
try:
29+
return await call_next(request)
30+
except HTTPException as e:
31+
return await http_exception_handler(request, e)
2732
else:
2833
try:
2934
with SessionLocal() as db_session:
@@ -41,7 +46,7 @@ async def db_middleware(request: Request, call_next) -> None:
4146
except TimeoutError:
4247
return PlainTextResponse("Timeout when connecting to database", status_code=503)
4348
except DatabaseError:
44-
raise
49+
return PlainTextResponse("Error when connecting to database", status_code=500)
4550

4651
return response
4752

@@ -51,31 +56,32 @@ def create_app() -> FastAPI:
5156
:return:
5257
"""
5358
# Set up logging format
54-
# dictConfig({
55-
# "version": 1,
56-
# "disable_existing_loggers": False,
57-
# "formatters": {
58-
# "access": {
59-
# "()": "uvicorn.logging.AccessFormatter",
60-
# "fmt": '%(asctime)s - %(levelprefix)s %(client_addr)s - \"%(request_line)s\" %(status_code)s',
61-
# "use_colors": True
62-
# },
63-
# },
64-
# "handlers": {
65-
# "access": {
66-
# "formatter": "access",
67-
# "class": "logging.StreamHandler",
68-
# "stream": "ext://sys.stdout",
69-
# },
70-
# },
71-
# "loggers": {
72-
# "uvicorn.access": {
73-
# "handlers": ["access"],
74-
# "level": "INFO",
75-
# "propagate": False
76-
# },
77-
# },
78-
# })
59+
dictConfig({
60+
"version": 1,
61+
"disable_existing_loggers": False,
62+
"formatters": {
63+
"access": {
64+
"()": "uvicorn.logging.AccessFormatter",
65+
"fmt": "%(levelprefix)s %(asctime)s.%(msecs)03d - %(client_addr)s - \"%(request_line)s\" %(status_code)s",
66+
"datefmt": "%Y-%m-%dT%H:%M:%S",
67+
"use_colors": True
68+
},
69+
},
70+
"handlers": {
71+
"access": {
72+
"formatter": "access",
73+
"class": "logging.StreamHandler",
74+
"stream": "ext://sys.stdout",
75+
},
76+
},
77+
"loggers": {
78+
"uvicorn.access": {
79+
"handlers": ["access"],
80+
"level": "INFO",
81+
"propagate": False
82+
},
83+
},
84+
})
7985

8086
app = FastAPI(
8187
debug=settings.DEBUG,
@@ -95,6 +101,7 @@ def create_app() -> FastAPI:
95101
allow_credentials=True,
96102
allow_methods=["*"],
97103
allow_headers=["*"],
104+
expose_headers=["Content-Disposition"], # allows for frontend to get download file names
98105
)
99106

100107
app.add_middleware(

src/app/api/deps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def log(self, action, thing, thing_type=None, thing_pk=None, log_thing=True,
174174
# permissions (Role) - contains all permissions the role is a
175175
# part of (only for cascade purposes)
176176
excluded_fields = ["schema_column", "entity_type", "tag_type",
177-
"permissions"]
177+
"permissions", "entity"]
178178
data = json.dumps(
179179
jsonable_encoder(thing.as_dict(exclude_keys=excluded_fields))
180180
)

src/app/api/endpoints/audit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def search_audits(
7373
- Non-numeric/non-date fields can't use the range filters, for example `subject` or `description`. If range filters are provided the system will treat them as a list filter instead.
7474
- If none of the range or list filters work it will attempt to do a normal search
7575
- Datetimes are parsed using the [dateutil module](https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.parse)
76+
- Some fields (e.g. event subjects and entity values) default to "contains" string searches, where any item containing the searched string will match (so searching for `example.com` would match both `example.com` and `foo.example.com`). Searching these fields with list searches (with square brackets) disables this feature for all list items, but "range" searches (with parentheses) search normally. For example, searching `[example.com]` wouldn't match `foo.example.com` but searching `(example.com)` would match.
7677
"""
7778
try:
7879
filter_dict = get_search_filters(search_schema)

src/app/api/endpoints/file.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,12 @@ def _download(db: Session, audit_logger: deps.AuditLogger, ids: list[int], passw
122122

123123
# if only one file to download and no password was provided send the file
124124
if len(ids) == 1 and password is None:
125+
filename = quote(fileobj.filename.encode("utf-8"))
125126
return StreamingResponse(
126127
filestream,
127128
media_type=fileobj.content_type,
128129
headers={
129-
"content-disposition": f"attachment; filename*=UTF-8''{quote(fileobj.filename.encode('utf-8'))}"
130+
"content-disposition": f'attachment; filename="{filename}"'
130131
},
131132
)
132133
# otherwise add file to zip

src/app/api/endpoints/firehose.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,10 @@
1414
from app.schemas import TokenPayload
1515

1616
router = APIRouter()
17-
18-
# Close all connections on SIGINT, SIGTERM (more graceful shutdown)
17+
shutdown_signal_set = False
1918
shutdown = False
2019

2120

22-
def signal_shutdown(*args):
23-
global shutdown
24-
shutdown = True
25-
for task in asyncio.all_tasks():
26-
task.cancel()
27-
28-
2921
@router.get('/', summary="Stream audit events")
3022
async def stream_audits(
3123
*,
@@ -36,9 +28,24 @@ async def stream_audits(
3628
Stream a list of changes to SCOT data in real time (e.g. creation and
3729
modification events), as well as notifications for the user
3830
"""
39-
if signal.getsignal(signal.SIGINT) != signal_shutdown:
40-
signal.signal(signal.SIGINT, signal_shutdown)
41-
signal.signal(signal.SIGTERM, signal_shutdown)
31+
# Close all connections on SIGINT, SIGTERM (more graceful shutdown)
32+
def get_shutdown_signal(signal_type):
33+
oldsignal = signal.getsignal(signal_type)
34+
if oldsignal == signal.SIG_IGN or oldsignal == signal.SIG_DFL:
35+
oldsignal = None
36+
37+
def signal_shutdown(*args):
38+
global shutdown
39+
shutdown = True
40+
if oldsignal:
41+
oldsignal(*args)
42+
return signal_shutdown
43+
44+
global shutdown_signal_set
45+
if not shutdown_signal_set:
46+
signal.signal(signal.SIGINT, get_shutdown_signal(signal.SIGINT))
47+
signal.signal(signal.SIGTERM, get_shutdown_signal(signal.SIGTERM))
48+
shutdown_signal_set = True
4249

4350
async def event_generator():
4451
audit_checkpoint = None
@@ -103,6 +110,6 @@ async def event_generator():
103110
yield json.dumps({'what': 'create', 'when': record[1].strftime('%Y-%m-%d %H:%M:%s'), 'element_type': 'notification', 'element_id': record[0], 'username': current_user.username})
104111
await asyncio.sleep(2)
105112
except asyncio.CancelledError as e:
106-
pass
113+
raise
107114

108115
return EventSourceResponse(event_generator())

src/app/api/endpoints/generic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ def search_object(
923923
- Non-numeric/non-date fields can't use the range filters, for example `subject` or `description`. If range filters are provided the system will treat them as a list filter instead.
924924
- If none of the range or list filters work it will attempt to do a normal search
925925
- Datetimes are parsed using the [dateutil module](https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.parse)
926+
- Some fields (e.g. event subjects and entity values) default to "contains" string searches, where any item containing the searched string will match (so searching for `example.com` would match both `example.com` and `foo.example.com`). Searching these fields with list searches (with square brackets) disables this feature for all list items, but "range" searches (with parentheses) search normally. For example, searching `[example.com]` wouldn't match `foo.example.com` but searching `(example.com)` would match.
926927
"""
927928

928929
if target_type in deps.PermissionCheck.type_allow_whitelist:

src/app/crud/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,20 @@ def _str_filter(self, query: Query, filter_dict: dict, column: str, escape: bool
215215
column_obj = filter_dict.pop(column, None)
216216
if column_obj is not None:
217217
# does this even make sense? raise exception or just treat it like range?
218-
if isinstance(column_obj, tuple) or isinstance(column_obj, list):
218+
if isinstance(column_obj, tuple):
219219
condition = []
220220
for item in column_obj:
221221
if escape:
222222
condition.append(model.like(f"%{escape_sql_like(item)}%"))
223223
else:
224224
condition.append(model.like(f"%{item}%"))
225225
query = query.filter(or_(*condition))
226+
# If it's a list, disable string contains querying
227+
elif isinstance(column_obj, list):
228+
condition = []
229+
for item in column_obj:
230+
condition.append(model == item)
231+
query = query.filter(or_(*condition))
226232
else:
227233
if escape:
228234
t = escape_sql_like(column_obj)

src/app/crud/crud_alertgroup.py

Lines changed: 0 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -151,111 +151,6 @@ def create(
151151
audit_logger.log("create", db_obj)
152152
return db_obj
153153

154-
def create_with_permissions(
155-
self,
156-
db_session: Session,
157-
*,
158-
obj_in: Union[AlertGroupCreate, AlertGroupDetailedCreate],
159-
perm_in: dict[PermissionEnum, list],
160-
audit_logger=None,
161-
) -> AlertGroup:
162-
"""Create an alertgroup given an AlertGroupCreate Object
163-
First create the alertgroup. Get back the alertgroup ID.
164-
Next, create the alertgroup schema
165-
Next, create the alerts, with the given alertgroup ID.
166-
Fin
167-
168-
"""
169-
170-
if PermissionEnum.admin in perm_in:
171-
raise ValueError("Users cannot assign admin permissions")
172-
db_obj = AlertGroup(
173-
owner=obj_in.owner,
174-
tlp=obj_in.tlp,
175-
view_count=obj_in.view_count,
176-
first_view=obj_in.first_view,
177-
message_id=obj_in.message_id,
178-
subject=obj_in.subject,
179-
)
180-
181-
db_session.add(db_obj)
182-
db_session.flush()
183-
db_session.refresh(db_obj)
184-
185-
tt = self.model.target_type_enum()
186-
# Assign permissions (if applicable)
187-
if tt:
188-
# need to import here to avoid circular dependency
189-
from app.crud import permission, role
190-
191-
for perm in perm_in:
192-
new_perm = {
193-
"permission": perm,
194-
"target_type": tt,
195-
"target_id": db_obj.id,
196-
}
197-
for r in perm_in[perm]:
198-
role_id = r
199-
if not isinstance(r, int):
200-
role_id = role.get_role_by_name(db_session, r).id
201-
new_perm["role_id"] = role_id
202-
permission.create(
203-
db_session, obj_in=new_perm, audit_logger=audit_logger
204-
)
205-
206-
# Also create the schema and add alerts if detailed create
207-
if isinstance(obj_in, AlertGroupDetailedCreate):
208-
schema = None
209-
if obj_in.alert_schema:
210-
schema = obj_in.alert_schema
211-
# Create the schema if one wasn't given
212-
elif obj_in.alerts:
213-
schema = self.validate_alerts(obj_in.alerts)
214-
if schema is not None:
215-
# We have a valid schema, now let's add it
216-
for schema_column in schema:
217-
schema_column.alertgroup_id = db_obj.id
218-
crud.alert_group_schema.create(
219-
db_session=db_session, obj_in=schema_column
220-
)
221-
# We've now validated and added the schema, lets add the alerts now
222-
for alert in obj_in.alerts:
223-
if alert.owner is None:
224-
alert.owner = obj_in.owner
225-
if alert.tlp == TlpEnum.unset:
226-
alert.tlp = obj_in.tlp
227-
alert.alertgroup_id = db_obj.id
228-
crud.alert.create_with_permissions(db_session=db_session, obj_in=alert, perm_in=perm_in)
229-
self.publish("create", db_obj)
230-
# Last thing, check to see if the alert group subject is contained in any existing signatures, if so, add links.
231-
looking_for = re.sub(r"\([^()]*\)", "", obj_in.subject)
232-
if "*" in looking_for or "_" in looking_for:
233-
looking_for = looking_for.lstrip().replace("\\", "\\\\")
234-
looking_for = (
235-
looking_for.replace("_", "__").replace("*", "%").replace("?", "_")
236-
)
237-
else:
238-
looking_for = looking_for.lstrip()
239-
looking_for = "%{0}%".format(escape_sql_like(looking_for))
240-
sigs_to_link_query = db_session.query(Signature).filter(
241-
Signature.name.ilike(looking_for)
242-
)
243-
results = sigs_to_link_query.all()
244-
for sig_to_link in results:
245-
crud.link.create(
246-
db_session=db_session,
247-
obj_in=LinkCreate(
248-
v0_type=TargetTypeEnum.alertgroup,
249-
v0_id=db_obj.id,
250-
v1_type=TargetTypeEnum.signature,
251-
v1_id=sig_to_link.id,
252-
context=f"Automatically linked from alertgroup subject matching: {looking_for}",
253-
),
254-
)
255-
if audit_logger is not None:
256-
audit_logger.log("create", db_obj)
257-
return db_obj
258-
259154
def create_with_owner(
260155
self,
261156
db_session: Session,

src/app/crud/crud_entity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def add_enrichment(
384384
db_session.refresh(entity)
385385
db_session.flush()
386386
if audit_logger is not None:
387-
audit_logger.log("create", enrichment)
387+
audit_logger.log("create", enrichment, log_thing=False)
388388
return entity
389389

390390
def add_entity_classes(

src/app/db/base_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Base:
2020
def __tablename__(cls) -> str:
2121
return cls.__name__.lower()
2222

23-
def as_dict(self, exclude_keys=["entity_type", "tag_type", "entity_types", "entity_classes", "permissions"], pretty_keys: bool = False, enum_value: bool = False):
23+
def as_dict(self, exclude_keys=["entity_type", "tag_type", "entity_types", "entity_classes", "permissions", "entity"], pretty_keys: bool = False, enum_value: bool = False):
2424
"""
2525
Serializes this model to a dictionary as accurately as possible
2626
Contains logic to prevent infinite recursion for circular references
@@ -132,4 +132,4 @@ def get_model_by_target_type(cls, target_type: TargetTypeEnum):
132132
registry_instance = getattr(cls, "registry")
133133
for mapper_ in registry_instance.mappers:
134134
if (mapper_.class_.__tablename__ == table_name):
135-
return mapper_.class_
135+
return mapper_.class_

0 commit comments

Comments
 (0)