Skip to content

Commit

Permalink
fix matrix3d int type error for windows (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
xieqihui authored Feb 23, 2024
1 parent c2d2de4 commit cbde34b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ struct Fp32GPTBigCodeDecoder_output Fp32GPTBigCodeDecoder::forward(const struct
// Position embeddings
// Matrix3D<float> pos_embeds = this->get_position_embed(sqlen, past_key_values_length);
#ifdef _WIN32
std::vector<float> position_ids_buf_vec(sqlen);
float *position_ids_buf = &position_ids_buf_vec.front();
std::vector<int> position_ids_buf_vec(sqlen);
int *position_ids_buf = &position_ids_buf_vec.front();
std::vector<float> pos_embeds_buf_vec(sqlen * this->embed_dim);
float *pos_embeds_buf = &pos_embeds_buf_vec.front();
#else
Expand Down
4 changes: 2 additions & 2 deletions llm/src/nn_modules/Int4GPTBigCodeDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ struct Int4GPTBigCodeDecoder_output Int4GPTBigCodeDecoder::forward(const struct
// printf(("Before get_position_embed\n");
// Matrix3D<float> pos_embeds = this->get_position_embed(sqlen, past_key_values_length);
#ifdef _WIN32
std::vector<float> position_ids_buf_vec(sqlen);
float *position_ids_buf = &position_ids_buf_vec.front();
std::vector<int> position_ids_buf_vec(sqlen);
int *position_ids_buf = &position_ids_buf_vec.front();
std::vector<float> pos_embeds_buf_vec(sqlen * this->embed_dim);
float *pos_embeds_buf = &pos_embeds_buf_vec.front();
#else
Expand Down

0 comments on commit cbde34b

Please sign in to comment.