Skip to content

Commit

Permalink
fix: Update for --files instead of --input, make accept option with d…
Browse files Browse the repository at this point in the history
…efault application/json
  • Loading branch information
vykozlov authored Jul 12, 2024
1 parent dd7c6ee commit 254651d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
8 changes: 4 additions & 4 deletions api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,12 @@ def main():
logger.debug("Calling method with args: %s", args)
del vars(args)["method"]
if hasattr(args, "files"):
file_extension = os.path.splitext(args.input)[1]
file_extension = os.path.splitext(args.files)[1]
args.files = UploadedFile(
"input",
"files",
args.files,
"application/octet-stream",
f"input{file_extension}",
f"files{file_extension}",
)
results = method_function(**vars(args))
print(json.dumps(results))
Expand Down Expand Up @@ -322,7 +322,7 @@ def main():
--data /srv/football-players-detection-7/data.yaml\
--Enable_MLFLOW --epochs 50
python3 api/__init__.py predict --input \
python3 api/__init__.py predict --files \
/srv/yolov8_api/tests/data/det/test/cat1.jpg\
--task_type det --accept application/json
"""
4 changes: 3 additions & 1 deletion api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class Meta:
"description": "image size as scalar or (h, w) list,"
" i.e. (640, 480)"
},
load_default=[640,480]
)

conf = fields.Float(
Expand Down Expand Up @@ -153,7 +154,8 @@ class Meta:
"description": "Return format for method response.",
"location": "headers",
},
required=True,
required=False,
load_default="application/json",
validate=validate.OneOf(responses.content_types),
)

Expand Down
6 changes: 3 additions & 3 deletions yolov8_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def predict(
print("arg of prediction are", args)

model = YOLO(args["model"])
test_image_path = args["input"]
test_image_path = args["files"]
results = []
for image_path in test_image_path:
print(
Expand All @@ -57,7 +57,7 @@ def predict(
utils.remove_keys_from_dict(
args,
[
"input",
"files",
"accept",
"task_type",
],
Expand All @@ -73,7 +73,7 @@ def predict(

if __name__ == "__main__":
args = {
"input": ["/home/se1131/cat1.jpg"],
"files": ["/home/se1131/cat1.jpg"],
"model": "yolov8n.pt",
"imgsz": [680, 512],
"conf": 0.25,
Expand Down

0 comments on commit 254651d

Please sign in to comment.