Skip to content

Commit

Permalink
[GPU] Skip redundant gather in stateful model (#21681)
Browse files Browse the repository at this point in the history
* Skip redundant gather in stateful model

* Fix memory reuse issue for node skipped at runtime.
If the node is not marked as can_be_optimized at build time, memory dep is not properly applied
=> So it can cause the wrong memory reuse
  • Loading branch information
yeonbok authored Dec 19, 2023
1 parent c59498b commit b770780
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 1 deletion.
19 changes: 19 additions & 0 deletions src/plugins/intel_gpu/src/graph/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,25 @@ std::string gather_inst::to_string(gather_node const& node) {
return primitive_description.str();
}

void gather_inst::on_execute() {
update_output_memory();
}

void gather_inst::update_output_memory() {
if (!can_be_optimized())
return;
if (static_cast<bool>(_outputs[0]) && _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;

if (_node != nullptr)
build_deps();

GPU_DEBUG_TRACE_DETAIL << id() << " : update_output_memory with mem of input " << get_node().get_dependency(0).id()
<< " : " << input_memory_ptr()->buffer_ptr() << std::endl;
_outputs[0] = input_memory_ptr();
_mem_allocated = false;
}

gather_inst::typed_primitive_inst(network& network, gather_node const& node) : parent(network, node) {}

} // namespace cldnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "pass_manager.h"
#include "gather_inst.h"
#include "program_helpers.h"

using namespace cldnn;

void dynamic_shape_gather_opts::run(program& p) {
auto itr = p.get_processing_order().begin();
// Set gathers that might be skipped at runtime as can_be_optimized.
// If not set, memory dependency will not work for the nodes that are skipped at runtime
while (itr != p.get_processing_order().end()) {
auto& node = *itr++;
if (!node->is_type<gather>())
continue;
auto& gather_node = node->as<gather>();
// Check pattern
auto impl_params = gather_node.get_kernel_impl_params();
if (gather_node.has_fused_primitives() ||
(impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type) ||
gather_node.get_dependency(1).is_constant() || gather_node.get_dependency(1).is_type<data>())
continue;
auto idx_rank = impl_params->get_input_layout(1).get_partial_shape().size();

if (idx_rank > 1) {
continue;
}
auto axis = impl_params->typed_desc<gather>()->axis;
if (impl_params->get_input_layout(0).get_partial_shape()[axis] == -1
|| impl_params->get_input_layout(1).get_partial_shape()[0] == -1
|| impl_params->get_input_layout(0).get_partial_shape()[axis] == impl_params->get_input_layout(1).get_partial_shape()[0]) {
// May be skipepd
gather_node.can_be_optimized(true);
GPU_DEBUG_TRACE_DETAIL << "[dynamic_shape_gather_opts] : " << gather_node.id() << "can_be_optimized" << std::endl;
}
}
}
3 changes: 2 additions & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "register.hpp"
#include "implementation_map.hpp"
#include "concatenation_inst.h"
#include "gather_inst.h"

#include <vector>
#include <list>
Expand Down Expand Up @@ -81,7 +82,7 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
template<typename ImplType>
static std::unique_ptr<primitive_impl> create(const typed_program_node<PType>& arg, const kernel_impl_params& impl_param) {
// concat buffer fusing for dynamic shape is adaptively applied at runtime. So we need to build dynamic impl at build time.
if (impl_param.can_be_optimized() && !(impl_param.is_type<concatenation>() && impl_param.is_dynamic())) {
if (impl_param.can_be_optimized() && !((impl_param.is_type<concatenation>() || impl_param.is_type<gather>()) && impl_param.is_dynamic())) {
return make_unique<ImplType>(kernel_selector::kernel_data{});
}
auto kernel_params = ImplType::get_kernel_params(ImplType::static_canonicalize_shapes(impl_param));
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/gather_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class typed_primitive_inst<gather> : public typed_primitive_inst_base<gather> {
static std::string to_string(gather_node const& node);

typed_primitive_inst(network& network, gather_node const& desc);

void update_output_memory() override;

private:
void on_execute() override;
};

using gather_inst = typed_primitive_inst<gather>;
Expand Down
8 changes: 8 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,12 @@ class reorder_transfer : public base_pass {
void run(program& p) override;
};

class dynamic_shape_gather_opts : public base_pass {
public:
dynamic_shape_gather_opts() : base_pass("dynamic_shape_gather_opts") {}

private:
void run(program& p) override;
};

} // namespace cldnn
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class primitive_inst {

void build_deps();
void do_runtime_skip_reorder();
void do_runtime_skip_gather();
void do_runtime_in_place_concat();
void configure_shape_of_dependencies();

Expand Down
60 changes: 60 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (C) 2018-2023 Intel Corporation

// SPDX-License-Identifier: Apache-2.0
//
#include "program_helpers.h"
Expand All @@ -25,6 +26,7 @@
#include "assign_inst.h"
#include "read_value_inst.h"
#include "condition_inst.h"
#include "gather_inst.h"
#include "experimental_detectron_roi_feature_extractor_inst.hpp"
#include "implementation_map.hpp"
#include "graph_optimizer/prepare_buffer_fusing.h"
Expand Down Expand Up @@ -787,6 +789,63 @@ void primitive_inst::do_runtime_skip_reorder() {
}
}

void primitive_inst::do_runtime_skip_gather() {
// Check pattern
if (!get_node().is_type<gather>()
|| _impl_params->has_fused_primitives()
|| _impl_params->get_input_layout(0).data_type != _impl_params->get_output_layout().data_type
|| get_node().get_dependency(1).is_constant() || get_node().get_dependency(1).is_type<data>())
return;

GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_gather] " << id() << " : check optimizability" << std::endl;
auto input_shape = _impl_params->get_input_layout(0).get_shape();
auto axis = _impl_params->typed_desc<gather>()->axis;
auto idx_id = get_node().get_dependency(1).id();
auto idx_shape = _impl_params->get_input_layout(1).get_shape();
auto idx_rank = idx_shape.size();

if (idx_rank > 1) {
GPU_DEBUG_TRACE_DETAIL << "-- Cannot optimize becuase of its indices rank " << idx_shape.size() << std::endl;
return;
}

// Check runtime shape (need to reset can_be_optimized)
if (idx_shape[0] != input_shape[axis]) {
set_can_be_optimized(false);
GPU_DEBUG_TRACE_DETAIL << "--- Cannot optimize because input shape[0] " << idx_shape[0] << " != input_shape[axis]" << input_shape[axis] << std::endl;
return;
}

// If the overhead for checking the index is bigger than doing gather itself, it does not make sense for skipping
const int MAX_INDICES_SIZE = 10*1024;
if (input_shape[axis] > MAX_INDICES_SIZE) {
GPU_DEBUG_TRACE_DETAIL << "--- Cannot optimize becuase data length along with the axis is too big" << input_shape[axis] << std::endl;
set_can_be_optimized(false);
return;
}
if (input_shape[axis] != 1) {
auto queue_type = get_network().get_stream().get_queue_type();
if (queue_type == QueueTypes::out_of_order)
get_network().get_stream().wait_for_events({_network.get_primitive_event(idx_id)});
else
_network.get_stream().finish();
mem_lock<int32_t, mem_lock_type::read> idx_data(dep_memory_ptr(1), _network.get_stream());
for (int64_t i = 0; i < static_cast<int32_t>(idx_shape[0]); ++i) {
if (idx_data[i] != i) {
GPU_DEBUG_TRACE_DETAIL << "--- Cannot optimize becuase idx_data [" << i << "] (" << idx_data[i] << ") != " << i << std::endl;
set_can_be_optimized(false);
return;
}
}
}
GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_gather] " << id() << " : can_be_optimized" << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Input layout : " << _impl_params->get_input_layout(0).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Indices layout : " << _impl_params->get_input_layout(1).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Gather axis : " << axis << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Output layout : " << _impl_params->get_output_layout().to_short_string() << std::endl;
set_can_be_optimized(true);
}

void primitive_inst::do_runtime_in_place_concat() {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_in_place_concat: " + id()));
GPU_DEBUG_GET_INSTANCE(debug_config);
Expand Down Expand Up @@ -871,6 +930,7 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
// Need to set can_be_optimized for user reorder at predecessor because
// if the user is can_be_optimized and output node then current nodes' output should be allocated to host.
do_runtime_skip_reorder();
do_runtime_skip_gather();
if (_impl_params->output_layouts[0].count() == 0) {
GPU_DEBUG_TRACE_DETAIL << id() << " : Skipping because output data is empty " << std::endl;
auto ev = get_network().get_stream().create_user_event(true);
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,9 @@ void program::pre_optimize_graph(bool is_internal) {
// Call shape_of subgraphs markup second time to update newely added nodes after graph
// optimization passes
apply_opt_pass<mark_shape_of_subgraphs>(true);

// Set gathers that might be skipped at runtime as can_be_optimized.
apply_opt_pass<dynamic_shape_gather_opts>();
}

void program::post_optimize_graph(bool is_internal) {
Expand Down
129 changes: 129 additions & 0 deletions src/plugins/intel_gpu/tests/unit/dynamic_execution/stateful_model.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "test_utils.h"

#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/reorder.hpp>
#include <intel_gpu/primitives/data.hpp>
#include <intel_gpu/primitives/gather.hpp>
#include <intel_gpu/primitives/concatenation.hpp>

#include "program_wrapper.h"

#include <cmath>
#include <algorithm>

using namespace cldnn;
using namespace ::tests;

namespace stateful_model_tests {
TEST(stateful_model, skip_gather_at_runtime) {
auto& engine = get_test_engine();

auto input_kv_lay = layout{ov::PartialShape{-1, 32, -1, 128}, data_types::f32, format::bfyx};
auto input_present_lay = layout{ov::PartialShape{-1, 32, -1, 128}, data_types::f32, format::bfyx};
auto input_beam_idx_lay = layout{ov::PartialShape{-1}, data_types::i32, format::bfyx};

topology topology(input_layout("kv_cache", input_kv_lay),
input_layout("beam_idx", input_beam_idx_lay),
input_layout("present", input_present_lay),
gather("gather",
input_info("kv_cache"),
input_info("beam_idx"),
0, // axis
input_kv_lay.get_partial_shape().size(), // input rank
ov::Shape{}, // output shape
0, // batch_dim
true), // support_neg_ind
concatenation("concat", {input_info("gather"), input_info("present")}, 0),
reorder("reorder", input_info("concat"), format::bfyx, data_types::f32)); /*output padding*/

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));

network network(engine, topology, config);
auto gather_inst = network.get_primitive("gather");
ASSERT_EQ(gather_inst->get_node().can_be_optimized(), true);
ASSERT_EQ(gather_inst->can_be_optimized(), true);

auto KV_SIZE = 24;
auto BATCH_SIZE = 1;
auto kv_cache_mem = engine.allocate_memory({{KV_SIZE, 32, BATCH_SIZE, 128}, data_types::f32, format::bfyx});
auto present_mem = engine.allocate_memory({{1, 32, BATCH_SIZE, 128}, data_types::f32, format::bfyx});
auto beam_idx_mem = engine.allocate_memory({{KV_SIZE}, data_types::i32, format::bfyx});
std::vector<float> kv_input_data(kv_cache_mem->get_layout().count());
std::vector<float> present_input_data(present_mem->get_layout().count());
std::vector<int32_t> beam_idx_input_data(beam_idx_mem->get_layout().count());
std::iota(kv_input_data.begin(), kv_input_data.end(), 0.f);
std::iota(present_input_data.begin(), present_input_data.end(), 0.f);
std::iota(beam_idx_input_data.begin(), beam_idx_input_data.end(), 0);
set_values(kv_cache_mem, kv_input_data);
set_values(present_mem, present_input_data);
set_values(beam_idx_mem, beam_idx_input_data);

network.set_input_data("kv_cache", kv_cache_mem);
network.set_input_data("present", present_mem);
network.set_input_data("beam_idx", beam_idx_mem);
network.execute();
ASSERT_EQ(gather_inst->can_be_optimized(), true);
auto gather_output_mem = network.get_output_memory("gather");
cldnn::mem_lock<float, mem_lock_type::read> gather_output_ptr(gather_output_mem, get_test_stream());
for (size_t i = 0; i < gather_output_mem->get_layout().count(); ++i) {
ASSERT_EQ(gather_output_ptr[i], kv_input_data[i]);
}
}

TEST(stateful_model, not_skip_gather_at_runtime) {
auto& engine = get_test_engine();

auto input_kv_lay = layout{ov::PartialShape{-1, 32, -1, 128}, data_types::f32, format::bfyx};
auto input_present_lay = layout{ov::PartialShape{-1, 32, -1, 128}, data_types::f32, format::bfyx};
auto input_beam_idx_lay = layout{ov::PartialShape{-1}, data_types::i32, format::bfyx};

topology topology(input_layout("kv_cache", input_kv_lay),
input_layout("beam_idx", input_beam_idx_lay),
input_layout("present", input_present_lay),
gather("gather",
input_info("kv_cache"),
input_info("beam_idx"),
0, // axis
input_kv_lay.get_partial_shape().size(), // input rank
ov::Shape{}, // output shape
0, // batch_dim
true), // support_neg_ind
concatenation("concat", {input_info("gather"), input_info("present")}, 0),
reorder("reorder", input_info("concat"), format::bfyx, data_types::f32)); /*output padding*/

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));

network network(engine, topology, config);
auto gather_inst = network.get_primitive("gather");
ASSERT_EQ(gather_inst->get_node().can_be_optimized(), true);
ASSERT_EQ(gather_inst->can_be_optimized(), true);

auto KV_SIZE = 24;
auto BATCH_SIZE = 1;
auto kv_cache_mem = engine.allocate_memory({{KV_SIZE, 32, BATCH_SIZE, 128}, data_types::f32, format::bfyx});
auto present_mem = engine.allocate_memory({{1, 32, BATCH_SIZE, 128}, data_types::f32, format::bfyx});
auto beam_idx_mem = engine.allocate_memory({{KV_SIZE}, data_types::i32, format::bfyx});
std::vector<float> kv_input_data(kv_cache_mem->get_layout().count());
std::vector<float> present_input_data(present_mem->get_layout().count());
std::vector<int32_t> beam_idx_input_data(beam_idx_mem->get_layout().count());
std::iota(kv_input_data.begin(), kv_input_data.end(), 0.f);
std::iota(present_input_data.begin(), present_input_data.end(), 0.f);
std::iota(beam_idx_input_data.begin(), beam_idx_input_data.end(), 0);
std::swap(beam_idx_input_data[0], beam_idx_input_data[1]);
set_values(kv_cache_mem, kv_input_data);
set_values(present_mem, present_input_data);
set_values(beam_idx_mem, beam_idx_input_data);

network.set_input_data("kv_cache", kv_cache_mem);
network.set_input_data("present", present_mem);
network.set_input_data("beam_idx", beam_idx_mem);
network.execute();
ASSERT_EQ(gather_inst->can_be_optimized(), false);
}
} // stateful_model_tests

0 comments on commit b770780

Please sign in to comment.