Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
ov::PartialShape new_shape;
if (input_name.find("input_ids") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size});
} else if (input_name.find("token_type_ids") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size});
} else if (input_name.find("inputs_embeds") != std::string::npos) {
// NB: VLMs case, model accepts inputs_embeds[BATCH, SEQ_LEN, EMB_SIZE]
NPUW_ASSERT(input.get_partial_shape().size() == 3u);
Expand Down
79 changes: 68 additions & 11 deletions src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ std::pair<uint32_t, uint32_t> get_lora_dims_by_name(const std::string& state_nam
return std::make_pair(low_rank_dim, full_rank_dim);
}

void copy_to_right(const ov::SoPtr<ov::ITensor>& src, const ov::SoPtr<ov::ITensor>& dst) {
OPENVINO_ASSERT(src->get_byte_size() <= dst->get_byte_size());
std::copy_n(reinterpret_cast<uint8_t*>(src->data()),
src->get_byte_size(),
reinterpret_cast<uint8_t*>(dst->data()) + dst->get_byte_size() - src->get_byte_size());
}

constexpr uint32_t INPUT_IDS_SEQ_LEN_DIM = 1;

constexpr std::size_t kStartOutputKVCacheLayers = 1;
Expand Down Expand Up @@ -472,6 +479,10 @@ void ov::npuw::LLMInferRequest::apply_lora() {

void ov::npuw::LLMInferRequest::prepare_for_new_conversation() {
fill_tensor_bytes(m_prefill_request->get_tensor(m_prefill_in_ports.at(m_input_ids_name)), 0u);
if (auto totyids_port = m_prefill_in_ports.find(layer_names::token_type_ids);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: suggest to rename to type_ids_port

totyids_port != m_prefill_in_ports.end()) {
fill_tensor_bytes(m_prefill_request->get_tensor(totyids_port->second), 0u);
}
fill_tensor<int64_t>(m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::attention_mask)), 0);
fill_tensor<int64_t>(m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::position_ids)), 0);
m_npuw_llm_compiled_model->m_kvcache_desc.num_stored_tokens = 0u;
Expand Down Expand Up @@ -555,8 +566,8 @@ void ov::npuw::LLMInferRequest::copy_kvcache() {

void ov::npuw::LLMInferRequest::update_kvcache_for(
std::shared_ptr<ov::IAsyncInferRequest> request,
std::unordered_map<std::string, ov::Output<const ov::Node>> in_ports,
std::unordered_map<std::string, ov::Output<const ov::Node>> out_ports,
const std::unordered_map<std::string, ov::Output<const ov::Node>>& in_ports,
const std::unordered_map<std::string, ov::Output<const ov::Node>>& out_ports,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

uint32_t num_tokens) {
LOG_DEBUG("Store computed key and values for passed number of tokens in the input kv-cache"
" layers.");
Expand Down Expand Up @@ -629,7 +640,8 @@ void ov::npuw::LLMInferRequest::clear_chunk_prefill_kv_cache() {

void ov::npuw::LLMInferRequest::infer_chunked_prefill(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids) {
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> token_types_ids) {
LOG_DEBUG("Calling chunked inference for prefill model.");
LOG_BLOCK();

Expand All @@ -646,6 +658,12 @@ void ov::npuw::LLMInferRequest::infer_chunked_prefill(ov::SoPtr<ov::ITensor> inp
auto attn_mask_in_tensor = m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::attention_mask));
auto pos_ids_in_tensor = m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::position_ids));

auto to_ty_ids_in_tensor = ov::npuw::util::TensorPtr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: suggest to rename to just types_ids_in_tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: somewhere it type and not types


if (auto ttis_port = m_prefill_in_ports.find(layer_names::token_type_ids); ttis_port != m_prefill_in_ports.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: suggest to rename to type_ids_port

to_ty_ids_in_tensor = m_prefill_request->get_tensor(ttis_port->second);
}

auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc;

int64_t remaining_prompts = input_prompt_len;
Expand All @@ -663,10 +681,18 @@ void ov::npuw::LLMInferRequest::infer_chunked_prefill(ov::SoPtr<ov::ITensor> inp
// If the current prompt length is smaller than the chunk prompt length,
// clear the last chunk of the attention mask to ensure non-relevant tokens are masked
fill_tensor<int64_t>(attn_mask_in_tensor, 0, last_chunk_offset);
if (to_ty_ids_in_tensor) {
fill_tensor<int64_t>(to_ty_ids_in_tensor, 0, last_chunk_offset);
}
}
std::copy_n(attention_mask->data<int64_t>() + kvcache_desc.num_stored_tokens,
current_prompts_len,
attn_mask_in_tensor->data<int64_t>() + attn_mask_in_tensor->get_size() - current_prompts_len);
if (to_ty_ids_in_tensor) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are also updating the attention mask after each call on chunk before the next iteration in the end of the loop:

        // Update attention mask for the next iteration
        std::copy_n(attn_mask_in_tensor->data<int64_t>() + attn_mask_in_tensor->get_size() - current_prompts_len,
                    current_prompts_len,
                    attn_mask_in_tensor->data<int64_t>() + kvcache_desc.num_stored_tokens - current_prompts_len);

Do we need to do this also for the token_type_ids?

std::copy_n(token_types_ids->data<int64_t>() + kvcache_desc.num_stored_tokens,
current_prompts_len,
to_ty_ids_in_tensor->data<int64_t>() + to_ty_ids_in_tensor->get_size() - current_prompts_len);
}

auto current_prefill_bytes = current_prompts_len * input_ids_elem_size;
auto prefilled_bytes = kvcache_desc.num_stored_tokens * input_ids_elem_size;
Expand Down Expand Up @@ -719,7 +745,8 @@ void ov::npuw::LLMInferRequest::infer_chunked_prefill(ov::SoPtr<ov::ITensor> inp

void ov::npuw::LLMInferRequest::infer_whole_prefill(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids) {
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> token_types_ids) {
LOG_DEBUG("Calling inference for prefill model in a single launch.");
LOG_BLOCK();

Expand All @@ -736,6 +763,13 @@ void ov::npuw::LLMInferRequest::infer_whole_prefill(ov::SoPtr<ov::ITensor> input
attention_mask->get_size(),
padded_attention_mask->data<int64_t>() + padded_attention_mask->get_size() - attention_mask->get_size());

if (token_types_ids) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this check is different: in infer_chunked_prefill() we are checking existence of token_type_ids by ports while here by passed data?

auto padded_token_type_ids = m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::token_type_ids));

std::fill_n(reinterpret_cast<uint8_t*>(padded_token_type_ids->data()), token_types_ids->get_byte_size(), 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is redudant, prepare_for_new_conversation() should already handle it

copy_to_right(token_types_ids, padded_token_type_ids);
}

auto padded_position_ids = m_prefill_request->get_tensor(m_prefill_in_ports.at(layer_names::position_ids));
pad_position_ids(padded_position_ids, position_ids);

Expand All @@ -748,7 +782,8 @@ void ov::npuw::LLMInferRequest::infer_whole_prefill(ov::SoPtr<ov::ITensor> input

void ov::npuw::LLMInferRequest::infer_prefill(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids) {
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> token_types_ids) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Somewhere it is type and not types

LOG_DEBUG("Calling inference for prefill model...");
LOG_BLOCK();

Expand All @@ -764,9 +799,9 @@ void ov::npuw::LLMInferRequest::infer_prefill(ov::SoPtr<ov::ITensor> input_ids,

const bool use_chunk_prefill = m_npuw_llm_compiled_model->m_use_chunk_prefill;
if (use_chunk_prefill) {
infer_chunked_prefill(input_ids, attention_mask, position_ids);
infer_chunked_prefill(input_ids, attention_mask, position_ids, token_types_ids);
} else {
infer_whole_prefill(input_ids, attention_mask, position_ids);
infer_whole_prefill(input_ids, attention_mask, position_ids, token_types_ids);
}

if (m_lm_head_request) {
Expand All @@ -784,7 +819,8 @@ void ov::npuw::LLMInferRequest::infer_prefill(ov::SoPtr<ov::ITensor> input_ids,

void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids) {
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> token_types_ids) {
LOG_DEBUG("Calling inference for generate model...");
LOG_BLOCK();
auto& kvcache_desc = m_npuw_llm_compiled_model->m_kvcache_desc;
Expand Down Expand Up @@ -823,6 +859,11 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
input_ids->get_byte_size(),
reinterpret_cast<uint8_t*>(kv_input_ids->data()) + kv_input_ids->get_byte_size() - input_ids->get_byte_size());

if (token_types_ids) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to additionally fill that tensor with 0s under if (!m_generate_initialized) condition above (as done to other inputs), if this token_type_ids behave like attention_mask and contain data for the whole context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do them behave like attention_mask? (I thought they do because of the code in infer_chunked_prefill())

auto r_token_type_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(layer_names::token_type_ids));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why r?

copy_to_right(token_types_ids, r_token_type_ids);
}

// NOTE: Attention mask pattern for generate model requires the set of "1"
// units of length of the current prompt on the right (for present
// kv layers) and the set of "1" units of number of previously calculated
Expand Down Expand Up @@ -873,12 +914,28 @@ void ov::npuw::LLMInferRequest::infer() {
// FIXME: position_ids might be optional for some models!
auto position_ids = get_tensor(find_port_by_name(inputs, layer_names::position_ids).value());

auto token_types_ids = ov::npuw::util::TensorPtr();

if (auto ttis_port = find_port_by_name(inputs, layer_names::token_type_ids); ttis_port.has_value()) {
token_types_ids = get_tensor(ttis_port.value());
}

// NB: For VLM, the "inputs_embeds" contains float values (embeddings)
OPENVINO_ASSERT(ov::element::f32 == input_ids->get_element_type() ||
ov::element::i64 == input_ids->get_element_type());
OPENVINO_ASSERT(ov::element::i64 == attention_mask->get_element_type());
OPENVINO_ASSERT(ov::element::i64 == position_ids->get_element_type());

if (m_first_run) {
// Most of the models have position_ids->data<int64_t>()[0] == 0 for the first infer
// But gemma3 has it's == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: do we need 's here?

// We need to store original zero position id in order to distinguish between prefill and generate stage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be we can call it first position id, but feel free to skip this comment

// While in most of the cases we need to do prefill only once, it is not true for chat mode
// where we need to do prefill on each user input.
m_zero_position_id = position_ids->data<int64_t>()[0];
m_first_run = false;
}

// NB: Check the sequence length provided for input_ids
// and start position idx in order to distinguish prefill
// and generate stages.
Expand All @@ -901,11 +958,11 @@ void ov::npuw::LLMInferRequest::infer() {
// The outcome of two items is that prefill and generate stages
// can be safely differentiated by start position id for
// both main and draft models.
if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] > 1 && position_ids->data<int64_t>()[0] == 0) {
infer_prefill(input_ids, attention_mask, position_ids);
if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] > 1 && position_ids->data<int64_t>()[0] == m_zero_position_id) {
infer_prefill(input_ids, attention_mask, position_ids, token_types_ids);
} else {
trim_kvcache_for_speculative_decoding(position_ids);
infer_generate(input_ids, attention_mask, position_ids);
infer_generate(input_ids, attention_mask, position_ids, token_types_ids);
}
}

Expand Down
21 changes: 15 additions & 6 deletions src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class LLMInferRequest final : public ov::ISyncInferRequest {
static constexpr const char* past_key_values = "past_key_values";
static constexpr const char* output_embeds = "npuw_output_embed";
static constexpr const char* logits = "logits";
static constexpr const char* token_type_ids = "token_type_ids";
};

explicit LLMInferRequest(const std::shared_ptr<ov::npuw::LLMCompiledModel>& compiled_model);
Expand All @@ -49,26 +50,30 @@ class LLMInferRequest final : public ov::ISyncInferRequest {
void init_tensor(const ov::Output<const ov::Node>& port);
void copy_kvcache();
void update_kvcache_for(std::shared_ptr<ov::IAsyncInferRequest> request,
std::unordered_map<std::string, ov::Output<const ov::Node>> in_ports,
std::unordered_map<std::string, ov::Output<const ov::Node>> out_ports,
const std::unordered_map<std::string, ov::Output<const ov::Node>>& in_ports,
const std::unordered_map<std::string, ov::Output<const ov::Node>>& out_ports,
uint32_t tokens);
void trim_kvcache_for_speculative_decoding(ov::SoPtr<ov::ITensor> position_ids);

void infer_chunked_prefill(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids);
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> input_token_ids);

void infer_whole_prefill(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids);
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> input_token_ids);

void infer_prefill(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids);
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> input_token_ids);

void infer_generate(ov::SoPtr<ov::ITensor> input_ids,
ov::SoPtr<ov::ITensor> attention_mask,
ov::SoPtr<ov::ITensor> position_ids);
ov::SoPtr<ov::ITensor> position_ids,
ov::SoPtr<ov::ITensor> input_token_ids);

std::shared_ptr<ov::IAsyncInferRequest> m_kvcache_request;
std::shared_ptr<ov::IAsyncInferRequest> m_prefill_request;
Expand All @@ -88,6 +93,10 @@ class LLMInferRequest final : public ov::ISyncInferRequest {

bool m_generate_initialized = false;

bool m_first_run = true;

int64_t m_zero_position_id = 0;

// Support LoRA
std::vector<ov::SoPtr<ov::IVariableState>> m_variableStates;
void init_lora_states();
Expand Down
Loading