Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kafka json byte encoding to match rest server #1622

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions mlserver/codecs/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# seperate file to side step circular dependency on the decode_str function

from typing import Any, Union
import json

try:
import orjson
except ImportError:
orjson = None # type: ignore

from .string import decode_str


# originally taken from: mlserver/rest/responses.py
class _BytesJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, bytes):
# If we get a bytes payload, try to decode it back to a string on a
# "best effort" basis
return decode_str(obj)

return super().default(self, obj)


def _encode_object_to_bytes(obj: Any) -> str:
"""
Add compatibility with `bytes` payloads to `orjson`
"""
if isinstance(obj, bytes):
# If we get a bytes payload, try to decode it back to a string on a
# "best effort" basis
return decode_str(obj)

raise TypeError


def encode_to_json_bytes(v: Any) -> bytes:
"""encodes a dict into json bytes, can deal with byte like values gracefully"""
if orjson is None:
# Original implementation of starlette's JSONResponse, using our
# custom encoder (capable of "encoding" bytes).
# Original implementation can be seen here:
# https://github.com/encode/starlette/blob/
# f53faba229e3fa2844bc3753e233d9c1f54cca52/starlette/responses.py#L173-L180
return json.dumps(
v,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
cls=_BytesJSONEncoder,
).encode("utf-8")

return orjson.dumps(v, default=_encode_object_to_bytes)


def decode_from_bytelike_json_to_dict(v: Union[bytes, str]) -> dict:
if orjson is None:
return json.loads(v)

return orjson.loads(v)
29 changes: 4 additions & 25 deletions mlserver/kafka/message.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,7 @@
import json

from typing import Dict, Optional, List, Tuple, Union
from typing import Dict, Optional, List, Tuple

from pydantic import BaseModel

try:
import orjson
except ImportError:
orjson = None # type: ignore


def _encode_value(v: dict) -> bytes:
if orjson is None:
dumped = json.dumps(v)
return dumped.encode("utf-8")

return orjson.dumps(v)


def _decode_value(v: Union[bytes, str]) -> dict:
if orjson is None:
return json.loads(v)

return orjson.loads(v)
from ..codecs.json import encode_to_json_bytes, decode_from_bytelike_json_to_dict


def _encode_headers(h: Dict[str, str]) -> List[Tuple[str, bytes]]:
Expand All @@ -48,7 +27,7 @@ def from_types(
@classmethod
def from_kafka_record(cls, kafka_record) -> "KafkaMessage":
key = kafka_record.key
value = _decode_value(kafka_record.value)
value = decode_from_bytelike_json_to_dict(kafka_record.value)
headers = _decode_headers(kafka_record.headers)
return KafkaMessage(key=key, value=value, headers=headers)

Expand All @@ -61,7 +40,7 @@ def encoded_key(self) -> bytes:

@property
def encoded_value(self) -> bytes:
return _encode_value(self.value)
return encode_to_json_bytes(self.value)

@property
def encoded_headers(self) -> List[Tuple[str, bytes]]:
Expand Down
62 changes: 6 additions & 56 deletions mlserver/rest/responses.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,22 @@
import json

from typing import Any

from pydantic import BaseModel
from starlette.responses import JSONResponse as _JSONResponse

from ..codecs.string import decode_str

try:
import orjson
except ImportError:
orjson = None # type: ignore


class BytesJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, bytes):
# If we get a bytes payload, try to decode it back to a string on a
# "best effort" basis
return decode_str(obj)

return super().default(self, obj)
from ..codecs.json import encode_to_json_bytes


class Response(_JSONResponse):
"""
Custom Response class to use `orjson` if present.
Otherwise, it'll fall back to the standard JSONResponse.
Custom Response that will use the encode_to_json_bytes function to
encode given content to json based on library availability.
See mlserver/codecs/utils.py for more details
"""

media_type = "application/json"

def render(self, content: Any) -> bytes:
return _render(content)
return encode_to_json_bytes(content)


class ServerSentEvent:
Expand All @@ -45,38 +29,4 @@ def __init__(self, data: BaseModel, *args, **kwargs):

def encode(self) -> bytes:
as_dict = self.data.model_dump()
return self._pre + _render(as_dict) + self._sep


def _render(content: Any) -> bytes:
if orjson is None:
# Original implementation of starlette's JSONResponse, using our
# custom encoder (capable of "encoding" bytes).
# Original implementation can be seen here:
# https://github.com/encode/starlette/blob/
# f53faba229e3fa2844bc3753e233d9c1f54cca52/starlette/responses.py#L173-L180
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
cls=BytesJSONEncoder,
).encode("utf-8")

# This is equivalent to the ORJSONResponse implementation in FastAPI:
# https://github.com/tiangolo/fastapi/blob/
# 864643ef7608d28ac4ed321835a7fb4abe3dfc13/fastapi/responses.py#L32-L34
return orjson.dumps(content, default=_encode_bytes)


def _encode_bytes(obj: Any) -> str:
"""
Add compatibility with `bytes` payloads to `orjson`
"""
if isinstance(obj, bytes):
# If we get a bytes payload, try to decode it back to a string on a
# "best effort" basis
return decode_str(obj)

raise TypeError
return self._pre + encode_to_json_bytes(as_dict) + self._sep
40 changes: 40 additions & 0 deletions tests/codecs/test_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
sakoush marked this conversation as resolved.
Show resolved Hide resolved

from typing import Any, Union

from mlserver.codecs.json import decode_from_bytelike_json_to_dict, encode_to_json_bytes


@pytest.mark.parametrize(
"input, expected",
[
(b"{}", dict()),
("{}", dict()),
('{"hello":"world"}', {"hello": "world"}),
(b'{"hello":"world"}', {"hello": "world"}),
(b'{"hello":"' + "world".encode("utf-8") + b'"}', {"hello": "world"}),
(
b'{"hello":"' + "world".encode("utf-8") + b'", "foo": { "bar": "baz" } }',
{"hello": "world", "foo": {"bar": "baz"}},
),
],
)
def test_decode_input(input: Union[str, bytes], expected: dict):
assert expected == decode_from_bytelike_json_to_dict(input)


@pytest.mark.parametrize(
# input and expected are flipped here for easier CTRL+C / V
"expected, input",
DerTiedemann marked this conversation as resolved.
Show resolved Hide resolved
[
(b"{}", dict()),
(b'{"hello":"world"}', {"hello": "world"}),
(b'{"hello":"' + "world".encode("utf-8") + b'"}', {"hello": "world"}),
(
b'{"hello":"' + "world".encode("utf-8") + b'","foo":{"bar":"baz"}}',
{"hello": b"world", "foo": {"bar": "baz"}},
),
],
)
def test_encode_input(input: Any, expected: bytes):
assert expected == encode_to_json_bytes(input)
Loading