Skip to content

Commit f8fe4e7

Browse files
authored
fix: add flash attn support check (leejet#803)
1 parent 1c07fb6 commit f8fe4e7

File tree

11 files changed

+191
-120
lines changed

11 files changed

+191
-120
lines changed

clip.hpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,14 @@ struct CLIPLayer : public GGMLBlock {
488488
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
489489
}
490490

491-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = true) {
491+
struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) {
492492
// x: [N, n_token, d_model]
493493
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
494494
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
495495
auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]);
496496
auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]);
497497

498-
x = ggml_add(ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask));
498+
x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask));
499499
x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
500500
return x;
501501
}
@@ -517,7 +517,11 @@ struct CLIPEncoder : public GGMLBlock {
517517
}
518518
}
519519

520-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) {
520+
struct ggml_tensor* forward(struct ggml_context* ctx,
521+
ggml_backend_t backend,
522+
struct ggml_tensor* x,
523+
int clip_skip = -1,
524+
bool mask = true) {
521525
// x: [N, n_token, d_model]
522526
int layer_idx = n_layer - 1;
523527
// LOG_DEBUG("clip_skip %d", clip_skip);
@@ -532,7 +536,7 @@ struct CLIPEncoder : public GGMLBlock {
532536
}
533537
std::string name = "layers." + std::to_string(i);
534538
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
535-
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
539+
x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model]
536540
// LOG_DEBUG("layer %d", i);
537541
}
538542
return x;
@@ -712,6 +716,7 @@ class CLIPTextModel : public GGMLBlock {
712716
}
713717

714718
struct ggml_tensor* forward(struct ggml_context* ctx,
719+
ggml_backend_t backend,
715720
struct ggml_tensor* input_ids,
716721
struct ggml_tensor* tkn_embeddings,
717722
size_t max_token_idx = 0,
@@ -722,7 +727,7 @@ class CLIPTextModel : public GGMLBlock {
722727
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
723728

724729
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
725-
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
730+
x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true);
726731
if (return_pooled || with_final_ln) {
727732
x = final_layer_norm->forward(ctx, x);
728733
}
@@ -775,6 +780,7 @@ class CLIPVisionModel : public GGMLBlock {
775780
}
776781

777782
struct ggml_tensor* forward(struct ggml_context* ctx,
783+
ggml_backend_t backend,
778784
struct ggml_tensor* pixel_values,
779785
bool return_pooled = true,
780786
int clip_skip = -1) {
@@ -786,7 +792,7 @@ class CLIPVisionModel : public GGMLBlock {
786792

787793
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
788794
x = pre_layernorm->forward(ctx, x);
789-
x = encoder->forward(ctx, x, clip_skip, false);
795+
x = encoder->forward(ctx, backend, x, clip_skip, false);
790796
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
791797
auto last_hidden_state = x;
792798
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
@@ -855,6 +861,7 @@ class CLIPVisionModelProjection : public GGMLBlock {
855861
}
856862

857863
struct ggml_tensor* forward(struct ggml_context* ctx,
864+
ggml_backend_t backend,
858865
struct ggml_tensor* pixel_values,
859866
bool return_pooled = true,
860867
int clip_skip = -1) {
@@ -863,7 +870,7 @@ class CLIPVisionModelProjection : public GGMLBlock {
863870
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
864871
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
865872

866-
auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
873+
auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
867874

868875
if (return_pooled) {
869876
x = visual_projection->forward(ctx, x); // [N, projection_dim]
@@ -900,6 +907,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
900907
}
901908

902909
struct ggml_tensor* forward(struct ggml_context* ctx,
910+
ggml_backend_t backend,
903911
struct ggml_tensor* input_ids,
904912
struct ggml_tensor* embeddings,
905913
size_t max_token_idx = 0,
@@ -911,7 +919,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
911919
input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
912920
}
913921

914-
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled);
922+
return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled);
915923
}
916924

917925
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
@@ -937,7 +945,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
937945
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
938946
}
939947

940-
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, embeddings, max_token_idx, return_pooled);
948+
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled);
941949

942950
ggml_build_forward_expand(gf, hidden_states);
943951

common.hpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ class CrossAttention : public GGMLBlock {
270270
// to_out_1 is nn.Dropout(), skip for inference
271271
}
272272

273-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
273+
struct ggml_tensor* forward(struct ggml_context* ctx,
274+
ggml_backend_t backend,
275+
struct ggml_tensor* x,
276+
struct ggml_tensor* context) {
274277
// x: [N, n_token, query_dim]
275278
// context: [N, n_context, context_dim]
276279
// return: [N, n_token, query_dim]
@@ -288,7 +291,7 @@ class CrossAttention : public GGMLBlock {
288291
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
289292
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
290293

291-
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
294+
x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
292295

293296
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
294297
return x;
@@ -327,7 +330,10 @@ class BasicTransformerBlock : public GGMLBlock {
327330
}
328331
}
329332

330-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
333+
struct ggml_tensor* forward(struct ggml_context* ctx,
334+
ggml_backend_t backend,
335+
struct ggml_tensor* x,
336+
struct ggml_tensor* context) {
331337
// x: [N, n_token, query_dim]
332338
// context: [N, n_context, context_dim]
333339
// return: [N, n_token, query_dim]
@@ -352,11 +358,11 @@ class BasicTransformerBlock : public GGMLBlock {
352358

353359
auto r = x;
354360
x = norm1->forward(ctx, x);
355-
x = attn1->forward(ctx, x, x); // self-attention
361+
x = attn1->forward(ctx, backend, x, x); // self-attention
356362
x = ggml_add(ctx, x, r);
357363
r = x;
358364
x = norm2->forward(ctx, x);
359-
x = attn2->forward(ctx, x, context); // cross-attention
365+
x = attn2->forward(ctx, backend, x, context); // cross-attention
360366
x = ggml_add(ctx, x, r);
361367
r = x;
362368
x = norm3->forward(ctx, x);
@@ -401,7 +407,10 @@ class SpatialTransformer : public GGMLBlock {
401407
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
402408
}
403409

404-
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
410+
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
411+
ggml_backend_t backend,
412+
struct ggml_tensor* x,
413+
struct ggml_tensor* context) {
405414
// x: [N, in_channels, h, w]
406415
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
407416
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
@@ -424,7 +433,7 @@ class SpatialTransformer : public GGMLBlock {
424433
std::string name = "transformer_blocks." + std::to_string(i);
425434
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
426435

427-
x = transformer_block->forward(ctx, x, context);
436+
x = transformer_block->forward(ctx, backend, x, context);
428437
}
429438

430439
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]

conditioner.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
639639

640640
pixel_values = to_backend(pixel_values);
641641

642-
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled, clip_skip);
642+
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip);
643643

644644
ggml_build_forward_expand(gf, hidden_states);
645645

control.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,11 @@ class ControlNetBlock : public GGMLBlock {
174174

175175
struct ggml_tensor* attention_layer_forward(std::string name,
176176
struct ggml_context* ctx,
177+
ggml_backend_t backend,
177178
struct ggml_tensor* x,
178179
struct ggml_tensor* context) {
179180
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
180-
return block->forward(ctx, x, context);
181+
return block->forward(ctx, backend, x, context);
181182
}
182183

183184
struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx,
@@ -199,6 +200,7 @@ class ControlNetBlock : public GGMLBlock {
199200
}
200201

201202
std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx,
203+
ggml_backend_t backend,
202204
struct ggml_tensor* x,
203205
struct ggml_tensor* hint,
204206
struct ggml_tensor* guided_hint,
@@ -272,7 +274,7 @@ class ControlNetBlock : public GGMLBlock {
272274
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
273275
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
274276
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
275-
h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w]
277+
h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w]
276278
}
277279

278280
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
@@ -296,9 +298,9 @@ class ControlNetBlock : public GGMLBlock {
296298
// [N, 4*model_channels, h/8, w/8]
297299

298300
// middle_block
299-
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
300-
h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8]
301-
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
301+
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
302+
h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8]
303+
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
302304

303305
// out
304306
outs.push_back(middle_block_out->forward(ctx, h));
@@ -403,6 +405,7 @@ struct ControlNet : public GGMLRunner {
403405
timesteps = to_backend(timesteps);
404406

405407
auto outs = control_net.forward(compute_ctx,
408+
runtime_backend,
406409
x,
407410
hint,
408411
guided_hint_cached ? guided_hint : NULL,

0 commit comments

Comments
 (0)