@@ -70,6 +70,20 @@ namespace refactor::kernel {
7070        }
7171    }
7272
73+     static  __global__  void  concatCache (
74+         void  *__restrict__  cache,
75+         void  const  *__restrict__  value,
76+         dim_t  pageStrideI,
77+         dim_t  pageStrideO,
78+         dim_t  lineStride,
79+         dim_t  pastOffset) {
80+ 
81+         auto  tid = blockIdx .x  * blockDim .x  + threadIdx .x ,
82+              dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
83+         reinterpret_cast <float4  *>(cache)[dst] = reinterpret_cast <float4  const  *>(value)[tid];
84+     }
85+     constexpr  uint64_t  DYNAMIC_WORKSPACE_SIZE = 40  << 20 ;//  试出来 40MiB 是够用的
86+ 
7387    RoutineWorkspace K::lower (Resources &res) const  {
7488        auto  handle = res.fetchOrStore <CublasLtContext>()->handle ;
7589
@@ -125,8 +139,8 @@ namespace refactor::kernel {
125139                              .batchCount  = static_cast <int32_t >(info.batch  * info.nHead ),
126140                              .batchStride  = static_cast <int64_t >(info.seqLen  * info.seqLen ),
127141                          }) {
128-                         auto  [algoQK_, workspaceSizeQK_] = tune (context.handle , mul, q, k, att);
129-                         auto  [algoAV_, workspaceSizeAV_] = tune (context.handle , mul, att, v, q);
142+                         auto  [algoQK_, workspaceSizeQK_] = tune (context.handle , mul, q, k, att, DYNAMIC_WORKSPACE_SIZE );
143+                         auto  [algoAV_, workspaceSizeAV_] = tune (context.handle , mul, att, v, q, DYNAMIC_WORKSPACE_SIZE );
130144                        algoQK = algoQK_;
131145                        algoAV = algoAV_;
132146                        workspaceSizeQK = workspaceSizeQK_;
@@ -187,12 +201,146 @@ namespace refactor::kernel {
187201                                &d->algoAV , 
188202                                workspaceAV, d->workspaceSizeAV , 
189203                                stream); 
190-                         };  
204+                         } 
191205                    }; 
192206
193207                return  {std::move (routine), workspaceSize}; 
194208            } 
209+             TODO (" "  
195210        } 
211+         if  (info.concatCache && !info.resetCache) { 
212+             if  (info.nHead  == info.nKVHead ) { 
213+ 
214+                 //  RAII for closure 
215+                 struct  Descriptors  { 
216+                     MatMulDescriptor mul; 
217+ 
218+                     Descriptors (AttentionInfo info) 
219+                         : mul(computeTypeConvert(info.dataType), 
220+                               dataTypeConvert (info.dataType)) {} 
221+                 }; 
222+ 
223+                 auto  const  &context = *res.fetchOrStore<CublasLtContext>(); 
224+                 auto  d = std::make_shared<Descriptors>(info); 
225+                 auto  attentionSize = info.maxAttSize(); 
226+                 auto  workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize; 
227+ 
228+                 auto  routine = [d = std::move(d), info = this ->info]//  
229+                     (Resources & res, void  *workspace, void  const  *const  *inputs, void  *const  *outputs) { 
230+                         auto  handle = res.fetchOrStore <CublasLtContext>()->handle ; 
231+                         auto  q = inputs[0 ]; 
232+                         auto  k = inputs[1 ]; 
233+                         auto  v = inputs[2 ]; 
234+                         auto  past = *reinterpret_cast <int64_t  const  *>(inputs[3 ]); 
235+                         auto  attLen = info.attLen (past); 
236+                         auto  o = reinterpret_cast <half *>(outputs[0 ]); 
237+                         auto  kCache  = reinterpret_cast <half *>(outputs[1 ]); 
238+                         auto  vCache = reinterpret_cast <half *>(outputs[2 ]); 
239+                         auto  att = reinterpret_cast <half *>(reinterpret_cast <uint8_t  *>(workspace) + DYNAMIC_WORKSPACE_SIZE); 
240+                         auto  stream = cudaStreamLegacy; 
241+                         { 
242+                             auto  itemsPerLine = info.headDim  * sizeof (half) / sizeof (float4 ); 
243+                             auto  threads = info.batch  * info.nHead  * info.seqLen  * itemsPerLine; 
244+                             auto  blocks = (threads + 1023 ) / 1024 ; 
245+ 
246+                             concatCache<<<blocks, 1024 , 0 , stream>>> ( 
247+                                 kCache , k, 
248+                                 info.seqLen  * itemsPerLine, 
249+                                 info.cacheLen  * itemsPerLine, 
250+                                 itemsPerLine, 
251+                                 past * itemsPerLine); 
252+                             concatCache<<<blocks, 1024 , 0 , stream>>> ( 
253+                                 vCache, v, 
254+                                 info.seqLen  * itemsPerLine, 
255+                                 info.cacheLen  * itemsPerLine, 
256+                                 itemsPerLine, 
257+                                 past * itemsPerLine); 
258+                         } 
259+                         MatrixDescriptor 
260+                             q_ (MatrixLayout{ 
261+                                 .dataType  = dataTypeConvert (info.dataType ), 
262+                                 .rows  = static_cast <uint64_t >(info.seqLen ), 
263+                                 .cols  = static_cast <uint64_t >(info.headDim ), 
264+                                 .majorStride  = static_cast <int64_t >(info.headDim ), 
265+                                 .order  = ROW_MAJOR, 
266+                                 .batchCount  = static_cast <int32_t >(info.batch  * info.nHead ), 
267+                                 .batchStride  = static_cast <int64_t >(info.seqLen  * info.headDim ), 
268+                             }), 
269+                             k_ (MatrixLayout{ 
270+                                 .dataType  = dataTypeConvert (info.dataType ), 
271+                                 .rows  = static_cast <uint64_t >(info.headDim ), 
272+                                 .cols  = static_cast <uint64_t >(attLen), 
273+                                 .majorStride  = static_cast <int64_t >(info.headDim ), 
274+                                 .order  = COL_MAJOR, 
275+                                 .batchCount  = static_cast <int32_t >(info.batch  * info.nHead ), 
276+                                 .batchStride  = static_cast <int64_t >(info.cacheLen  * info.headDim ), 
277+                             }), 
278+                             v_ (MatrixLayout{ 
279+                                 .dataType  = dataTypeConvert (info.dataType ), 
280+                                 .rows  = static_cast <uint64_t >(attLen), 
281+                                 .cols  = static_cast <uint64_t >(info.headDim ), 
282+                                 .majorStride  = static_cast <int64_t >(info.headDim ), 
283+                                 .order  = ROW_MAJOR, 
284+                                 .batchCount  = static_cast <int32_t >(info.batch  * info.nHead ), 
285+                                 .batchStride  = static_cast <int64_t >(info.cacheLen  * info.headDim ), 
286+                             }), 
287+                             att_ (MatrixLayout{ 
288+                                 .dataType  = dataTypeConvert (info.dataType ), 
289+                                 .rows  = static_cast <uint64_t >(info.seqLen ), 
290+                                 .cols  = static_cast <uint64_t >(attLen), 
291+                                 .majorStride  = static_cast <int64_t >(info.cacheLen ), 
292+                                 .order  = ROW_MAJOR, 
293+                                 .batchCount  = static_cast <int32_t >(info.batch  * info.nHead ), 
294+                                 .batchStride  = static_cast <int64_t >(info.cacheLen  * info.seqLen ), 
295+                             }); 
296+                         { 
297+                             auto  [algo, workspaceSize] = tune ( 
298+                                 handle, d->mul , 
299+                                 q_, k_, att_, 
300+                                 DYNAMIC_WORKSPACE_SIZE); 
301+                             half alpha = rsqrtf (info.headDim ), beta = 0 ; 
302+                             cublasLtMatmul ( 
303+                                 handle, d->mul .get (), 
304+                                 &alpha, 
305+                                 q, q_.get (), 
306+                                 kCache , k_.get (), 
307+                                 &beta, 
308+                                 att, att_.get (), 
309+                                 att, att_.get (), 
310+                                 &algo, 
311+                                 workspace, workspaceSize, 
312+                                 stream); 
313+                         } 
314+                         softmax<<<dim3 (info.batch * info.nHead, info.seqLen),  
315+                                   std::min (1024u , attLen), 
316+                                   attLen * sizeof(float ), 
317+                                   stream>>>( 
318+                             att, AttentionCausualMask(), attLen, info.cacheLen); 
319+                         { 
320+                             auto  [algo, workspaceSize] = tune ( 
321+                                 handle, d->mul , 
322+                                 att_, v_, q_, 
323+                                 DYNAMIC_WORKSPACE_SIZE); 
324+                             half alpha = 1 , beta = 0 ; 
325+                             cublasLtMatmul ( 
326+                                 handle, d->mul .get (), 
327+                                 &alpha, 
328+                                 att, att_.get (), 
329+                                 vCache, v_.get (), 
330+                                 &beta, 
331+                                 o, q_.get (), 
332+                                 o, q_.get (), 
333+                                 &algo, 
334+                                 workspace, workspaceSize, 
335+                                 stream); 
336+                         } 
337+                     }; 
338+ 
339+                 return  {std::move (routine), workspaceSize}; 
340+             } 
341+             TODO (" "  
342+         } 
343+ 
196344        TODO (" "  
197345    } 
198346
0 commit comments