Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Auditability (Closes #916) #960

Merged
merged 23 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
YETI_AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: 30
YETI_AUTH_ENABLED: False
YETI_SYSTEM_PLUGINS_PATH: ./plugins
YETI_SYSTEM_AUDIT_LOGFILE: /tmp/yeti_audit.log
strategy:
matrix:
os: [ubuntu-latest]
Expand Down
121 changes: 121 additions & 0 deletions core/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import datetime
import json
import logging
import os
import queue
from logging import Formatter
from logging.handlers import QueueHandler, QueueListener

from core.config.config import yeti_config
from core.schemas.audit import AuditLog

# Inspired by
# * https://www.sheshbabu.com/posts/fastapi-structured-json-logging/
# * https://rob-blackbourn.medium.com/how-to-use-python-logging-queuehandler-with-dictconfig-1e8b1284e27a


class ArangoHandler(logging.Handler):

actions = {
"GET": "read",
"POST": "create",
"PATCH": "update",
"DELETE": "delete",
}

def __init__(self, level=logging.NOTSET):
super().__init__(level)

def emit(self, record):
if "type" not in record.__dict__:
return
if record.__dict__["type"] != "audit.log":
return
target = record.__dict__["path"]
if "/auth/" in target or target.endswith("/search"):
return
action = self.actions.get(record.__dict__["method"], "unknown")
if record.__dict__["status_code"] == 200:
status = "succeeded"
else:
status = "failed"
udgover marked this conversation as resolved.
Show resolved Hide resolved

if "body" in record.__dict__ and record.__dict__["body"]:
content = json.loads(record.__dict__["body"].decode("utf-8"))
else:
content = {}
AuditLog(
created = datetime.datetime.fromtimestamp(record.created),
username = record.__dict__["username"],
action = action,
status = status,
target = target,
content = content,
status_code = record.__dict__["status_code"],
ip = record.__dict__["client"],
).save()


class JsonFormatter(Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()

def format(self, record):
json_record = {}
json_record["message"] = record.getMessage()
if "username" in record.__dict__:
json_record["username"] = record.__dict__["username"]
if "path" in record.__dict__:
json_record["path"] = record.__dict__["path"]
if "method" in record.__dict__:
json_record["method"] = record.__dict__["method"]
if "body" in record.__dict__ and record.__dict__["body"]:
if record.__dict__["content-type"] == "application/json":
json_record["body"] = json.loads(record.__dict__["body"].decode("utf-8"))
else:
json_record["body"] = record.__dict__["body"].decode("utf-8")
if "client" in record.__dict__:
json_record["client"] = record.__dict__["client"]
if "status_code" in record.__dict__:
json_record["status_code"] = record.__dict__["status_code"]
if record.levelno == logging.ERROR and record.exc_info:
json_record["err"] = self.formatException(record.exc_info)
return json.dumps(json_record)


logger = logging.getLogger("yeti.audit.log")
logger.setLevel(logging.INFO)
logger.propagate = False

log_queue = queue.Queue(-1)
queue_handler = QueueHandler(log_queue)
logger.addHandler(queue_handler)

json_formatter = JsonFormatter()
console_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(username)s - %(path)s - %(method)s - %(body)s - %(client)s - %(status_code)s"
)

handlers = list()

console_handler = logging.StreamHandler()
console_handler.setFormatter(console_formatter)
handlers.append(console_handler)

audit_logfile = yeti_config.get('system', 'audit_logfile')

if audit_logfile:
if os.access(audit_logfile, os.W_OK):
file_handler = logging.FileHandler(audit_logfile)
file_handler.setFormatter(json_formatter)
handlers.append(file_handler)
else:
logging.getLogger().warning("Audit log file not writable, using console only")
else:
logging.getLogger().warning("Audit log file not configured, using console only")

arango_handler = ArangoHandler()
handlers.append(arango_handler)

listener = QueueListener(log_queue, *handlers)
listener.start()
30 changes: 30 additions & 0 deletions core/schemas/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import datetime
from typing import ClassVar, Literal

from core import database_arango
from core.schemas.model import YetiModel
from pydantic import computed_field


class AuditLog(YetiModel, database_arango.ArangoYetiConnector):
_collection_name: ClassVar[str] = "auditlog"
_type_filter: ClassVar[str | None] = None
_root_type: Literal["auditlog"] = "auditlog"

created: datetime.datetime
username: str
action: str
status: str
target: str
content: dict = {}
ip: str
status_code: int

@computed_field(return_type=Literal["auditlog"])
@property
def root_type(self):
return self._root_type

@classmethod
def load(cls, object: dict) -> "AuditLog":
return cls(**object)
40 changes: 20 additions & 20 deletions core/web/apiv2/auth.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import datetime

from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import APIRouter, Depends, HTTPException, Response, Security, status
from core.config.config import yeti_config
from core.schemas.user import User, UserSensitive
from fastapi import (APIRouter, Depends, HTTPException, Response, Security,
status)
from fastapi.responses import RedirectResponse
from fastapi.security import (
APIKeyCookie,
APIKeyHeader,
OAuth2PasswordBearer,
OAuth2PasswordRequestForm,
)
from fastapi.security import (APIKeyCookie, APIKeyHeader, OAuth2PasswordBearer,
OAuth2PasswordRequestForm)
from jose import JWTError, jwt
from starlette.requests import Request

from core.config.config import yeti_config
from core.schemas.user import User, UserSensitive

ACCESS_TOKEN_EXPIRE_MINUTES = datetime.timedelta(
minutes=yeti_config.get('auth', "access_token_expire_minutes")
)
Expand Down Expand Up @@ -62,21 +58,15 @@ def create_access_token(data: dict, expires_delta: datetime.timedelta | None = N
return encoded_jwt


async def get_current_user(
async def get_current_user(request: Request,
token: str = Depends(oauth2_scheme), cookie: str = Security(cookie_scheme)
) -> UserSensitive:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

disabled_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account disabled. Please contact your server admin.",
headers={"WWW-Authenticate": "Bearer"},
)

request.state.username = None
if not token and not cookie:
raise credentials_exception

Expand All @@ -93,11 +83,21 @@ async def get_current_user(
user = UserSensitive.find(username=username)
if user is None:
raise credentials_exception
if not user.enabled:
raise disabled_exception
request.state.username = user.username
return user


async def get_current_active_user(
current_user: User = Security(get_current_user)
):
if not current_user.enabled:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account disabled. Please contact your server admin.",
headers={"WWW-Authenticate": "Bearer"}
)
return current_user


class GetCurrentUserWithPermissions:
"""Helper class to manage a layer of user permissions.
Expand Down
7 changes: 4 additions & 3 deletions core/web/apiv2/system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from core.config.config import yeti_config
from core.taskscheduler import app
from fastapi import APIRouter
from core.web.apiv2.auth import get_current_active_user
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict

# API endpoints
Expand Down Expand Up @@ -37,7 +38,7 @@ async def get_config() -> SystemConfigResponse:
return config


@router.get("/workers")
@router.get("/workers", dependencies=[Depends(get_current_active_user)])
async def get_worker_status() -> WorkerStatusResponse:
inspect = app.control.inspect(timeout=5, destination=None)

Expand All @@ -56,7 +57,7 @@ async def get_worker_status() -> WorkerStatusResponse:
active=active_tasks,
)

@router.post("/restartworker/{worker_name}")
@router.post("/restartworker/{worker_name}", dependencies=[Depends(get_current_active_user)])
async def restart_worker(worker_name: str) -> WorkerRestartResponse:
"""Restarts a single or all Celery workers."""
destination = [worker_name] if worker_name != "all" else None
Expand Down
105 changes: 82 additions & 23 deletions core/web/webapp.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,96 @@
from fastapi import FastAPI
from fastapi import APIRouter
from starlette.middleware.sessions import SessionMiddleware
import logging

from core.web.apiv2 import observables
from core.web.apiv2 import entities
from core.web.apiv2 import indicators
from core.web.apiv2 import tag
from core.web.apiv2 import graph
from core.web.apiv2 import auth
from core.web.apiv2 import tasks
from core.web.apiv2 import templates
from core.web.apiv2 import users
from core.web.apiv2 import system
from core.config.config import yeti_config
from core.logger import logger
from core.web.apiv2 import (auth, entities, graph, indicators, observables,
system, tag, tasks, templates, users)
from fastapi import APIRouter, Depends, FastAPI, Request
from starlette.middleware.sessions import SessionMiddleware
from starlette.types import Message

SECRET_KEY = yeti_config.get('auth', "secret_key")

app = FastAPI()

app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)

api_router = APIRouter()

api_router.include_router(auth.router, prefix="/auth", tags=["auth"])

api_router.include_router(
observables.router, prefix="/observables", tags=["observables"]
observables.router, prefix="/observables", tags=["observables"],
dependencies=[Depends(auth.get_current_active_user)]
)
api_router.include_router(
entities.router, prefix="/entities", tags=["entities"],
dependencies=[Depends(auth.get_current_active_user)]
)
api_router.include_router(
indicators.router, prefix="/indicators", tags=["indicators"],
dependencies=[Depends(auth.get_current_active_user)]
)
api_router.include_router(
tag.router, prefix="/tags", tags=["tags"],
dependencies=[Depends(auth.get_current_active_user)]
)
api_router.include_router(
tasks.router, prefix="/tasks", tags=["tasks"],
dependencies=[Depends(auth.get_current_active_user)]
)
api_router.include_router(
graph.router, prefix="/graph", tags=["graph"],
dependencies=[Depends(auth.get_current_active_user)]
)
api_router.include_router(
templates.router, prefix="/templates", tags=["templates"],
dependencies=[Depends(auth.get_current_active_user)]
)
api_router.include_router(
users.router, prefix="/users", tags=["users"],
dependencies=[Depends(auth.get_current_active_user)]
)
# Dependencies are set in system endpoints
api_router.include_router(
system.router, prefix="/system", tags=["system"],
)
api_router.include_router(entities.router, prefix="/entities", tags=["entities"])
api_router.include_router(indicators.router, prefix="/indicators", tags=["indicators"])
api_router.include_router(tag.router, prefix="/tags", tags=["tags"])
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(graph.router, prefix="/graph", tags=["graph"])
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(templates.router, prefix="/templates", tags=["templates"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(system.router, prefix="/system", tags=["system"])

app.include_router(api_router, prefix="/api/v2")

async def set_body(request: Request, body: bytes):
async def receive() -> Message:
return {'type': 'http.request', 'body': body}
request._receive = receive

@app.middleware("http")
async def log_requests(request: Request, call_next):
req_body = await request.body()
await set_body(request, req_body)
response = await call_next(request)
try:
extra = {
"type": "audit.log",
"path": request.url.path,
"method": request.method,
"username": "anonymous",
# When behind a proxy, we should start uvicorn with --proxy-headers
# and use request.headers.get('x-forwarded-for') instead.
"client": request.client.host,
"status_code": response.status_code,
"content-type": request.headers.get("content-type", ""),
"body": b""
}
if getattr(request.state, 'username', None):
extra["username"] = request.state.username
if req_body:
extra["body"] = req_body
if response.status_code == 200:
logger.info("Authorized request", extra=extra)
elif response.status_code == 401:
logger.warning("Unauthorized request", extra=extra)
else:
logger.error("Bad request", extra=extra)
except Exception as e:
err_logger = logging.getLogger("webapp.log_requests")
err_logger.exception("Error while logging request")
return response
Loading