Skip to content

Commit

Permalink
Modify iCompiler interface to avoid std::vector allocation for `e…
Browse files Browse the repository at this point in the history
…xport_model`
  • Loading branch information
MirceaDan99 committed Sep 24, 2024
1 parent e725500 commit 13aff9b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ class ICompiler : public std::enable_shared_from_this<ICompiler> {
// Driver compiler can use this to release graphHandle, if we do not have executor
virtual void release([[maybe_unused]] std::shared_ptr<const NetworkDescription> networkDescription){};

virtual std::vector<uint8_t> getCompiledNetwork(std::shared_ptr<const NetworkDescription> networkDescription) {
return networkDescription->compiledNetwork;
virtual std::pair<const uint8_t*, size_t> getCompiledNetwork(std::shared_ptr<const NetworkDescription> networkDescription) {
return {networkDescription->compiledNetwork.data(), networkDescription->compiledNetwork.size()};
}

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LevelZeroCompilerAdapter final : public ICompiler {

void release(std::shared_ptr<const NetworkDescription> networkDescription) override;

std::vector<uint8_t> getCompiledNetwork(std::shared_ptr<const NetworkDescription> networkDescription) override;
std::pair<const uint8_t*, size_t> getCompiledNetwork(std::shared_ptr<const NetworkDescription> networkDescription) override;

private:
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class LevelZeroCompilerInDriver final : public ICompiler {

void release(std::shared_ptr<const NetworkDescription> networkDescription) override;

std::vector<uint8_t> getCompiledNetwork(std::shared_ptr<const NetworkDescription> networkDescription) override;
std::pair<const uint8_t*, size_t> getCompiledNetwork(std::shared_ptr<const NetworkDescription> networkDescription) override;

private:
NetworkMetadata getNetworkMeta(ze_graph_handle_t graphHandle) const;
Expand All @@ -131,12 +131,14 @@ class LevelZeroCompilerInDriver final : public ICompiler {
template <typename T = TableExtension, typename std::enable_if_t<UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob) const;
std::shared_ptr<const NetworkDescription> networkDescription,
uint8_t** blobPtr, size_t* blobSize) const;

template <typename T = TableExtension, typename std::enable_if_t<!UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob) const;
std::shared_ptr<const NetworkDescription>,
uint8_t** blobPtr, size_t* blobSize) const;

template <typename T = TableExtension, typename std::enable_if_t<SupportAPIGraphQueryNetworkV2(T), bool> = true>
ze_result_t seriazlideIRModelAndQueryNetworkCreateV2(const std::shared_ptr<const ov::Model>& model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void LevelZeroCompilerAdapter::release(std::shared_ptr<const NetworkDescription>
apiAdapter->release(std::move(networkDescription));
}

std::vector<uint8_t> LevelZeroCompilerAdapter::getCompiledNetwork(
std::pair<const uint8_t*, size_t> LevelZeroCompilerAdapter::getCompiledNetwork(
std::shared_ptr<const NetworkDescription> networkDescription) {
_logger.info("getCompiledNetwork - using adapter to perform getCompiledNetwork(networkDescription)");
return apiAdapter->getCompiledNetwork(std::move(networkDescription));
Expand Down
33 changes: 15 additions & 18 deletions src/plugins/intel_npu/src/compiler/src/zero_compiler_in_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,10 @@ template <typename TableExtension>
template <typename T, std::enable_if_t<UseCopyForNativeBinary(T), bool>>
void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob) const {
std::shared_ptr<const NetworkDescription> networkDescription,
uint8_t** blobPtr, size_t* blobSize) const {
// Get blob size first
size_t blobSize = -1;

auto result = _graphDdiTableExt.pfnGetNativeBinary(graphHandle, &blobSize, nullptr);
auto result = _graphDdiTableExt.pfnGetNativeBinary(graphHandle, blobSize, nullptr);

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
"Failed to compile network. L0 pfnGetNativeBinary get blob size",
Expand All @@ -382,9 +381,8 @@ void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditabl
". ",
getLatestBuildError());

blob.resize(blobSize);
// Get blob data
result = _graphDdiTableExt.pfnGetNativeBinary(graphHandle, &blobSize, blob.data());
result = _graphDdiTableExt.pfnGetNativeBinary(graphHandle, blobSize, std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.data());

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
"Failed to compile network. L0 pfnGetNativeBinary get blob data",
Expand All @@ -395,18 +393,18 @@ void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditabl
uint64_t(result),
". ",
getLatestBuildError());

*blobPtr = std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.data();
}

template <typename TableExtension>
template <typename T, std::enable_if_t<!UseCopyForNativeBinary(T), bool>>
void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob) const {
std::shared_ptr<const NetworkDescription>,
uint8_t** blobPtr, size_t* blobSize) const {
// Get blob ptr and size
uint8_t* blobPtr;
size_t blobSize = -1;

auto result = _graphDdiTableExt.pfnGetNativeBinary2(graphHandle, &blobSize, &blobPtr);
auto result = _graphDdiTableExt.pfnGetNativeBinary2(graphHandle, blobSize, blobPtr);

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
"Failed to compile network. L0 pfnGetNativeBinary get blob size",
Expand All @@ -417,26 +415,25 @@ void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditabl
uint64_t(result),
". ",
getLatestBuildError());

blob.assign(blobPtr, blobPtr + blobSize);
}

template <typename TableExtension>
std::vector<uint8_t> LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
std::pair<const uint8_t*, size_t> LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
std::shared_ptr<const NetworkDescription> networkDescription) {
if (networkDescription->metadata.graphHandle != nullptr && networkDescription->compiledNetwork.size() == 0) {
_logger.info("LevelZeroCompilerInDriver getCompiledNetwork get blob from graphHandle");
ze_graph_handle_t graphHandle = static_cast<ze_graph_handle_t>(networkDescription->metadata.graphHandle);

std::vector<uint8_t> blob;
uint8_t* blobPtr;
size_t blobSize = -1;

getNativeBinary(_graphDdiTableExt, graphHandle, blob);
getNativeBinary(_graphDdiTableExt, graphHandle, networkDescription, &blobPtr, &blobSize);

_logger.info("LevelZeroCompilerInDriver getCompiledNetwork returning blob");
return blob;
return {blobPtr, blobSize};
} else {
_logger.info("return the blob from network description");
return networkDescription->compiledNetwork;
return {networkDescription->compiledNetwork.data(), networkDescription->compiledNetwork.size()};
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/plugins/intel_npu/src/plugin/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ constexpr std::string_view NO_EXECUTOR_FOR_INFERENCE =
"Can't create infer request!\n"
"Please make sure that the device is available. Only exports can be made.";

std::uint32_t hash(const std::vector<uint8_t>& data) {
std::uint32_t hash(std::pair<const uint8_t*, size_t> blob) {
std::uint32_t result = 1171117u;
for (const auto& c : data)
result = ((result << 7) + result) + static_cast<uint32_t>(c);
for (const uint8_t* it = blob.first; it != blob.first + blob.second; ++it) {
result = ((result << 7) + result) + static_cast<uint32_t>(*it);
}
return result;
}

Expand Down Expand Up @@ -140,11 +141,11 @@ std::shared_ptr<ov::ISyncInferRequest> CompiledModel::create_sync_infer_request(
void CompiledModel::export_model(std::ostream& stream) const {
_logger.debug("CompiledModel::export_model");
const auto&& blob = _compiler->getCompiledNetwork(_networkPtr);
stream.write(reinterpret_cast<const char*>(blob.data()), blob.size());
stream.write(reinterpret_cast<const char*>(blob.first), blob.second);

if (_logger.level() == ov::log::Level::INFO) {
std::stringstream str;
str << "Blob size: " << blob.size() << ", hash: " << std::hex << hash(blob);
str << "Blob size: " << blob.second << ", hash: " << std::hex << hash(blob);
_logger.info(str.str().c_str());

if (!stream) {
Expand Down

0 comments on commit 13aff9b

Please sign in to comment.