Skip to content

Commit e728569

Browse files
Merge pull request #28 from MITLibraries/tco-157
Implement prediction logic via neural network algorithm
2 parents 4339a9b + 3e6e138 commit e728569

File tree

11 files changed

+419
-40
lines changed

11 files changed

+419
-40
lines changed

Makefile

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
SHELL=/bin/bash
22
DATETIME:=$(shell date -u +%Y%m%dT%H%M%SZ)
3+
PAYLOAD:=tests/sam/citation.json
4+
35
### This is the Terraform-generated header for tacos-detectors-lambdas-dev. If ###
46
### this is a Lambda repo, uncomment the FUNCTION line below ###
57
### and review the other commented lines in the document. ###
@@ -77,33 +79,13 @@ sam-http-run: # Run lambda locally as an HTTP server
7779

7880
sam-http-ping: # Send curl command to SAM HTTP server using the ping action
7981
curl --location 'http://localhost:3000/foo' \
80-
--header 'Content-Type: application\json' \
82+
--header 'Content-Type: application/json' \
8183
--data '{"action":"ping", "challenge_secret": "secret_phrase"}'
8284

8385
sam-http-predict: # Send curl command to SAM HTTP server using the predict action (next step - take file argument?)
8486
curl --location 'http://localhost:3000/foo' \
85-
--header 'Content-Type: application\json' \
86-
--data '{ \
87-
"action": "predict", \
88-
"challenge_secret": "secret_phrase", \
89-
"features": { \
90-
"apa": 0, \
91-
"brackets": 0, \
92-
"colons": 0, \
93-
"commas": 0, \
94-
"lastnames": 0, \
95-
"no": 0, \
96-
"pages": 0, \
97-
"periods": 0, \
98-
"pp": 0, \
99-
"quotes": 0, \
100-
"semicolons": 0, \
101-
"vol": 0, \
102-
"words": 0, \
103-
"year":0 \
104-
} \
105-
}'
106-
87+
--header 'Content-Type: application/json' \
88+
--data '@$(PAYLOAD)'
10789

10890
### Terraform-generated Developer Deploy Commands for Dev environment ###
10991
dist-dev: ## Build docker container (intended for developer-based manual build)

Pipfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ name = "pypi"
66
[packages]
77
sentry-sdk = "*"
88
jsonschema = "*"
9+
pandas = "*"
10+
scikit-learn = "*"
11+
pandas-stubs = "*"
912

1013
[dev-packages]
1114
black = "*"

Pipfile.lock

Lines changed: 247 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,24 +86,51 @@ the lambda does not utilize them in request payload.
8686

8787
3. In another terminal, perform an HTTP request via another `Makefile` command:
8888

89+
The server's baseline readiness can be confirmed via the ping action:
8990
```shell
9091
make sam-http-ping
9192
```
9293

93-
Response should have an HTTP status of `200` and respond with:
94+
The response should have an HTTP status of `200` and respond with:
9495

9596
```json
9697
{
9798
"response": "pong"
9899
}
99100
```
100101

102+
Actual predictions can be sent in via the predict action:
103+
104+
```shell
105+
make sam-http-predict
106+
```
107+
108+
```json
109+
{
110+
"response": "True"
111+
}
112+
```
113+
114+
Custom payloads can be found in the `tests/sam` directory, and the default payload overridden via the `PAYLOAD` Makefile
115+
argument:
116+
117+
```shell
118+
make sam-http-predict PAYLOAD=tests/sam/noncitation.json
119+
```
120+
121+
```json
122+
{
123+
"response": "False"
124+
}
125+
```
126+
101127
### Invoking lambda directly
102128

103129
While lambdas can be invoked via HTTP methods (ALB, Function URL, etc), they are also often invoked directly with an
104130
`event` payload. To do so with SAM, you do **not** need to first start an HTTP server with `make sam-run`, you can
105131
invoke the function image directly:
106132

133+
#### Example 1: ping
107134
```shell
108135
echo '{"action": "ping", "challenge_secret": "secret_phrase"}' | sam local invoke --env-vars tests/sam/env.json -e -
109136
```
@@ -115,7 +142,23 @@ Response:
115142
false, "body": "{\"response\": \"pong\"}"}
116143
```
117144

118-
As you can see from this response, the lambda is still returning a dictionary that _would_ work for an HTTP response,
145+
#### Example 2: predict
146+
147+
The JSON files with example payloads in `tests/sam` can be helpful for working with the `predict` action, rather than
148+
trying to include all features and values directly within an echo command:
149+
150+
```shell
151+
echo "$(cat tests/sam/citation.json)" | sam local invoke --env-vars tests/sam/env.json -e -
152+
```
153+
154+
Response:
155+
156+
```text
157+
{"statusCode": 200, "statusDescription": "200 OK", "headers": {"Content-Type": "application/json"}, "isBase64Encoded":
158+
false, "body": "{\"response\": \"True\"}"}
159+
```
160+
161+
As you can see from these responses, the lambda is still returning a dictionary that _would_ work for an HTTP response,
119162
but is actually just a dictionary with the required information.
120163

121164
It's unknown at this time if this lambda will get invoked via non-HTTP methods, but SAM will be helpful for testing and

lambdas/models/neural.pkl

15.3 KB
Binary file not shown.

lambdas/predict.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import asdict, dataclass
55
from http import HTTPStatus
6+
from pickle import load
67

8+
import pandas as pd
79
from jsonschema import ValidationError, validate
810

911
from lambdas.config import Config, configure_sentry
@@ -41,13 +43,28 @@ def handle(self, _payload: InputPayload) -> dict:
4143
class PredictHandler(RequestHandler):
4244
"""Handle prediction requests."""
4345

46+
def load_model(self) -> None:
47+
"""Load the machine learning model, and confirm it is fitted.
48+
49+
Please note that this method does not have a return value. It populates
50+
the `self.model` attribute with the loaded model.
51+
"""
52+
path = "lambdas/models/neural.pkl"
53+
with open(path, "rb") as f:
54+
self.model = load(f) # noqa: S301
55+
4456
def handle(self, payload: InputPayload) -> dict:
45-
# validate payload against a JSONSchema
57+
"""Validate received payload, load model, and generate prediction."""
4658
with open("lambdas/schemas/features_schema.json") as f:
4759
schema = json.load(f)
48-
logger.debug(payload.to_dict())
4960
validate(instance=payload.to_dict(), schema=schema)
50-
return {"response": "true"}
61+
62+
self.load_model()
63+
64+
data = pd.DataFrame(payload.features, index=[0])
65+
prediction = self.model.predict(data)
66+
67+
return {"response": bool(prediction[0])}
5168

5269

5370
class LambdaProcessor:

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ disallow_untyped_calls = true
1010
disallow_untyped_defs = true
1111
exclude = ["tests/"]
1212

13+
[[tool.mypy.overrides]]
14+
module = ["sklearn.*"]
15+
ignore_missing_imports = true
16+
1317
[tool.pytest.ini_options]
1418
log_level = "INFO"
1519

@@ -41,7 +45,6 @@ ignore = [
4145
"PLR0912",
4246
"PLR0913",
4347
"PLR0915",
44-
"S320",
4548
"S321",
4649
]
4750

tests/conftest.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,38 @@ def valid_ping_event():
1919

2020

2121
@pytest.fixture
22-
def valid_predict_event():
23-
"""Valid event payload for an HTTP invocation."""
22+
def valid_predict_event_citation():
23+
"""Valid event payload with features extracted from a known citation."""
24+
return {
25+
"body": json.dumps(
26+
{
27+
"action": "predict",
28+
"challenge_secret": "secret_phrase",
29+
"features": {
30+
"apa": 0,
31+
"brackets": 0,
32+
"colons": 0,
33+
"commas": 5,
34+
"lastnames": 4,
35+
"no": 0,
36+
"pages": 0,
37+
"periods": 7,
38+
"pp": 0,
39+
"quotes": 0,
40+
"semicolons": 1,
41+
"vol": 0,
42+
"words": 12,
43+
"year": 0,
44+
},
45+
}
46+
),
47+
"requestContext": {"http": {"method": "POST"}},
48+
}
49+
50+
51+
@pytest.fixture
52+
def valid_predict_event_noncitation():
53+
"""Valid event payload with features extracted from a non-citation."""
2454
return {
2555
"body": json.dumps(
2656
{
@@ -39,7 +69,7 @@ def valid_predict_event():
3969
"quotes": 0,
4070
"semicolons": 0,
4171
"vol": 0,
42-
"words": 0,
72+
"words": 1,
4373
"year": 0,
4474
},
4575
}

tests/sam/citation.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"action": "predict",
3+
"challenge_secret": "secret_phrase",
4+
"features": {
5+
"apa": 0,
6+
"brackets": 0,
7+
"colons": 0,
8+
"commas": 5,
9+
"lastnames": 4,
10+
"no": 0,
11+
"pages": 0,
12+
"periods": 7,
13+
"pp": 0,
14+
"quotes": 0,
15+
"semicolons": 1,
16+
"vol": 0,
17+
"words": 12,
18+
"year": 0
19+
}
20+
}

tests/sam/noncitation.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"action": "predict",
3+
"challenge_secret": "secret_phrase",
4+
"features": {
5+
"apa": 0,
6+
"brackets": 0,
7+
"colons": 0,
8+
"commas": 0,
9+
"lastnames": 0,
10+
"no": 0,
11+
"pages": 0,
12+
"periods": 0,
13+
"pp": 0,
14+
"quotes": 0,
15+
"semicolons": 0,
16+
"vol": 0,
17+
"words": 1,
18+
"year": 0
19+
}
20+
}

0 commit comments

Comments
 (0)