Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove JSON Dependency #1812

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions mlx/distributed/ring/ring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
#include <unistd.h>

#include <chrono>
#include <cstring>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>

#include <json.hpp>

#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/io/json.h"
#include "mlx/threadpool.h"

#define SWITCH_TYPE(x, ...) \
Expand Down Expand Up @@ -81,7 +81,6 @@ constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;

using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;

namespace {

Expand Down Expand Up @@ -212,11 +211,11 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<address_t>> nodes;
std::ifstream f(hostfile);

json hosts = json::parse(f);
io::json hosts = io::parse_json(f);
for (auto& h : hosts) {
std::vector<address_t> host;
for (auto& ips : h) {
host.push_back(std::move(parse_address(ips.get<std::string>())));
for (std::string ips : h) {
host.push_back(std::move(parse_address(ips)));
}
nodes.push_back(std::move(host));
}
Expand Down
1 change: 1 addition & 0 deletions mlx/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <variant>

#include "mlx/array.h"
#include "mlx/io/json.h"
#include "mlx/io/load.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
Expand Down
8 changes: 1 addition & 7 deletions mlx/io/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/json.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp)

if(MLX_BUILD_SAFETENSORS)
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp)
Expand Down
281 changes: 281 additions & 0 deletions mlx/io/json.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
// Copyright © 2025 Apple Inc.

#include "mlx/io/json.h"

#include <sstream>

namespace mlx::core {

namespace io {

std::string read_digits(std::istream& s) {
std::string num = "";
char ch = s.get();
while (std::isdigit(ch) || ch == '-' || ch == '.' || ch == 'e' || ch == 'E') {
num += ch;
ch = s.get();
}
s.seekg(-1, std::ios::cur);
return num;
}

json parse_json_number(std::istream& s) {
auto num = read_digits(s);
if (num.find_first_of(".eE") != std::string::npos) {
return json(std::stod(num));
} else {
return json(std::stol(num));
}
Comment on lines +24 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (num.find_first_of(".eE") != std::string::npos) {
return json(std::stod(num));
} else {
return json(std::stol(num));
}
size_t pos;
json ret;
if (num.find_first_of(".eE") != std::string::npos) {
ret = json(std::stod(num, &pos));
} else {
ret = json(std::stol(num, &pos));
}
if (pos != num.size()) {
throw std::invalid_argument(...);
}
return std::move(ret);

The above is a bit of a problem because read_digits reads any character in [0-9-eE.] which means it could be not a number eg 1.2.3.4 and in this case we should fail. std::stod and friends return the number of characters processed so it is a fairly easy change to check for correctness.

}

std::string parse_json_string(std::istream& s) {
bool in_escape = false;
std::string str = "";
char ch = s.get();
while (ch != '"' || in_escape) {
if (in_escape) {
if (ch == '"' || ch == '\\' || ch == '/') {
str += ch;
} else if (ch == 'b') {
str += '\b';
} else if (ch == 'f') {
str += '\f';
} else if (ch == 'n') {
str += '\n';
} else if (ch == 'r') {
str += '\r';
} else if (ch == 't') {
str += '\t';
} else if (ch == 'u') {
// Basic unicode support -- leaving the escaping unchanged
str += "\\u";
for (int i = 0; i < 4; i++) {
str += s.get();
}
} else {
throw std::invalid_argument("[json] Invalid escape sequence.");
}
in_escape = false;
} else if (ch == '\\') {
in_escape = true;
} else {
str += ch;
}

ch = s.get();
if (s.eof()) {
throw std::invalid_argument("[json] Unfinished string value.");
}
}
return str;
}

json parse_json_helper(std::istream& s) {
char ch;
s >> std::ws >> ch;
// object
if (ch == '{') {
json::json_object object;
while (true) {
s >> std::ws >> ch;
if (ch == '}') {
break;
} else if (ch != '"') {
throw std::invalid_argument("[json] Invalid json: expected '\"'.");
}
std::string key = parse_json_string(s);
s >> std::ws >> ch;
if (ch != ':') {
throw std::invalid_argument("[json] Invalid json: expected '\"'.");
}
json value = parse_json_helper(s);
object[key] = value;

s >> std::ws >> ch;
if (ch == '}') {
break;
} else if (ch != ',') {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is something that can be fixed later but unfortunately this is a tiny bit out of spec. It is also one of the most annoying parts of the json spec so I don't really mind if we don't fix it immediately but we should at least put a TODO unfortunately.

The following is unfortunately not valid JSON ["a", "b",] (notice the extra comma) but it would be parsed without error here. Both for object and array.

throw std::invalid_argument("[json] Invalid json: expected ','.");
}
}
return object;
// array
} else if (ch == '[') {
json::json_array array;
s >> std::ws;
while (true) {
if (s.peek() == ']') {
s.get();
break;
}
json value = parse_json_helper(s);
array.push_back(value);
s >> std::ws >> ch;
if (ch == ']') {
break;
} else if (ch != ',') {
throw std::invalid_argument("[json] Invalid json: expected ','.");
}
}
return array;
// null
} else if (ch == 'n') {
std::string str = "";
for (int i = 0; i < 3; i++) {
str += s.get();
}
if (str != "ull") {
Comment on lines +123 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::string str = "";
for (int i = 0; i < 3; i++) {
str += s.get();
}
if (str != "ull") {
char buff[3] = {0};
if (s.read(buff, 3); !std::equal(buff, buff+3, "ull")) {

This is a total nitpick but the string may do a heap allocation there (may even do more than 1 given the +=) and it really is just to check 3 characters so I think it is a bit of a an overkill. You could even do if (s.get() != 'u' || s.get() != 'l' || s.get() != 'l') {...

throw std::invalid_argument("[json] Invalid keyword.");
}
return json(nullptr);
// true
} else if (ch == 't') {
std::string str = "";
for (int i = 0; i < 3; i++) {
str += s.get();
}
if (str != "rue") {
throw std::invalid_argument("[json] Invalid keyword.");
}
return json(true);
// false
} else if (ch == 'f') {
std::string str = "";
for (int i = 0; i < 4; i++) {
str += s.get();
}
if (str != "alse") {
throw std::invalid_argument("[json] Invalid keyword.");
}
return json(false);
// string
} else if (ch == '"') {
return json(parse_json_string(s));
// number
} else if (ch == '-' || std::isdigit(ch)) {
s.seekg(-1, std::ios::cur);
return parse_json_number(s);
} else {
throw std::invalid_argument("[json] Invalid json: Unrecognized value.");
}
}

void apply_indent(std::ostream& os, int indent) {
for (int i = 0; i < indent; i++) {
os << " ";
}
}

void print_json(std::ostream& os, const json& obj, int indent) {
os << std::boolalpha;
if (obj.is<json::json_array>()) {
os << "[" << std::endl;
bool first = true;
for (const json& val : obj) {
if (!first) {
os << ",";
os << std::endl;
}
first = false;
apply_indent(os, indent + 2);
print_json(os, val, indent + 2);
}
os << std::endl;
apply_indent(os, indent);
os << "]";
} else if (obj.is<json::json_object>()) {
os << "{" << std::endl;
bool first = true;
for (const auto& [key, val] : obj.items()) {
if (!first) {
os << ",";
os << std::endl;
}
first = false;
apply_indent(os, indent + 2);
os << '"' << key << '"' << ": ";
print_json(os, val, indent + 2);
}
os << std::endl;
apply_indent(os, indent);
os << "}";
} else if (obj.is<double>()) {
double val = obj;
os << val;
} else if (obj.is<long>()) {
long val = obj;
os << val;
} else if (obj.is<bool>()) {
bool val = obj;
os << val;
} else if (obj.is<std::string>()) {
std::string val = obj;
// Escape special string characters
const std::vector<std::pair<char, std::string>> special_chars = {
{'\\', "\\\\"},
{'"', "\\\""},
{'/', "\\/"},
{'\b', "\\b"},
{'\f', "\\f"},
{'\n', "\\n"},
{'\r', "\\r"},
{'\t', "\\t"},
};
for (const auto& [ch, new_str] : special_chars) {
int pos = -1;
while ((pos = val.find(ch, pos + new_str.length())) !=
std::string::npos) {
val.replace(pos, 1, new_str);
}
}
os << '"' << val << '"';
} else if (obj.is<std::nullptr_t>()) {
os << "null";
}
}

std::ostream& operator<<(std::ostream& os, const json& obj) {
print_json(os, obj, 0);
return os;
}

json parse_json(std::istream& s) {
json result = parse_json_helper(s);
s.get();
if (!s.eof()) {
throw std::invalid_argument(
"[json] json finished before the end of the stream."
" Pass `allow_extra` to allow this.");
}
return result;
}

struct membuf : std::streambuf {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

membuf(char* s, int length) {
this->setg(s, s, s + length);
}
pos_type seekoff(
off_type off,
std::ios_base::seekdir dir,
std::ios_base::openmode which = std::ios_base::in) {
if (dir == std::ios_base::cur) {
gbump(off);
}
return gptr() - eback();
}
};

json parse_json(char* s, int length) {
membuf sbuf(s, length);
std::istream stream(&sbuf);
std::string os(s, length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgotten.

return parse_json(stream);
}

json parse_json(std::string& s) {
return parse_json(s.data(), s.size());
}

} // namespace io

} // namespace mlx::core
Loading