Skip to content

Commit

Permalink
Merge pull request #375 from Anko59/374-case-templates-endpoint
Browse files Browse the repository at this point in the history
Case template endpoint
  • Loading branch information
Kamforka authored Jan 15, 2025
2 parents 82ad1ab + 55fdf07 commit 71fcdf0
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ repos:
language: system
pass_filenames: false
always_run: true
stages: [push]
stages: [pre-push]
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from thehive4py.helpers import now_to_ts
from thehive4py.types.alert import InputAlert, OutputAlert
from thehive4py.types.case import InputCase, OutputCase
from thehive4py.types.case_template import InputCaseTemplate, OutputCaseTemplate
from thehive4py.types.comment import OutputComment
from thehive4py.types.custom_field import OutputCustomField
from thehive4py.types.observable import InputObservable, OutputObservable
Expand Down Expand Up @@ -113,6 +114,30 @@ def test_cases(thehive: TheHiveApi) -> List[OutputCase]:
return [thehive.case.create(case=case) for case in cases]


@pytest.fixture
def test_case_template(thehive: TheHiveApi) -> OutputCaseTemplate:
name = "my first case template"
return thehive.case_template.create(
case_template={
"name": name,
"description": "...",
"tags": ["whatever"],
}
)


@pytest.fixture
def test_case_templates(thehive: TheHiveApi) -> List[OutputCaseTemplate]:
case_templates: List[InputCaseTemplate] = [
{"name": "my first case template", "description": "..."},
{"name": "my second case template", "description": "..."},
]
return [
thehive.case_template.create(case_template=case_template)
for case_template in case_templates
]


@pytest.fixture
def test_observable(thehive: TheHiveApi, test_case: OutputCase) -> OutputObservable:
return thehive.observable.create_in_case(
Expand Down
50 changes: 50 additions & 0 deletions tests/test_case_template_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List

import pytest
from thehive4py.client import TheHiveApi
from thehive4py.errors import TheHiveError
from thehive4py.types.case_template import InputCaseTemplate, OutputCaseTemplate


class TestCaseTemplateEndpoint:
def test_create_and_get(self, thehive: TheHiveApi):
created_case_template = thehive.case_template.create(
case_template={
"name": "my first template",
"description": "Template description",
}
)
fetched_case_template = thehive.case_template.get(created_case_template["_id"])
assert created_case_template == fetched_case_template

def test_update(self, thehive: TheHiveApi, test_case_template: OutputCaseTemplate):
case_template_id = test_case_template["_id"]
update_fields: InputCaseTemplate = {
"name": "updated template name",
"description": "updated template description",
}
thehive.case_template.update(
case_template_id=case_template_id, fields=update_fields
)
updated_case_template = thehive.case_template.get(
case_template_id=case_template_id
)

for key, value in update_fields.items():
assert updated_case_template.get(key) == value

def test_delete(self, thehive: TheHiveApi, test_case_template: OutputCaseTemplate):
case_template_id = test_case_template["_id"]
thehive.case_template.delete(case_template_id=case_template_id)
with pytest.raises(TheHiveError):
thehive.case_template.get(case_template_id=case_template_id)

def test_find(
self,
thehive: TheHiveApi,
test_case_templates: List[OutputCaseTemplate],
):
found_templates = thehive.case_template.find()
original_ids = [template["_id"] for template in test_case_templates]
found_ids = [template["_id"] for template in found_templates]
assert sorted(found_ids) == sorted(original_ids)
5 changes: 5 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,15 @@ def _reset_hive_org(hive_url: str, test_config: TestConfig, organisation: str) -

alerts = client.alert.find()
cases = client.case.find()
case_templates = client.case_template.find()

with ThreadPoolExecutor() as executor:
executor.map(client.alert.delete, [alert["_id"] for alert in alerts])
executor.map(client.case.delete, [case["_id"] for case in cases])
executor.map(
client.case_template.delete,
[case_template["_id"] for case_template in case_templates],
)


def _reset_hive_admin_org(hive_url: str, test_config: TestConfig) -> None:
Expand Down
2 changes: 2 additions & 0 deletions thehive4py/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from thehive4py.endpoints import (
AlertEndpoint,
CaseEndpoint,
CaseTemplateEndpoint,
CommentEndpoint,
ObservableEndpoint,
OrganisationEndpoint,
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
# case management endpoints
self.alert = AlertEndpoint(self.session)
self.case = CaseEndpoint(self.session)
self.case_template = CaseTemplateEndpoint(self.session)
self.comment = CommentEndpoint(self.session)
self.observable = ObservableEndpoint(self.session)
self.procedure = ProcedureEndpoint(self.session)
Expand Down
1 change: 1 addition & 0 deletions thehive4py/endpoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .alert import AlertEndpoint
from .case import CaseEndpoint
from .case_template import CaseTemplateEndpoint
from .comment import CommentEndpoint
from .cortex import CortexEndpoint
from .custom_field import CustomFieldEndpoint
Expand Down
47 changes: 47 additions & 0 deletions thehive4py/endpoints/case_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from thehive4py.endpoints._base import EndpointBase
from thehive4py.query import QueryExpr
from thehive4py.query.filters import FilterExpr
from thehive4py.query.page import Paginate
from thehive4py.query.sort import SortExpr
from thehive4py.types.case_template import OutputCaseTemplate, InputCaseTemplate
from typing import List, Optional


class CaseTemplateEndpoint(EndpointBase):
def find(
self,
filters: Optional[FilterExpr] = None,
sortby: Optional[SortExpr] = None,
paginate: Optional[Paginate] = None,
) -> List[OutputCaseTemplate]:
query: QueryExpr = [
{"_name": "listCaseTemplate"},
*self._build_subquery(filters=filters, sortby=sortby, paginate=paginate),
]

return self._session.make_request(
"POST",
path="/api/v1/query",
json={"query": query},
params={"name": "caseTemplate"},
)

def get(self, case_template_id: str) -> OutputCaseTemplate:
return self._session.make_request(
"GET", path=f"/api/v1/caseTemplate/{case_template_id}"
)

def create(self, case_template: InputCaseTemplate) -> OutputCaseTemplate:
return self._session.make_request(
"POST", path="/api/v1/caseTemplate", json=case_template
)

def delete(self, case_template_id: str) -> None:
return self._session.make_request(
"DELETE", path=f"/api/v1/caseTemplate/{case_template_id}"
)

def update(self, case_template_id: str, fields: InputCaseTemplate) -> None:
return self._session.make_request(
"PATCH", path=f"/api/v1/caseTemplate/{case_template_id}", json=fields
)
52 changes: 52 additions & 0 deletions thehive4py/types/case_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Literal, TypedDict, Union

from .custom_field import InputCustomFieldValue
from .task import InputTask, OutputTask

SeverityValue = Literal[1, 2, 3, 4]
TlpValue = Literal[0, 1, 2, 3, 4]
PapValue = Literal[0, 1, 2, 3]


class InputCaseTemplateRequired(TypedDict):
name: str


class InputCaseTemplate(InputCaseTemplateRequired, total=False):
displayName: str
titlePrefix: str
description: str
severity: SeverityValue
tags: List[str]
flag: bool
tlp: TlpValue
pap: PapValue
summary: str
tasks: List[InputTask]
pageTemplateIds: List[str]
customFields: Union[dict, List[InputCustomFieldValue]]


class OutputCaseTemplateRequired(TypedDict):
_id: str
_type: str
_createdBy: str
_createdAt: int
name: str


class OutputCaseTemplate(OutputCaseTemplateRequired, total=False):
_updatedBy: str
_updatedAt: int
displayName: str
titlePrefix: str
description: str
severity: SeverityValue
tags: List[str]
flag: bool
tlp: TlpValue
pap: PapValue
summary: str
tasks: List[OutputTask]
pageTemplateIds: List[str]
customFields: Union[dict, List[InputCustomFieldValue]]

0 comments on commit 71fcdf0

Please sign in to comment.