Skip to content

Commit

Permalink
Release networkStorage and compiledNetwork for import_model path
Browse files Browse the repository at this point in the history
  • Loading branch information
MirceaDan99 committed Oct 29, 2024
1 parent 8749462 commit af080e0
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ZeGraphExtWrappers final : public ZeGraphExtWrappersInterface {

void setGraphArgumentValue(ze_graph_handle_t graphHandle, uint32_t argi_, const void* argv) const override;

void initializeGraph(ze_graph_handle_t graphHandle, const Config& config) const override;
void initializeGraph(ze_graph_handle_t graphHandle, const Config& config, std::optional<std::vector<uint8_t>> /* unusedNetworkStorageOpt */) const override;

private:
template <ze_graph_ext_version_t T = TableExtension, std::enable_if_t<!NotSupportQuery(T), bool> = true>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ZeGraphExtWrappersInterface {

virtual void setGraphArgumentValue(ze_graph_handle_t graphHandle, uint32_t argi_, const void* argv) const = 0;

virtual void initializeGraph(ze_graph_handle_t graphHandle, const Config& config) const = 0;
virtual void initializeGraph(ze_graph_handle_t graphHandle, const Config& config, std::optional<std::vector<uint8_t>> /* unusedNetworkStorageOpt */) const = 0;

virtual ~ZeGraphExtWrappersInterface() = default;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void DriverGraph::initialize(const Config& config) {
set_workload_type(config.get<WORKLOAD_TYPE>());
}

_zeGraphExt->initializeGraph(_handle, config);
_zeGraphExt->initializeGraph(_handle, config, std::move(_networkStorage));

_logger.debug("Graph initialize finish");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeG
uint32_t groupOrdinal,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
const std::vector<uint8_t> compiledNetwork,
std::vector<uint8_t> compiledNetwork,
const Config& config)
: IGraph(graphHandle, std::move(metadata)),
_zeGraphExt(zeGraphExt),
Expand Down Expand Up @@ -98,7 +98,7 @@ void PluginGraph::initialize(const Config& config) {
set_workload_type(config.get<WORKLOAD_TYPE>());
}

_zeGraphExt->initializeGraph(_handle, config);
_zeGraphExt->initializeGraph(_handle, config, std::move(_compiledNetwork));

_logger.debug("Graph initialize finish");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void ZeGraphExtWrappers<TableExtension>::setGraphArgumentValue(ze_graph_handle_t
}

template <ze_graph_ext_version_t TableExtension>
void ZeGraphExtWrappers<TableExtension>::initializeGraph(ze_graph_handle_t graphHandle, const Config& config) const {
void ZeGraphExtWrappers<TableExtension>::initializeGraph(ze_graph_handle_t graphHandle, const Config& config, std::optional<std::vector<uint8_t>> /* unusedNetworkStorageOpt */) const {
if (_graphDdiTableExt.version() < ZE_GRAPH_EXT_VERSION_1_8) {
initialize_graph_through_command_list(graphHandle, config);
} else {
Expand Down

0 comments on commit af080e0

Please sign in to comment.