Skip to content

Commit 8ee14d5

Browse files
committed
Implement loading snapshot from stream
1 parent db56177 commit 8ee14d5

File tree

4 files changed

+42
-14
lines changed

4 files changed

+42
-14
lines changed

include/neural-graphics-primitives/testbed.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ class Testbed {
359359
);
360360
void visualize_nerf_cameras(ImDrawList* list, const mat4& world2proj);
361361
fs::path find_network_config(const fs::path& network_config_path);
362+
nlohmann::json load_network_config(std::istream& stream, bool is_compressed);
362363
nlohmann::json load_network_config(const fs::path& network_config_path);
363364
void reload_network_from_file(const fs::path& path = "");
364365
void reload_network_from_json(const nlohmann::json& json, const std::string& config_base_path=""); // config_base_path is needed so that if the passed in json uses the 'parent' feature, we know where to look... be sure to use a filename, or if a directory, end with a trailing slash
@@ -484,7 +485,9 @@ class Testbed {
484485
vec2 fov_xy() const ;
485486
void set_fov_xy(const vec2& val);
486487
void save_snapshot(const fs::path& path, bool include_optimizer_state, bool compress);
488+
void load_snapshot(nlohmann::json config);
487489
void load_snapshot(const fs::path& path);
490+
void load_snapshot(std::istream& stream, bool is_compressed = true);
488491
CameraKeyframe copy_camera_to_keyframe() const;
489492
void set_camera_from_keyframe(const CameraKeyframe& k);
490493
void set_camera_from_time(float t);

src/main.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ int main_func(const std::vector<std::string>& arguments) {
159159
}
160160

161161
if (snapshot_flag) {
162-
testbed.load_snapshot(get(snapshot_flag));
162+
testbed.load_snapshot(static_cast<fs::path>(get(snapshot_flag)));
163163
} else if (network_config_flag) {
164164
testbed.reload_network_from_file(get(network_config_flag));
165165
}

src/python_api.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ PYBIND11_MODULE(pyngp, m) {
443443
.def("n_params", &Testbed::n_params, "Number of trainable parameters")
444444
.def("n_encoding_params", &Testbed::n_encoding_params, "Number of trainable parameters in the encoding")
445445
.def("save_snapshot", &Testbed::save_snapshot, py::arg("path"), py::arg("include_optimizer_state")=false, py::arg("compress")=true, "Save a snapshot of the currently trained model. Optionally compressed (only when saving '.ingp' files).")
446-
.def("load_snapshot", &Testbed::load_snapshot, py::arg("path"), "Load a previously saved snapshot")
446+
.def("load_snapshot", py::overload_cast<const fs::path&>(&Testbed::load_snapshot), py::arg("path"), "Load a previously saved snapshot")
447447
.def("load_camera_path", &Testbed::load_camera_path, py::arg("path"), "Load a camera path")
448448
.def("load_file", &Testbed::load_file, py::arg("path"), "Load a file and automatically determine how to handle it. Can be a snapshot, dataset, network config, or camera path.")
449449
.def_property("loop_animation", &Testbed::loop_animation, &Testbed::set_loop_animation)

src/testbed.cu

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,14 @@ fs::path Testbed::find_network_config(const fs::path& network_config_path) {
233233
return network_config_path;
234234
}
235235

236+
json Testbed::load_network_config(std::istream& stream, bool is_compressed) {
237+
if (is_compressed) {
238+
zstr::istream zstream{stream};
239+
return json::from_msgpack(zstream);
240+
}
241+
return json::from_msgpack(stream);
242+
}
243+
236244
json Testbed::load_network_config(const fs::path& network_config_path) {
237245
bool is_snapshot = equals_case_insensitive(network_config_path.extension(), "msgpack") || equals_case_insensitive(network_config_path.extension(), "ingp");
238246
if (network_config_path.empty() || !network_config_path.exists()) {
@@ -1543,7 +1551,7 @@ void Testbed::imgui() {
15431551
ImGui::SameLine();
15441552
if (ImGui::Button("Load")) {
15451553
try {
1546-
load_snapshot(m_imgui.snapshot_path);
1554+
load_snapshot(static_cast<fs::path>(m_imgui.snapshot_path));
15471555
} catch (const std::exception& e) {
15481556
imgui_error_string = fmt::format("Failed to load snapshot: {}", e.what());
15491557
ImGui::OpenPopup("Error");
@@ -2339,14 +2347,14 @@ void Testbed::SecondWindow::draw(GLuint texture) {
23392347
}
23402348

23412349
void Testbed::init_opengl_shaders() {
2342-
static const char* shader_vert = R"(#version 140
2350+
static const char* shader_vert = R"glsl(#version 140
23432351
out vec2 UVs;
23442352
void main() {
23452353
UVs = vec2((gl_VertexID << 1) & 2, gl_VertexID & 2);
23462354
gl_Position = vec4(UVs * 2.0 - 1.0, 0.0, 1.0);
2347-
})";
2355+
})glsl";
23482356

2349-
static const char* shader_frag = R"(#version 140
2357+
static const char* shader_frag = R"glsl(#version 140
23502358
in vec2 UVs;
23512359
out vec4 frag_color;
23522360
uniform sampler2D rgba_texture;
@@ -2386,7 +2394,7 @@ void Testbed::init_opengl_shaders() {
23862394
//Uncomment the following line of code to visualize debug the depth buffer for debugging.
23872395
// frag_color = vec4(vec3(texture(depth_texture, tex_coords.xy).r), 1.0);
23882396
gl_FragDepth = texture(depth_texture, tex_coords.xy).r;
2389-
})";
2397+
})glsl";
23902398

23912399
GLuint vert = glCreateShader(GL_VERTEX_SHADER);
23922400
glShaderSource(vert, 1, &shader_vert, NULL);
@@ -4746,12 +4754,7 @@ void Testbed::save_snapshot(const fs::path& path, bool include_optimizer_state,
47464754
tlog::success() << "Saved snapshot '" << path.str() << "'";
47474755
}
47484756

4749-
void Testbed::load_snapshot(const fs::path& path) {
4750-
auto config = load_network_config(path);
4751-
if (!config.contains("snapshot")) {
4752-
throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())};
4753-
}
4754-
4757+
void Testbed::load_snapshot(nlohmann::json config) {
47554758
const auto& snapshot = config["snapshot"];
47564759
if (snapshot.value("version", 0) < SNAPSHOT_FORMAT_VERSION) {
47574760
throw std::runtime_error{"Snapshot uses an old format and can not be loaded."};
@@ -4841,7 +4844,6 @@ void Testbed::load_snapshot(const fs::path& path) {
48414844
m_render_aabb = snapshot.value("render_aabb", m_render_aabb);
48424845
if (snapshot.contains("up_dir")) from_json(snapshot.at("up_dir"), m_up_dir);
48434846

4844-
m_network_config_path = path;
48454847
m_network_config = std::move(config);
48464848

48474849
reset_network(false);
@@ -4868,6 +4870,29 @@ void Testbed::load_snapshot(const fs::path& path) {
48684870
set_all_devices_dirty();
48694871
}
48704872

4873+
void Testbed::load_snapshot(const fs::path& path) {
4874+
auto config = load_network_config(path);
4875+
if (!config.contains("snapshot")) {
4876+
throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())};
4877+
}
4878+
4879+
load_snapshot(std::move(config));
4880+
4881+
m_network_config_path = path;
4882+
}
4883+
4884+
void Testbed::load_snapshot(std::istream& stream, bool is_compressed) {
4885+
auto config = load_network_config(stream, is_compressed);
4886+
if (!config.contains("snapshot")) {
4887+
throw std::runtime_error{"Given stream does not contain a snapshot."};
4888+
}
4889+
4890+
load_snapshot(std::move(config));
4891+
4892+
// Network config path is unknown.
4893+
m_network_config_path = "";
4894+
}
4895+
48714896
Testbed::CudaDevice::CudaDevice(int id, bool is_primary) : m_id{id}, m_is_primary{is_primary} {
48724897
auto guard = device_guard();
48734898
m_stream = std::make_unique<StreamAndEvent>();

0 commit comments

Comments
 (0)