diff --git a/include/neural-graphics-primitives/testbed.h b/include/neural-graphics-primitives/testbed.h index 75aa028dd..fc9b8992f 100644 --- a/include/neural-graphics-primitives/testbed.h +++ b/include/neural-graphics-primitives/testbed.h @@ -359,6 +359,7 @@ class Testbed { ); void visualize_nerf_cameras(ImDrawList* list, const mat4& world2proj); fs::path find_network_config(const fs::path& network_config_path); + nlohmann::json load_network_config(std::istream& stream, bool is_compressed); nlohmann::json load_network_config(const fs::path& network_config_path); void reload_network_from_file(const fs::path& path = ""); void reload_network_from_json(const nlohmann::json& json, const std::string& config_base_path=""); // config_base_path is needed so that if the passed in json uses the 'parent' feature, we know where to look... be sure to use a filename, or if a directory, end with a trailing slash @@ -484,7 +485,9 @@ class Testbed { vec2 fov_xy() const ; void set_fov_xy(const vec2& val); void save_snapshot(const fs::path& path, bool include_optimizer_state, bool compress); + void load_snapshot(nlohmann::json config); void load_snapshot(const fs::path& path); + void load_snapshot(std::istream& stream, bool is_compressed = true); CameraKeyframe copy_camera_to_keyframe() const; void set_camera_from_keyframe(const CameraKeyframe& k); void set_camera_from_time(float t); diff --git a/src/main.cu b/src/main.cu index 3494f94ff..c7c8bdee6 100644 --- a/src/main.cu +++ b/src/main.cu @@ -159,7 +159,7 @@ int main_func(const std::vector& arguments) { } if (snapshot_flag) { - testbed.load_snapshot(get(snapshot_flag)); + testbed.load_snapshot(static_cast(get(snapshot_flag))); } else if (network_config_flag) { testbed.reload_network_from_file(get(network_config_flag)); } diff --git a/src/python_api.cu b/src/python_api.cu index 851d7c9c4..b696d57da 100644 --- a/src/python_api.cu +++ b/src/python_api.cu @@ -443,7 +443,7 @@ PYBIND11_MODULE(pyngp, m) { .def("n_params", &Testbed::n_params, "Number of trainable parameters") .def("n_encoding_params", &Testbed::n_encoding_params, "Number of trainable parameters in the encoding") .def("save_snapshot", &Testbed::save_snapshot, py::arg("path"), py::arg("include_optimizer_state")=false, py::arg("compress")=true, "Save a snapshot of the currently trained model. Optionally compressed (only when saving '.ingp' files).") - .def("load_snapshot", &Testbed::load_snapshot, py::arg("path"), "Load a previously saved snapshot") + .def("load_snapshot", py::overload_cast(&Testbed::load_snapshot), py::arg("path"), "Load a previously saved snapshot") .def("load_camera_path", &Testbed::load_camera_path, py::arg("path"), "Load a camera path") .def("load_file", &Testbed::load_file, py::arg("path"), "Load a file and automatically determine how to handle it. Can be a snapshot, dataset, network config, or camera path.") .def_property("loop_animation", &Testbed::loop_animation, &Testbed::set_loop_animation) diff --git a/src/testbed.cu b/src/testbed.cu index 78b000cc6..a42734175 100644 --- a/src/testbed.cu +++ b/src/testbed.cu @@ -233,6 +233,14 @@ fs::path Testbed::find_network_config(const fs::path& network_config_path) { return network_config_path; } +json Testbed::load_network_config(std::istream& stream, bool is_compressed) { + if (is_compressed) { + zstr::istream zstream{stream}; + return json::from_msgpack(zstream); + } + return json::from_msgpack(stream); +} + json Testbed::load_network_config(const fs::path& network_config_path) { bool is_snapshot = equals_case_insensitive(network_config_path.extension(), "msgpack") || equals_case_insensitive(network_config_path.extension(), "ingp"); if (network_config_path.empty() || !network_config_path.exists()) { @@ -1543,7 +1551,7 @@ void Testbed::imgui() { ImGui::SameLine(); if (ImGui::Button("Load")) { try { - load_snapshot(m_imgui.snapshot_path); + load_snapshot(static_cast(m_imgui.snapshot_path)); } catch (const std::exception& e) { imgui_error_string = fmt::format("Failed to load snapshot: {}", e.what()); ImGui::OpenPopup("Error"); @@ -2339,14 +2347,14 @@ void Testbed::SecondWindow::draw(GLuint texture) { } void Testbed::init_opengl_shaders() { - static const char* shader_vert = R"(#version 140 + static const char* shader_vert = R"glsl(#version 140 out vec2 UVs; void main() { UVs = vec2((gl_VertexID << 1) & 2, gl_VertexID & 2); gl_Position = vec4(UVs * 2.0 - 1.0, 0.0, 1.0); - })"; + })glsl"; - static const char* shader_frag = R"(#version 140 + static const char* shader_frag = R"glsl(#version 140 in vec2 UVs; out vec4 frag_color; uniform sampler2D rgba_texture; @@ -2386,7 +2394,7 @@ void Testbed::init_opengl_shaders() { //Uncomment the following line of code to visualize debug the depth buffer for debugging. // frag_color = vec4(vec3(texture(depth_texture, tex_coords.xy).r), 1.0); gl_FragDepth = texture(depth_texture, tex_coords.xy).r; - })"; + })glsl"; GLuint vert = glCreateShader(GL_VERTEX_SHADER); glShaderSource(vert, 1, &shader_vert, NULL); @@ -4746,12 +4754,7 @@ void Testbed::save_snapshot(const fs::path& path, bool include_optimizer_state, tlog::success() << "Saved snapshot '" << path.str() << "'"; } -void Testbed::load_snapshot(const fs::path& path) { - auto config = load_network_config(path); - if (!config.contains("snapshot")) { - throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())}; - } - +void Testbed::load_snapshot(nlohmann::json config) { const auto& snapshot = config["snapshot"]; if (snapshot.value("version", 0) < SNAPSHOT_FORMAT_VERSION) { throw std::runtime_error{"Snapshot uses an old format and can not be loaded."}; @@ -4841,7 +4844,6 @@ void Testbed::load_snapshot(const fs::path& path) { m_render_aabb = snapshot.value("render_aabb", m_render_aabb); if (snapshot.contains("up_dir")) from_json(snapshot.at("up_dir"), m_up_dir); - m_network_config_path = path; m_network_config = std::move(config); reset_network(false); @@ -4868,6 +4870,29 @@ void Testbed::load_snapshot(const fs::path& path) { set_all_devices_dirty(); } +void Testbed::load_snapshot(const fs::path& path) { + auto config = load_network_config(path); + if (!config.contains("snapshot")) { + throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())}; + } + + load_snapshot(std::move(config)); + + m_network_config_path = path; +} + +void Testbed::load_snapshot(std::istream& stream, bool is_compressed) { + auto config = load_network_config(stream, is_compressed); + if (!config.contains("snapshot")) { + throw std::runtime_error{"Given stream does not contain a snapshot."}; + } + + load_snapshot(std::move(config)); + + // Network config path is unknown. + m_network_config_path = ""; +} + Testbed::CudaDevice::CudaDevice(int id, bool is_primary) : m_id{id}, m_is_primary{is_primary} { auto guard = device_guard(); m_stream = std::make_unique();