Skip to content

Commit 4dd3234

Browse files
authored
Bert crf (bytedance#391)
* add bert-crf files * support bert+crf weight loading * add bert crf model (with bug) * fix bug * fix bug * fix bug, restore memory.cpp to old version * update manager to latest * fix kernel bug when head_dim % 4 != 0 * delete emb lang_emb loading * modify crf bias to fp32 * fix cmake dtype bug * fix bug * rename cmake outputs
1 parent 9983de5 commit 4dd3234

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1470
-161
lines changed

CMakeLists.txt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,15 @@ if(USE_NEW_ARCH)
4040
set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86 87)
4141

4242
if(DEBUG_MODE)
43-
if(FP16_MODE)
44-
add_definitions(-DDEBUG_TYPE=FP16)
45-
else()
46-
add_definitions(-DDEBUG_TYPE=FP32)
47-
endif()
43+
add_definitions(-DDEBUG_MODE)
44+
message(STATUS "Build using debug mode")
45+
endif()
46+
47+
if(FP16_MODE)
48+
add_definitions(-DFP16_MODE)
49+
message(STATUS "Build using fp16 precision")
50+
else()
51+
message(STATUS "Build using fp32 precision")
4852
endif()
4953

5054
set(COMMON_HEADER_DIRS

lightseq/csrc/kernels/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ set(cuda_kernel_files
1313
softmax_kernels.cu
1414
softmax_kernels_new.cu
1515
transform_kernels.cu
16-
transform_kernels_new.cu)
16+
transform_kernels_new.cu
17+
crf.cu)
1718

1819
add_library(cuda_kernels STATIC ${cuda_kernel_files})
1920
target_include_directories(cuda_kernels INTERFACE includes)

lightseq/csrc/kernels/crf.cu

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ transition: [num_tags, num_tags]
4141
emission: [batch_size, seq_len, num_tags]
4242
mask: [batch_size, seq_len]
4343
0 for invalid token
44+
bias: [num_tags]
4445
best_score: [batch_size]
4546
history: [batch_size, seq_len, num_tags]:
4647
i, j, k store the tag of i-th batch, j-th step when
@@ -50,9 +51,9 @@ best_tag: [batch_size, seq_len]
5051
template <typename T>
5152
__global__ void ker_viterbi(const T* start_transition, const T* end_transition,
5253
const T* transition, const T* emission,
53-
const uint8_t* mask, float* best_score,
54-
int* history, int* best_tags, int num_tags,
55-
int seq_len) {
54+
const uint8_t* mask, const T* bias,
55+
float* best_score, int* history, int* best_tags,
56+
int num_tags, int seq_len) {
5657
cg::thread_block b = cg::this_thread_block();
5758
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
5859

@@ -63,9 +64,11 @@ __global__ void ker_viterbi(const T* start_transition, const T* end_transition,
6364
// step 1. compute first step's score
6465
if (threadIdx.y == 0) {
6566
for (int cur_tag = threadIdx.x; cur_tag < num_tags; cur_tag += blockDim.x) {
67+
float linear_bias = bias ? float(bias[cur_tag]) : float(0);
6668
s_score[cur_tag] =
67-
emission[flat_3dim(blockIdx.x, 0, cur_tag, seq_len, num_tags)] +
68-
start_transition[cur_tag];
69+
float(
70+
emission[flat_3dim(blockIdx.x, 0, cur_tag, seq_len, num_tags)]) +
71+
linear_bias + float(start_transition[cur_tag]);
6972
}
7073
}
7174
b.sync();
@@ -91,9 +94,12 @@ __global__ void ker_viterbi(const T* start_transition, const T* end_transition,
9194
g.sync();
9295
warp_reduce_max(g, &max_score, &idx);
9396
if (threadIdx.x == 0) {
97+
float linear_bias = bias ? float(bias[cur_tag]) : float(0);
9498
s_next_score[cur_tag] =
95-
max_score + (float)emission[flat_3dim(blockIdx.x, seq_idx, cur_tag,
96-
seq_len, num_tags)];
99+
max_score +
100+
float(emission[flat_3dim(blockIdx.x, seq_idx, cur_tag, seq_len,
101+
num_tags)]) +
102+
linear_bias;
97103
history[flat_3dim(blockIdx.x, seq_idx - 1, cur_tag, seq_len,
98104
num_tags)] = idx;
99105
}
@@ -144,13 +150,14 @@ void launch_viterbi<__half>(const __half* start_transition,
144150
const __half* transition, const __half* emission,
145151
const uint8_t* mask, float* best_score,
146152
int* history, int* best_tags, int num_tags,
147-
int seq_len, int batch_size, cudaStream_t stream) {
153+
int seq_len, int batch_size, cudaStream_t stream,
154+
const __half* bias) {
148155
dim3 grid_dim(batch_size);
149156
dim3 block_dim(WARP_SIZE, WARP_SIZE);
150157

151158
ker_viterbi<__half>
152159
<<<grid_dim, block_dim, 2 * num_tags * sizeof(float), stream>>>(
153-
start_transition, end_transition, transition, emission, mask,
160+
start_transition, end_transition, transition, emission, mask, bias,
154161
best_score, history, best_tags, num_tags, seq_len);
155162
}
156163

@@ -160,12 +167,12 @@ void launch_viterbi<float>(const float* start_transition,
160167
const float* emission, const uint8_t* mask,
161168
float* best_score, int* history, int* best_tags,
162169
int num_tags, int seq_len, int batch_size,
163-
cudaStream_t stream) {
170+
cudaStream_t stream, const float* bias) {
164171
dim3 grid_dim(batch_size);
165172
dim3 block_dim(WARP_SIZE, WARP_SIZE);
166173

167174
ker_viterbi<float>
168175
<<<grid_dim, block_dim, 2 * num_tags * sizeof(float), stream>>>(
169-
start_transition, end_transition, transition, emission, mask,
176+
start_transition, end_transition, transition, emission, mask, bias,
170177
best_score, history, best_tags, num_tags, seq_len);
171178
}

lightseq/csrc/kernels/includes/kernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ void launch_viterbi(const T *start_transition, const T *end_transition,
228228
const T *transition, const T *emission, const uint8_t *mask,
229229
float *best_score, int *history, int *best_tags,
230230
int num_tags, int seq_len, int batch_size,
231-
cudaStream_t stream);
231+
cudaStream_t stream, const T *bias = nullptr);
232232

233233
template <typename T>
234234
void launch_quantize(int8_t *q_ptr, uint8_t *clip_mask_ptr, float *alpha_ptr,

lightseq/csrc/kernels/transform_kernels.cu

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -471,35 +471,91 @@ __global__ void transform4d_0213(T *output, const T *input, int batch_size,
471471
res4[trg_offset] = input4[offset];
472472
}
473473

474+
/**
475+
@brief: transform4d_0213_slow
476+
Reshape the input matrix to merge the heads
477+
Not use float4 for dim % 4 != 0 or dim % 8 != 0
478+
479+
@thread
480+
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
481+
blockDim.x = max_block_thread
482+
483+
@param
484+
input: [trans_count, batch_size, nhead, seq_len, head_dim]
485+
output: [batch_size, seq_len, trans_count, nhead, head_dim]
486+
batch_size: the size of the current batch
487+
seq_len: the sequence length of the current batch
488+
hidden_dim: dim of the hidden tensor
489+
nhead: number of attention heads
490+
trans_count: 1 or 3, the count of matrice need to be transformed
491+
*/
492+
template <typename T>
493+
__global__ void transform4d_0213_slow(T *output, const T *input, int batch_size,
494+
int seq_len, int trans_count, int nhead,
495+
int head_dim, int num_all) {
496+
int offset = blockIdx.x * blockDim.x + threadIdx.x;
497+
if (offset >= num_all) {
498+
return;
499+
}
500+
int trans_id, batch_id, head_id, token_id, dim_id;
501+
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
502+
&batch_id, &head_id, &token_id, &dim_id);
503+
// [b, s, tc, nh, ad]
504+
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
505+
seq_len, trans_count, nhead, head_dim);
506+
507+
output[trg_offset] = input[offset];
508+
}
509+
474510
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
475511
template <>
476512
void launch_transform4d_0213<float>(float *output, const float *input,
477513
int batch_size, int seq_len, int hidden_dim,
478514
int nhead, int trans_count,
479515
cudaStream_t stream) {
480-
hidden_dim >>= 2;
481-
int head_dim = hidden_dim / nhead;
482-
int num_all = batch_size * seq_len * trans_count * hidden_dim;
483-
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
484-
485-
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
486-
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
487-
num_all);
516+
if ((hidden_dim / nhead) % 4 == 0) {
517+
hidden_dim >>= 2;
518+
int head_dim = hidden_dim / nhead;
519+
int num_all = batch_size * seq_len * trans_count * hidden_dim;
520+
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
521+
522+
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
523+
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
524+
num_all);
525+
} else {
526+
int head_dim = hidden_dim / nhead;
527+
int num_all = batch_size * seq_len * trans_count * hidden_dim;
528+
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
529+
530+
transform4d_0213_slow<float><<<nblock, MAX_THREADS, 0, stream>>>(
531+
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
532+
num_all);
533+
}
488534
}
489535

490536
template <>
491537
void launch_transform4d_0213<__half>(__half *output, const __half *input,
492538
int batch_size, int seq_len,
493539
int hidden_dim, int nhead, int trans_count,
494540
cudaStream_t stream) {
495-
hidden_dim >>= 3;
496-
int head_dim = hidden_dim / nhead;
497-
int num_all = batch_size * seq_len * trans_count * hidden_dim;
498-
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
499-
500-
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
501-
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
502-
num_all);
541+
if ((hidden_dim / nhead) % 8 == 0) {
542+
hidden_dim >>= 3;
543+
int head_dim = hidden_dim / nhead;
544+
int num_all = batch_size * seq_len * trans_count * hidden_dim;
545+
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
546+
547+
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
548+
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
549+
num_all);
550+
} else {
551+
int head_dim = hidden_dim / nhead;
552+
int num_all = batch_size * seq_len * trans_count * hidden_dim;
553+
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
554+
555+
transform4d_0213_slow<__half><<<nblock, MAX_THREADS, 0, stream>>>(
556+
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
557+
num_all);
558+
}
503559
}
504560

505561
/**

lightseq/csrc/kernels/transform_kernels_new.cu

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,36 +128,110 @@ __global__ void bias_add_transform_20314_new<__half>(
128128
}
129129
}
130130

131+
/**
132+
@brief: bias_add_transform_20314_new_slow
133+
Add bias to input, transform from
134+
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
135+
Not use float4 for dim % 4 != 0 or dim % 8 != 0
136+
137+
@thread
138+
gridDim.x = dim_0
139+
gridDim.y = dim_1
140+
gridDim.z = dim_2
141+
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
142+
143+
@param
144+
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
145+
bias: [dim_2, dim_3, dim_4]
146+
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
147+
*/
148+
template <typename T>
149+
__global__ void bias_add_transform_20314_new_slow(T *q_out, T *k_out, T *v_out,
150+
const T *input, const T *bias,
151+
int dim_3, int dim_4,
152+
int batch_ele) {
153+
int id0 = blockIdx.x;
154+
int id1 = blockIdx.y;
155+
int id2 = blockIdx.z;
156+
int dim_0 = gridDim.x;
157+
int dim_1 = gridDim.y;
158+
int dim_2 = gridDim.z;
159+
int dim_34 = dim_3 * dim_4;
160+
161+
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
162+
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
163+
int bias_offset = flat_2dim(id2, 0, dim_34);
164+
165+
float vres;
166+
167+
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
168+
vres = input[src_offset + i] + bias[bias_offset + i];
169+
170+
int id3 = i / dim_4;
171+
int id4 = i % dim_4;
172+
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
173+
int temp_offset = trg_offset + cur_trg_offset;
174+
if (temp_offset >= batch_ele * 2) {
175+
v_out[temp_offset - batch_ele * 2] = vres;
176+
} else if (temp_offset >= batch_ele) {
177+
k_out[temp_offset - batch_ele] = vres;
178+
} else {
179+
q_out[temp_offset] = vres;
180+
}
181+
}
182+
}
183+
131184
// [b, s, 3, h] -> [3, b, nh, s, ad]
132185
template <>
133186
void launch_bias_add_transform_20314_new<float>(
134187
float *q_out, float *k_out, float *v_out, const float *input,
135188
const float *bias, int dim_0, int dim_1, int dim_2, int dim_3, int dim_4,
136189
cudaStream_t stream) {
137-
dim_4 >>= 2;
138-
139-
dim3 grid_dim(dim_0, dim_1, dim_2);
140-
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
141-
int batch_ele = dim_0 * dim_1 * dim_3 * dim_4;
142-
143-
bias_add_transform_20314_new<float><<<grid_dim, block_dim, 0, stream>>>(
144-
q_out, k_out, v_out, input, bias, dim_3, dim_4, batch_ele);
190+
if (dim_4 % 4 == 0) {
191+
dim_4 >>= 2;
192+
193+
dim3 grid_dim(dim_0, dim_1, dim_2);
194+
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
195+
int batch_ele = dim_0 * dim_1 * dim_3 * dim_4;
196+
197+
bias_add_transform_20314_new<float><<<grid_dim, block_dim, 0, stream>>>(
198+
q_out, k_out, v_out, input, bias, dim_3, dim_4, batch_ele);
199+
} else {
200+
dim3 grid_dim(dim_0, dim_1, dim_2);
201+
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
202+
int batch_ele = dim_0 * dim_1 * dim_3 * dim_4;
203+
204+
bias_add_transform_20314_new_slow<float>
205+
<<<grid_dim, block_dim, 0, stream>>>(q_out, k_out, v_out, input, bias,
206+
dim_3, dim_4, batch_ele);
207+
}
145208
}
146209

147210
template <>
148211
void launch_bias_add_transform_20314_new<__half>(
149212
__half *q_out, __half *k_out, __half *v_out, const __half *input,
150213
const __half *bias, int dim_0, int dim_1, int dim_2, int dim_3, int dim_4,
151214
cudaStream_t stream) {
152-
dim_4 >>= 3;
215+
if (dim_4 % 8 == 0) {
216+
dim_4 >>= 3;
153217

154-
dim3 grid_dim(dim_0, dim_1, dim_2);
155-
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
218+
dim3 grid_dim(dim_0, dim_1, dim_2);
219+
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
156220

157-
int batch_ele = dim_0 * dim_1 * dim_3 * dim_4;
221+
int batch_ele = dim_0 * dim_1 * dim_3 * dim_4;
158222

159-
bias_add_transform_20314_new<__half><<<grid_dim, block_dim, 0, stream>>>(
160-
q_out, k_out, v_out, input, bias, dim_3, dim_4, batch_ele);
223+
bias_add_transform_20314_new<__half><<<grid_dim, block_dim, 0, stream>>>(
224+
q_out, k_out, v_out, input, bias, dim_3, dim_4, batch_ele);
225+
} else {
226+
dim3 grid_dim(dim_0, dim_1, dim_2);
227+
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
228+
229+
int batch_ele = dim_0 * dim_1 * dim_3 * dim_4;
230+
231+
bias_add_transform_20314_new_slow<__half>
232+
<<<grid_dim, block_dim, 0, stream>>>(q_out, k_out, v_out, input, bias,
233+
dim_3, dim_4, batch_ele);
234+
}
161235
}
162236

163237
/**

lightseq/csrc/layers_new/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
set(layers_files feed_forward_layer.cpp multihead_attention_layer.cpp
2-
transformer_encoder_layer.cpp)
2+
transformer_encoder_layer.cpp linear_layer.cpp crf_layer.cpp)
33

44
add_library(lightseq_layers STATIC ${layers_files})
55
target_link_libraries(lightseq_layers PUBLIC lightseq_operators lsflow)

0 commit comments

Comments
 (0)