Skip to content

Commit

Permalink
Fix gpt (#499)
Browse files Browse the repository at this point in the history
* fix mirror bug

* fix launch_gpt_embedding bug
  • Loading branch information
hexisyztem authored Apr 4, 2023
1 parent 9eca20a commit b538578
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 15 deletions.
65 changes: 56 additions & 9 deletions lightseq/csrc/kernels/cuda/gptKernels.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ pos_offset: get real pos when decoding which gridDim.y=1
// __half* output, int* real_seq_len, int padding_id, int pos_offset);

template <typename T>
__global__ void kernel_gpt_embedding(const T* token_emb, const T* pos_emb,
const int* token_ids, T* output,
T* pad_mask_ptr, int* left_pad_len_ptr,
int batch_size, int beam_size, int seq_len,
int hidden_dim, int padding_id,
int max_step, int step_offset) {
__global__ void kernel_gpt_padding(const T* token_emb, const T* pos_emb,
const int* token_ids, T* output,
T* pad_mask_ptr, int* left_pad_len_ptr,
int batch_size, int beam_size, int seq_len,
int hidden_dim, int padding_id, int max_step,
int step_offset) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= batch_size * beam_size * seq_len * hidden_dim) {
return;
Expand All @@ -149,9 +149,28 @@ __global__ void kernel_gpt_embedding(const T* token_emb, const T* pos_emb,
output_val.z = 0.;
output_val.w = 0.;
}
}

__syncthreads();
template <typename T>
__global__ void kernel_gpt_embedding(const T* token_emb, const T* pos_emb,
const int* token_ids, T* output,
T* pad_mask_ptr, int* left_pad_len_ptr,
int batch_size, int beam_size, int seq_len,
int hidden_dim, int padding_id,
int max_step, int step_offset) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= batch_size * beam_size * seq_len * hidden_dim) {
return;
}
int batch_idx, beam_idx, seq_idx, state_idx;
decompose_4dim(idx, beam_size, seq_len, hidden_dim, &batch_idx, &beam_idx,
&seq_idx, &state_idx);
int token_idx = flat_3dim(batch_idx, beam_idx, seq_idx + step_offset,
beam_size, max_step);
int token_id = token_ids[token_idx];
int batch_beam_idx = batch_idx * beam_size + beam_idx;

float4& output_val = ((float4*)output)[idx];
if (token_id != padding_id) {
if (state_idx == 0) {
pad_mask_ptr[token_idx] = 0;
Expand All @@ -171,7 +190,7 @@ __global__ void kernel_gpt_embedding(const T* token_emb, const T* pos_emb,
}

template <>
__global__ void kernel_gpt_embedding<__half>(
__global__ void kernel_gpt_padding<__half>(
const __half* token_emb, const __half* pos_emb, const int* token_ids,
__half* output, __half* pad_mask_ptr, int* left_pad_len_ptr, int batch_size,
int beam_size, int seq_len, int hidden_dim, int padding_id, int max_step,
Expand Down Expand Up @@ -199,8 +218,27 @@ __global__ void kernel_gpt_embedding<__half>(
output_val.z = 0.f;
output_val.w = 0.f;
}
}

__syncthreads();
template <>
__global__ void kernel_gpt_embedding<__half>(
const __half* token_emb, const __half* pos_emb, const int* token_ids,
__half* output, __half* pad_mask_ptr, int* left_pad_len_ptr, int batch_size,
int beam_size, int seq_len, int hidden_dim, int padding_id, int max_step,
int step_offset) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= batch_size * beam_size * seq_len * hidden_dim) {
return;
}
int batch_idx, beam_idx, seq_idx, state_idx;
decompose_4dim(idx, beam_size, seq_len, hidden_dim, &batch_idx, &beam_idx,
&seq_idx, &state_idx);
int token_idx = flat_3dim(batch_idx, beam_idx, seq_idx + step_offset,
beam_size, max_step);
int token_id = token_ids[token_idx];
int batch_beam_idx = batch_idx * beam_size + beam_idx;

float4& output_val = ((float4*)output)[idx];

if (token_id != padding_id) {
if (state_idx == 0) {
Expand Down Expand Up @@ -242,6 +280,11 @@ void launch_gpt_embedding<float>(const float* token_emb, const float* pos_emb,
hidden_dim >>= 2;
int nele = (batch_size * beam_size * seq_len * hidden_dim);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_gpt_padding<float><<<nblock, MAX_THREADS, 0, stream>>>(
token_emb, pos_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr,
batch_size, beam_size, seq_len, hidden_dim, padding_id, max_step,
step_offset);

kernel_gpt_embedding<float><<<nblock, MAX_THREADS, 0, stream>>>(
token_emb, pos_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr,
batch_size, beam_size, seq_len, hidden_dim, padding_id, max_step,
Expand All @@ -265,6 +308,10 @@ void launch_gpt_embedding<__half>(const __half* token_emb,
hidden_dim >>= 3;
int nele = (batch_size * beam_size * seq_len * hidden_dim);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_gpt_padding<__half><<<nblock, MAX_THREADS, 0, stream>>>(
token_emb, pos_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr,
batch_size, beam_size, seq_len, hidden_dim, padding_id, max_step,
step_offset);
kernel_gpt_embedding<__half><<<nblock, MAX_THREADS, 0, stream>>>(
token_emb, pos_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr,
batch_size, beam_size, seq_len, hidden_dim, padding_id, max_step,
Expand Down
6 changes: 3 additions & 3 deletions lightseq/csrc/models/gpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Gpt::Gpt(const std::string weight_path, const int max_batch_size)
tw_.print_model_config();

/* --- step.3 initial input Variable node --- */
_inp_tokens = new Variable("inp_tokens", g_dtype<OpType_>());
_inp_tokens = new Variable("inp_tokens", g_dtype<int>());

/* --- step.4 inital operator & layer --- */
int max_batch_tokens = tw_._max_step * _max_batch_size;
Expand Down Expand Up @@ -106,8 +106,8 @@ Gpt::Gpt(const std::string weight_path, const int max_batch_size)

_out_tokens = std::get<0>(gen_outs);
_out_scores = std::get<1>(gen_outs);
_inp_tokens->malloc_memory(_max_batch_size * tw_._beam_size * tw_._max_step);
_out_tokens->malloc_memory(_max_batch_size * tw_._beam_size * tw_._max_step);
_inp_tokens->malloc_memory(max_batch_size * tw_._beam_size * tw_._max_step);
_out_tokens->malloc_memory(max_batch_size * tw_._beam_size * tw_._max_step);

_context_ptr->build();
printf("Finish construct network!\n");
Expand Down
1 change: 1 addition & 0 deletions lightseq/csrc/ops_new/includes/launch_gpt_emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class LaunchGptEmbOp : public Operator {
_batch_size = batch_size, _seq_len = seq_len, _offset = offset;
_result->set_shape({batch_size * seq_len, _hidden_dim});
_pad_mask->set_shape({batch_size, seq_len + offset});
_left_pad_len->set_shape({_batch_size, size_t(_beam_size)});
}

void forward() override;
Expand Down
8 changes: 5 additions & 3 deletions lightseq/csrc/ops_new/launch_gpt_emb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ std::tuple<Variable*, Variable*, Variable*> LaunchGptEmbOp<T>::operator()(

size_t max_size = _max_batch_tokens * _hidden_dim;

_result = new Variable("LaunchGptEmbOp_out", _max_batch_tokens * _hidden_dim,
g_dtype<T>());
_pad_mask = new Variable("_pad_mask", _max_batch_tokens, g_dtype<T>());
_result =
new Variable("LaunchGptEmbOp_out",
_max_batch_tokens * _hidden_dim * _beam_size, g_dtype<T>());
_pad_mask =
new Variable("_pad_mask", _max_batch_tokens * _beam_size, g_dtype<T>());

_left_pad_len = new Variable("_left_pad_len", _max_batch_size * _beam_size,
g_dtype<int>(), cuda::DataType::kNotSupported,
Expand Down

0 comments on commit b538578

Please sign in to comment.