@@ -8,14 +8,11 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
8
8
if (GGML_CUDA_CC_IS_AMD (cc)) {
9
9
switch (D) {
10
10
case 64 :
11
- return 64 ;
11
+ return ncols <= 16 ? 32 : 64 ;
12
12
case 128 :
13
+ return ncols <= 16 ? 64 : warp_size;
13
14
case 256 :
14
- if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
15
- return ncols <= 16 ? 64 : 32 ;
16
- } else {
17
- return 64 ;
18
- }
15
+ return 64 ;
19
16
default :
20
17
GGML_ABORT (" fatal error" );
21
18
return -1 ;
@@ -44,26 +41,17 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
44
41
GGML_ABORT (" fatal error" );
45
42
return -1 ;
46
43
}
47
- GGML_UNUSED (warp_size);
48
44
}
49
45
50
46
static constexpr __device__ int fattn_tile_get_kq_stride_device (int D, int ncols, int warp_size) {
51
47
#ifdef GGML_USE_HIP
52
48
switch (D) {
53
49
case 64 :
54
- return 64 ;
50
+ return ncols <= 16 ? 32 : 64 ;
55
51
case 128 :
56
- #if defined(GCN) || defined(CDNA)
57
- return ncols <= 16 ? 64 : 32 ;
58
- #else
59
- return 64 ;
60
- #endif // defined(GCN) || defined(CDNA)
52
+ return ncols <= 16 ? 64 : warp_size;
61
53
case 256 :
62
- #if defined(GCN) || defined(CDNA)
63
- return ncols <= 16 ? 64 : 32 ;
64
- #else
65
54
return 64 ;
66
- #endif // defined(GCN) || defined(CDNA)
67
55
default :
68
56
return -1 ;
69
57
}
@@ -100,17 +88,9 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
100
88
case 64 :
101
89
return 64 ;
102
90
case 128 :
103
- #if defined(GCN) || defined(CDNA)
104
- return ncols <= 16 ? 64 : 128 ;
105
- #else
106
- return 64 ;
107
- #endif // defined(GCN) || defined(CDNA)
91
+ return ncols <= 16 ? 2 *warp_size : 128 ;
108
92
case 256 :
109
- #if defined(GCN) || defined(CDNA)
110
- return ncols <= 16 ? 64 : 128 ;
111
- #else
112
- return ncols <= 16 ? 64 : 256 ;
113
- #endif // defined(GCN) || defined(CDNA)
93
+ return ncols <= 16 ? 128 : 2 *warp_size;
114
94
default :
115
95
return -1 ;
116
96
}
@@ -216,21 +196,14 @@ static __global__ void flash_attn_tile(
216
196
217
197
const float slope = get_alibi_slope (max_bias, head, n_head_log2, m0, m1);
218
198
219
- #if defined(GGML_USE_HIP)
220
- constexpr int cpy_nb = 16 ;
221
- #else
222
- constexpr int cpy_nb = 8 ;
223
- #endif // defined(GGML_USE_HIP) && defined(GCN)
224
- constexpr int cpy_ne = cpy_nb / 4 ;
225
-
226
199
__shared__ float KQ[ncols][kq_stride];
227
200
#ifdef FAST_FP16_AVAILABLE
228
201
__shared__ half2 Q_tmp[ncols][D/2 ];
229
- __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne )]; // Padded to avoid memory bank conflicts.
202
+ __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1 )]; // Padded to avoid memory bank conflicts.
230
203
half2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
231
204
#else
232
205
__shared__ float Q_tmp[ncols][D];
233
- __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne )]; // Padded to avoid memory bank conflicts.
206
+ __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1 )]; // Padded to avoid memory bank conflicts.
234
207
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
235
208
float2 VKQ[ncols/nwarps][D/(2 *warp_size)] = {{{0 .0f , 0 .0f }}};
236
209
#endif // FAST_FP16_AVAILABLE
@@ -283,11 +256,11 @@ static __global__ void flash_attn_tile(
283
256
for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += warp_size) {
284
257
const half2 tmp_h2 = K_h2[int64_t (k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx .x ];
285
258
#ifdef FAST_FP16_AVAILABLE
286
- KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
259
+ KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1 + threadIdx .x ] = tmp_h2;
287
260
#else
288
261
const float2 tmp_f2 = __half22float2 (tmp_h2);
289
- KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
290
- KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
262
+ KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + threadIdx .x ] = tmp_f2.x ;
263
+ KV_tmp_f[i_KQ*(kq_nbatch + 1 ) + 2 *k_KQ_1 + warp_size + threadIdx .x ] = tmp_f2.y ;
291
264
#endif // FAST_FP16_AVAILABLE
292
265
}
293
266
}
@@ -296,45 +269,42 @@ static __global__ void flash_attn_tile(
296
269
297
270
#ifdef FAST_FP16_AVAILABLE
298
271
#pragma unroll
299
- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; k_KQ_1 += cpy_ne ) {
300
- half2 K_k[kq_stride/warp_size][cpy_ne] ;
301
- half2 Q_k[ncols/nwarps][cpy_ne] ;
272
+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch/2 ; ++k_KQ_1 ) {
273
+ half2 K_k[kq_stride/warp_size];
274
+ half2 Q_k[ncols/nwarps];
302
275
#else
303
276
#pragma unroll
304
- for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne ) {
305
- float K_k[kq_stride/warp_size][cpy_ne] ;
306
- float Q_k[ncols/nwarps][cpy_ne] ;
277
+ for (int k_KQ_1 = 0 ; k_KQ_1 < kq_nbatch; ++k_KQ_1 ) {
278
+ float K_k[kq_stride/warp_size];
279
+ float Q_k[ncols/nwarps];
307
280
#endif // FAST_FP16_AVAILABLE
308
281
309
282
#pragma unroll
310
283
for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
311
284
const int i_KQ = i_KQ_0 + threadIdx .x ;
312
285
313
286
#ifdef FAST_FP16_AVAILABLE
314
- ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne ) + k_KQ_1]) ;
287
+ K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1 ) + k_KQ_1];
315
288
#else
316
- ggml_cuda_memcpy_1<cpy_nb>(& K_k[i_KQ_0/warp_size], & KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne ) + k_KQ_1]) ;
289
+ K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1 ) + k_KQ_1];
317
290
#endif // FAST_FP16_AVAILABLE
318
291
}
319
292
#pragma unroll
320
293
for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
321
294
const int j_KQ = j_KQ_0 + threadIdx .y ;
322
295
323
296
#ifdef FAST_FP16_AVAILABLE
324
- ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]) ;
297
+ Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
325
298
#else
326
- ggml_cuda_memcpy_1<cpy_nb>(& Q_k[j_KQ_0/nwarps], & Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]) ;
299
+ Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
327
300
#endif // FAST_FP16_AVAILABLE
328
301
}
329
302
330
303
#pragma unroll
331
304
for (int i_KQ_0 = 0 ; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
332
305
#pragma unroll
333
306
for (int j_KQ_0 = 0 ; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
334
- #pragma unroll
335
- for (int k = 0 ; k < cpy_ne; ++k) {
336
- ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
337
- }
307
+ ggml_cuda_mad (sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
338
308
}
339
309
}
340
310
}
@@ -375,54 +345,14 @@ static __global__ void flash_attn_tile(
375
345
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
376
346
377
347
float kqsum_add = 0 .0f ;
378
- if (kq_stride % (4 *warp_size) == 0 && cpy_ne % 4 == 0 ) {
379
348
#pragma unroll
380
- for (int i0 = 0 ; i0 < kq_stride; i0 += 4 * warp_size) {
381
- const int i = i0 + 4 * threadIdx .x ;
349
+ for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
350
+ const int i = i0 + threadIdx .x ;
382
351
383
- float4 val = *(const float4 *) &KQ[j][i];
384
- val.x = expf (val.x - kqmax[j0/nwarps]);
385
- val.y = expf (val.y - kqmax[j0/nwarps]);
386
- val.z = expf (val.z - kqmax[j0/nwarps]);
387
- val.w = expf (val.w - kqmax[j0/nwarps]);
388
- kqsum_add += val.x + val.y + val.z + val.w ;
389
-
390
- #ifdef FAST_FP16_AVAILABLE
391
- const half2 tmp[2 ] = {make_half2 (val.x , val.y ), make_half2 (val.z , val.w )};
392
- ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
393
- #else
394
- ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
395
- #endif // FAST_FP16_AVAILABLE
396
- }
397
- } else if (kq_stride % (2 *warp_size) == 0 && cpy_ne % 2 == 0 ) {
398
- #pragma unroll
399
- for (int i0 = 0 ; i0 < kq_stride; i0 += 2 *warp_size) {
400
- const int i = i0 + 2 *threadIdx .x ;
401
-
402
- float2 val = *(const float2 *) &KQ[j][i];
403
- val.x = expf (val.x - kqmax[j0/nwarps]);
404
- val.y = expf (val.y - kqmax[j0/nwarps]);
405
- kqsum_add += val.x + val.y ;
406
- #ifdef FAST_FP16_AVAILABLE
407
- const half2 tmp = make_half2 (val.x , val.y );
408
- ggml_cuda_memcpy_1<sizeof (tmp)>(&KQ[j][i/2 ], &tmp);
409
- #else
410
- ggml_cuda_memcpy_1<sizeof (val)>(&KQ[j][i], &val);
411
- #endif // FAST_FP16_AVAILABLE
412
- }
413
- } else {
414
- for (int i0 = 0 ; i0 < kq_stride; i0 += warp_size) {
415
- const int i = i0 + threadIdx .x ;
416
-
417
- const float diff = KQ[j][i] - kqmax[j0/nwarps];
418
- const float val = expf (diff);
419
- kqsum_add += val;
420
- #ifdef FAST_FP16_AVAILABLE
421
- ((half *) KQ[j])[i] = val;
422
- #else
423
- KQ[j][i] = val;
424
- #endif // FAST_FP16_AVAILABLE
425
- }
352
+ const float diff = KQ[j][i] - kqmax[j0/nwarps];
353
+ const float val = expf (diff);
354
+ kqsum_add += val;
355
+ KQ[j][i] = val;
426
356
}
427
357
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
428
358
@@ -489,7 +419,8 @@ static __global__ void flash_attn_tile(
489
419
const int j = j0 + threadIdx .y ;
490
420
491
421
#ifdef FAST_FP16_AVAILABLE
492
- KQ_k[j0/nwarps] = __half2half2 (((const half *)KQ[j])[k0 + k1]);
422
+ const float tmp = KQ[j][k0 + k1];
423
+ KQ_k[j0/nwarps] = make_half2 (tmp, tmp);
493
424
#else
494
425
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
495
426
#endif // FAST_FP16_AVAILABLE
0 commit comments