Skip to content

Commit

Permalink
[GPU] Added tests for LoRA with empty adapters and handling of incorr…
Browse files Browse the repository at this point in the history
…ect fusings
  • Loading branch information
Lyamin-Roman committed Oct 15, 2024
1 parent 9b4a8ef commit 21b20b5
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1048,17 +1048,25 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
std::swap(fused_idx, peer_idx);
}

auto fused_node = parents[fused_idx].first;
auto peer_node = parents[peer_idx].first;

// Avoid fusing with GEMM from the LoRA pattern, that can be optimized in case of empty adapters
if (parents[fused_idx].first->is_type<gemm>()) {
if (parents[peer_idx].first->is_type<fully_connected>() ||
(parents[peer_idx].first->is_type<crop>() &&
parents[peer_idx].first->get_dependency(0).is_type<fully_connected>())) {
std::swap(fused_idx, peer_idx);
if (fused_node->is_type<gemm>()) {
bool is_fc_lora = peer_node->is_type<fully_connected>() ||
(peer_node->is_type<crop>() &&
peer_node->get_dependency(0).is_type<fully_connected>());

bool is_conv_lora = peer_node->is_type<convolution>();

bool is_gemm_lora = peer_node->is_type<gemm>() &&
fused_node->get_input_pshape().rbegin()->is_dynamic();

if (is_fc_lora || is_conv_lora || is_gemm_lora) {
std::swap(peer_node, fused_node);
}
}

auto fused_node = parents[fused_idx].first;
auto peer_node = parents[peer_idx].first;
if (lo.get_optimization_attributes().use_onednn_impls && lo.is_primitive_implemented_for_onednn(*fused_node)) {
auto eltw_in_size = peer_node->get_output_layout();
if (eltw_in_size.is_dynamic()
Expand Down
6 changes: 5 additions & 1 deletion src/plugins/intel_gpu/src/graph/input_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ input_layout_inst::typed_primitive_inst(network& network, input_layout_node cons
event::ptr input_layout_inst::set_data(memory::ptr mem) {
auto ol = get_node_output_layout();

check_memory_to_set(*mem, ol);
bool empty_mem = mem->size() == 0 && (ol.is_dynamic() || ol.count() == 0);
if (!empty_mem) {
check_memory_to_set(*mem, ol);
}

event::ptr ev = nullptr;
auto& engine = get_network().get_engine();
auto& stream = get_network().get_stream();
Expand Down
14 changes: 13 additions & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1553,8 +1553,13 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
auto allocated_mem = d.first->output_memory_ptr();
auto actual_input_layout = d.first->get_output_layout();
auto& engine = _network.get_engine();
cldnn::memory_ptr actual_mem = nullptr;
// Need to use actual layout, not the fake aligned memory layout
auto actual_mem = engine.reinterpret_buffer(*allocated_mem, actual_input_layout);
if (actual_input_layout.count() != 0) {
actual_mem = engine.reinterpret_buffer(*allocated_mem, actual_input_layout);
} else {
actual_mem = engine.allocate_memory(actual_input_layout);
}
subgraph->set_input_data(d.first->id(), std::move(actual_mem));
}
}
Expand Down Expand Up @@ -2324,6 +2329,13 @@ bool primitive_inst::is_valid_fusion() const {
if (fused_eltwise_prims.empty())
return true;

if (_node->is_type<fully_connected>() || _node->is_type<gemm>() || _node->is_type<convolution>()) {
if (_impl_params->input_layouts[0].count() == 0 ||
_impl_params->input_layouts[1].count() == 0) {
return false;
}
}

if (_node->is_type<fully_connected>() && _node->get_preferred_impl_type() == impl_types::ocl) {
// TODO: Only fc_bf_tiled_kernel & ref kernel are verified for fused eltwise. To support more fc kernels for eltwise fusion
if (!_node->get_selected_impl())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "subgraph_tests/lora_pattern.hpp"

using namespace ov::test;

namespace {

INSTANTIATE_TEST_SUITE_P(smoke,
LoraPatternConvolution,
::testing::Values(ov::test::utils::DEVICE_GPU),
LoraPatternBase::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke,
LoraPatternMatmul,
::testing::Values(ov::test::utils::DEVICE_GPU),
LoraPatternBase::getTestCaseName);

} // namespace

0 comments on commit 21b20b5

Please sign in to comment.