Skip to content

Commit

Permalink
Update plugins api to load cached model with mmap buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi authored and MirceaDan99 committed Dec 10, 2024
1 parent f0da707 commit 8014e21
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#pragma once

#include "openvino/runtime/aligned_buffer.hpp"
#include "openvino/runtime/properties.hpp"
#include "openvino/runtime/threading/istreams_executor.hpp"

Expand Down Expand Up @@ -36,6 +37,12 @@ static constexpr Property<std::vector<PropertyName>, PropertyMutability::RO> cac
*/
static constexpr Property<bool, PropertyMutability::RO> caching_with_mmap{"CACHING_WITH_MMAP"};

/**
* @brief Property to get a ov::AlignedBuffer with cached model
* @ingroup ov_dev_api_plugin_api
*/
static constexpr Property<std::shared_ptr<ov::AlignedBuffer>, PropertyMutability::RW> cached_model_buffer{"CACHED_MODEL_BUFFER"};

/**
* @brief Allow to create exclusive_async_requests with one executor
* @ingroup ov_dev_api_plugin_api
Expand Down
6 changes: 3 additions & 3 deletions src/inference/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ICacheManager {
/**
* @brief Function passing created input stream
*/
using StreamReader = std::function<void(std::istream&)>;
using StreamReader = std::function<void(std::istream&, std::shared_ptr<ov::AlignedBuffer>)>;

/**
* @brief Callback when OpenVINO intends to read model from cache
Expand Down Expand Up @@ -143,10 +143,10 @@ class FileStorageCacheManager final : public ICacheManager {
std::make_shared<ov::SharedBuffer<std::shared_ptr<MappedMemory>>>(mmap->data(), mmap->size(), mmap);
OwningSharedStreamBuffer buf(shared_buffer);
std::istream stream(&buf);
reader(stream);
reader(stream, shared_buffer);
} else {
std::ifstream stream(blob_file_name, std::ios_base::binary);
reader(stream);
reader(stream, nullptr);
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/inference/src/dev/core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::load_model_from_cache(
cacheContent.blobId,
coreConfig.get_enable_mmap() && ov::util::contains(plugin.get_property(ov::internal::supported_properties),
ov::internal::caching_with_mmap),
[&](std::istream& networkStream) {
[&](std::istream& networkStream, std::shared_ptr<ov::AlignedBuffer> model_buffer) {
OV_ITT_SCOPE(FIRST_INFERENCE,
ov::itt::domains::LoadTime,
"Core::load_model_from_cache::ReadStreamAndImport");
Expand Down Expand Up @@ -1459,6 +1459,9 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::load_model_from_cache(
update_config[ov::weights_path.name()] = weights_path;
}
}
if (model_buffer) {
update_config[ov::internal::cached_model_buffer.name()] = model_buffer;
}
compiled_model = context ? plugin.import_model(networkStream, context, update_config)
: plugin.import_model(networkStream, update_config);
});
Expand Down
136 changes: 136 additions & 0 deletions src/inference/tests/functional/caching_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,142 @@ TEST_P(CachingTest, Load_threads) {
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

TEST_P(CachingTest, Load_mmap) {
ON_CALL(*mockPlugin, import_model(_, _)).WillByDefault(Invoke([&](std::istream& istr, const ov::AnyMap& config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (config.count(ov::internal::cached_model_buffer.name()))
model_buffer = config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
EXPECT_TRUE(model_buffer);

std::string name;
istr >> name;
char space;
istr.read(&space, 1);
std::lock_guard<std::mutex> lock(mock_creation_mutex);
return create_mock_compiled_model(m_models[name], mockPlugin);
}));

ON_CALL(*mockPlugin, get_property(ov::internal::supported_properties.name(), _))
.WillByDefault(Invoke([&](const std::string&, const ov::AnyMap&) {
return std::vector<ov::PropertyName>{ov::internal::caching_properties.name(),
ov::internal::caching_with_mmap.name()};
}));
EXPECT_CALL(*mockPlugin, get_property(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, query_model(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::device::architecture.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::internal::caching_properties.name(), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
int index = 0;
m_post_mock_net_callbacks.emplace_back([&](MockICompiledModelImpl& net) {
EXPECT_CALL(net, export_model(_)).Times(1);
});
MkDirGuard guard(m_cacheDir);
EXPECT_CALL(*mockPlugin, compile_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, compile_model(A<const std::shared_ptr<const ov::Model>&>(), _)).Times(1);
EXPECT_CALL(*mockPlugin, import_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, import_model(_, _)).Times(1);
testLoad([&](ov::Core& core) {
core.set_property({{ov::cache_dir.name(), m_cacheDir}});
m_testFunction(core);
m_testFunction(core);
});
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

TEST_P(CachingTest, Load_mmap_is_disabled) {
ON_CALL(*mockPlugin, import_model(_, _)).WillByDefault(Invoke([&](std::istream& istr, const ov::AnyMap& config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (config.count(ov::internal::cached_model_buffer.name()))
model_buffer = config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
EXPECT_FALSE(model_buffer);

std::string name;
istr >> name;
char space;
istr.read(&space, 1);
std::lock_guard<std::mutex> lock(mock_creation_mutex);
return create_mock_compiled_model(m_models[name], mockPlugin);
}));
ON_CALL(*mockPlugin, get_property(ov::internal::supported_properties.name(), _))
.WillByDefault(Invoke([&](const std::string&, const ov::AnyMap&) {
return std::vector<ov::PropertyName>{ov::internal::caching_properties.name(),
ov::internal::caching_with_mmap.name()};
}));
EXPECT_CALL(*mockPlugin, get_property(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, query_model(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::device::architecture.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::internal::caching_properties.name(), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
int index = 0;
m_post_mock_net_callbacks.emplace_back([&](MockICompiledModelImpl& net) {
EXPECT_CALL(net, export_model(_)).Times(1);
});
MkDirGuard guard(m_cacheDir);
EXPECT_CALL(*mockPlugin, compile_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, compile_model(A<const std::shared_ptr<const ov::Model>&>(), _)).Times(1);
EXPECT_CALL(*mockPlugin, import_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, import_model(_, _)).Times(1);
testLoad([&](ov::Core& core) {
core.set_property({{ov::cache_dir.name(), m_cacheDir}});
core.set_property({ov::enable_mmap(false)});
m_testFunction(core);
m_testFunction(core);
});
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

TEST_P(CachingTest, Load_mmap_is_not_supported_by_plugin) {
ON_CALL(*mockPlugin, import_model(_, _)).WillByDefault(Invoke([&](std::istream& istr, const ov::AnyMap& config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (config.count(ov::internal::cached_model_buffer.name()))
model_buffer = config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
EXPECT_FALSE(model_buffer);

std::string name;
istr >> name;
char space;
istr.read(&space, 1);
std::lock_guard<std::mutex> lock(mock_creation_mutex);
return create_mock_compiled_model(m_models[name], mockPlugin);
}));
EXPECT_CALL(*mockPlugin, get_property(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, query_model(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::device::architecture.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::internal::caching_properties.name(), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
int index = 0;
m_post_mock_net_callbacks.emplace_back([&](MockICompiledModelImpl& net) {
EXPECT_CALL(net, export_model(_)).Times(1);
});
MkDirGuard guard(m_cacheDir);
EXPECT_CALL(*mockPlugin, compile_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, compile_model(A<const std::shared_ptr<const ov::Model>&>(), _)).Times(1);
EXPECT_CALL(*mockPlugin, import_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, import_model(_, _)).Times(1);
testLoad([&](ov::Core& core) {
core.set_property({{ov::cache_dir.name(), m_cacheDir}});
core.set_property({ov::enable_mmap(true)});
m_testFunction(core);
m_testFunction(core);
});
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

#if defined(ENABLE_OV_IR_FRONTEND)

static std::string getTestCaseName(const testing::TestParamInfo<std::tuple<TestParam, std::string>>& obj) {
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,13 @@ std::shared_ptr<ov::ICompiledModel> Plugin::import_model(std::istream& model_str
decript_from_string = true;
}

std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (config.count(ov::internal::cached_model_buffer.name()))
model_buffer = config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();

ModelDeserializer deserializer(
model_stream,
model_buffer,
[this](const std::shared_ptr<ov::AlignedBuffer>& model, const std::shared_ptr<ov::AlignedBuffer>& weights) {
return get_core()->read_model(model, weights);
},
Expand Down
13 changes: 8 additions & 5 deletions src/plugins/intel_cpu/src/utils/serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ void ModelSerializer::operator<<(const std::shared_ptr<ov::Model>& model) {

////////// ModelDeserializer //////////

ModelDeserializer::ModelDeserializer(std::istream& model_stream, ModelBuilder fn, const CacheDecrypt& decrypt_fn, bool decript_from_string)
: m_istream(model_stream), m_model_builder(std::move(fn)), m_decript_from_string(decript_from_string) {
ModelDeserializer::ModelDeserializer(std::istream& model_stream,
std::shared_ptr<ov::AlignedBuffer> model_buffer,
ModelBuilder fn,
const CacheDecrypt& decrypt_fn,
bool decript_from_string)
: m_istream(model_stream), m_model_builder(std::move(fn)), m_decript_from_string(decript_from_string), m_model_buffer(model_buffer) {
if (m_decript_from_string) {
m_cache_decrypt.m_decrypt_str = decrypt_fn.m_decrypt_str;
} else {
Expand All @@ -42,9 +46,8 @@ ModelDeserializer::ModelDeserializer(std::istream& model_stream, ModelBuilder fn
void ModelDeserializer::set_info(pugi::xml_node& root, std::shared_ptr<ov::Model>& model) {}

void ModelDeserializer::operator>>(std::shared_ptr<ov::Model>& model) {
if (auto mmap_buffer = dynamic_cast<OwningSharedStreamBuffer*>(m_istream.rdbuf())) {
auto buffer = mmap_buffer->get_buffer();
process_mmap(model, buffer);
if (m_model_buffer) {
process_mmap(model, m_model_buffer);
} else {
process_stream(model);
}
Expand Down
7 changes: 6 additions & 1 deletion src/plugins/intel_cpu/src/utils/serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ class ModelDeserializer {
public:
typedef std::function<std::shared_ptr<ov::Model>(const std::shared_ptr<ov::AlignedBuffer>&, const std::shared_ptr<ov::AlignedBuffer>&)> ModelBuilder;

ModelDeserializer(std::istream& model, ModelBuilder fn, const CacheDecrypt& encrypt_fn, bool decript_from_string);
ModelDeserializer(std::istream& model,
std::shared_ptr<ov::AlignedBuffer> model_buffer,
ModelBuilder fn,
const CacheDecrypt& encrypt_fn,
bool decript_from_string);

virtual ~ModelDeserializer() = default;

Expand All @@ -48,6 +52,7 @@ class ModelDeserializer {
ModelBuilder m_model_builder;
CacheDecrypt m_cache_decrypt;
bool m_decript_from_string;
std::shared_ptr<ov::AlignedBuffer> m_model_buffer;
};

} // namespace intel_cpu
Expand Down

0 comments on commit 8014e21

Please sign in to comment.