Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
}

static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && defined(GCN)
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
Expand All @@ -567,7 +567,21 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(GCN)
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
}

// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
template <int nbytes>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (nbytes == 4) {
*(int *) dst = *(const int *) src;
} else if constexpr (nbytes == 8) {
*(int2 *) dst = *(const int2 *) src;
} else if constexpr (nbytes == 16) {
*(int4 *) dst = *(const int4 *) src;
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this work?

Suggested change
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
static_assert(false, "bad nbytes");

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this first, it failed during the host pass.

}
}

static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
Expand Down
131 changes: 100 additions & 31 deletions ggml/src/ggml-cuda/fattn-tile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
if (GGML_CUDA_CC_IS_AMD(cc)) {
switch (D) {
case 64:
return ncols <= 16 ? 32 : 64;
return 64;
case 128:
return ncols <= 16 ? 64 : warp_size;
case 256:
return 64;
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
return ncols <= 16 ? 64 : 32;
} else {
return 64;
}
default:
GGML_ABORT("fatal error");
return -1;
Expand Down Expand Up @@ -41,17 +44,26 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
GGML_ABORT("fatal error");
return -1;
}
GGML_UNUSED(warp_size);
}

static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
switch (D) {
case 64:
return ncols <= 16 ? 32 : 64;
return 64;
case 128:
return ncols <= 16 ? 64 : warp_size;
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 32;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
case 256:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 32;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
default:
return -1;
}
Expand Down Expand Up @@ -88,9 +100,17 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
case 64:
return 64;
case 128:
return ncols <= 16 ? 2*warp_size : 128;
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 128;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
case 256:
return ncols <= 16 ? 128 : 2*warp_size;
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 128;
#else
return ncols <= 16 ? 64 : 256;
#endif // defined(GCN) || defined(CDNA)
default:
return -1;
}
Expand Down Expand Up @@ -196,14 +216,21 @@ static __global__ void flash_attn_tile(

const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);

#if defined(GGML_USE_HIP)
constexpr int cpy_nb = 16;
#else
constexpr int cpy_nb = 8;
#endif // defined(GGML_USE_HIP) && defined(GCN)
constexpr int cpy_ne = cpy_nb / 4;

__shared__ float KQ[ncols][kq_stride];
#ifdef FAST_FP16_AVAILABLE
__shared__ half2 Q_tmp[ncols][D/2];
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1)]; // Padded to avoid memory bank conflicts.
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#else
__shared__ float Q_tmp[ncols][D];
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1)]; // Padded to avoid memory bank conflicts.
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#endif // FAST_FP16_AVAILABLE
Expand Down Expand Up @@ -256,11 +283,11 @@ static __global__ void flash_attn_tile(
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
#ifdef FAST_FP16_AVAILABLE
KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1 + threadIdx.x] = tmp_h2;
KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
#else
const float2 tmp_f2 = __half22float2(tmp_h2);
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
#endif // FAST_FP16_AVAILABLE
}
}
Expand All @@ -269,42 +296,45 @@ static __global__ void flash_attn_tile(

#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; ++k_KQ_1) {
half2 K_k[kq_stride/warp_size];
half2 Q_k[ncols/nwarps];
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
half2 K_k[kq_stride/warp_size][cpy_ne];
half2 Q_k[ncols/nwarps][cpy_ne];
#else
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; ++k_KQ_1) {
float K_k[kq_stride/warp_size];
float Q_k[ncols/nwarps];
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
float K_k[kq_stride/warp_size][cpy_ne];
float Q_k[ncols/nwarps][cpy_ne];
#endif // FAST_FP16_AVAILABLE

#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;

#ifdef FAST_FP16_AVAILABLE
K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1];
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
#else
K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1) + k_KQ_1];
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;

#ifdef FAST_FP16_AVAILABLE
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
#else
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}

#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
#pragma unroll
for (int k = 0; k < cpy_ne; ++k) {
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
}
}
}
}
Expand Down Expand Up @@ -345,14 +375,54 @@ static __global__ void flash_attn_tile(
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];

float kqsum_add = 0.0f;
if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;
for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
const int i = i0 + 4*threadIdx.x;

const float diff = KQ[j][i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
KQ[j][i] = val;
float4 val = *(const float4 *) &KQ[j][i];
val.x = expf(val.x - kqmax[j0/nwarps]);
val.y = expf(val.y - kqmax[j0/nwarps]);
val.z = expf(val.z - kqmax[j0/nwarps]);
val.w = expf(val.w - kqmax[j0/nwarps]);
kqsum_add += val.x + val.y + val.z + val.w;

#ifdef FAST_FP16_AVAILABLE
const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
#else
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
#endif // FAST_FP16_AVAILABLE
}
} else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
const int i = i0 + 2*threadIdx.x;

float2 val = *(const float2 *) &KQ[j][i];
val.x = expf(val.x - kqmax[j0/nwarps]);
val.y = expf(val.y - kqmax[j0/nwarps]);
kqsum_add += val.x + val.y;
#ifdef FAST_FP16_AVAILABLE
const half2 tmp = make_half2(val.x, val.y);
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
#else
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
#endif // FAST_FP16_AVAILABLE
}
} else {
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;

const float diff = KQ[j][i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
#ifdef FAST_FP16_AVAILABLE
((half *) KQ[j])[i] = val;
#else
KQ[j][i] = val;
#endif // FAST_FP16_AVAILABLE
}
}
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;

Expand Down Expand Up @@ -419,8 +489,7 @@ static __global__ void flash_attn_tile(
const int j = j0 + threadIdx.y;

#ifdef FAST_FP16_AVAILABLE
const float tmp = KQ[j][k0 + k1];
KQ_k[j0/nwarps] = make_half2(tmp, tmp);
KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
#else
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
#endif // FAST_FP16_AVAILABLE
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@
#define GCN
#endif

#if defined(__gfx900__) || defined(__gfx906__)
#define GCN5
#endif

#if defined(__gfx803__)
#define GCN4
#endif

#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA // For the entire family
#endif
Expand Down
Loading