Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
31f6c19
Init PA CM Impl(1st/2nd token and kvcache update)
riverlijunjie Aug 29, 2025
af9e8e5
enabled simple pa unit tests pass
riverlijunjie Aug 31, 2025
5f7c5da
Fix 2nd_token issue
riverlijunjie Aug 31, 2025
840d221
Fixed pipeline output corruption issue
riverlijunjie Sep 2, 2025
d3c1e91
Fix 2nd non-16 alignment accuracy issue
riverlijunjie Sep 2, 2025
904da7c
Set best partition size for 2nd
riverlijunjie Sep 2, 2025
6c0ea23
update KV_BLOCK_SIZE to 256
ceciliapeng2011 Sep 3, 2025
217001a
initiate xattention integration
ceciliapeng2011 Sep 3, 2025
d7a8ec1
qwen2.5-1.5b 4k trunk works with xatten.
ceciliapeng2011 Sep 5, 2025
7858f61
4k aligned works.
ceciliapeng2011 Sep 5, 2025
9b806b0
fix block_mask not fully initialized issue.
ceciliapeng2011 Sep 5, 2025
f06a71a
fix of find_block
ceciliapeng2011 Sep 8, 2025
c702b1f
xatten: fix accuacy problem caused by debug
ceciliapeng2011 Sep 9, 2025
66b34ed
use int32 to store float INV_S to align python version accuracy
luo-cheng2021 Sep 10, 2025
b383ec8
OV_GPU_XATTN_BLOCK_SIZE and OV_GPU_XATTN_THRESH
ceciliapeng2011 Sep 10, 2025
d15f12c
fix building error on windows.
usstq Sep 10, 2025
39a7b5f
process tail in find_block
ceciliapeng2011 Sep 12, 2025
05bb45d
Fix f16 accuracy issue and optimize 2nd token to improve 5%
riverlijunjie Sep 9, 2025
238d1ff
fix waring_as_error on CI Windows.
ceciliapeng2011 Sep 15, 2025
ffdf2f1
dump block mask with DUMP_XATTN_BLOCK_MASK for debug
ceciliapeng2011 Sep 15, 2025
623f524
fix xatten case: (tails=q_len % 128) < 16
luo-cheng2021 Sep 15, 2025
384e955
Fix C220 warning as error
peterchen-intel Sep 18, 2025
64df098
Disable C2220
peterchen-intel Sep 18, 2025
536285a
C2220 to 2220
peterchen-intel Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

namespace cldnn {

#define ENABLE_PA_CM_PATH 1

struct paged_attention : public primitive_base<paged_attention> {
CLDNN_DECLARE_PRIMITIVE(paged_attention)

Expand All @@ -36,7 +38,7 @@ struct paged_attention : public primitive_base<paged_attention> {
XATTENTION_STRIDE = 19,
};

static constexpr size_t block_size = 16;
static constexpr size_t block_size = 256;

paged_attention() : primitive_base("", {}) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ static constexpr Property<bool, ov::PropertyMutability::RW> asym_dynamic_quantiz
static constexpr Property<ShapePredictor::Settings, ov::PropertyMutability::RW> shape_predictor_settings{"GPU_SHAPE_PREDICTOR_SETTINGS"};
static constexpr Property<std::vector<std::string>, ov::PropertyMutability::RW> load_dump_raw_binary{"GPU_LOAD_DUMP_RAW_BINARY"};
static constexpr Property<bool, ov::PropertyMutability::RW> could_use_flashattn_v2{"GPU_COULD_USE_FLASHATTN_V2"};
static constexpr Property<size_t, ov::PropertyMutability::RW> xattention_block_size{"GPU_XATTN_BLOCK_SIZE"};
static constexpr Property<uint64_t, PropertyMutability::RW> dynamic_quantization_group_size_max{"GPU_DYNAMIC_QUANTIZATION_GROUP_SIZE_MAX"};
} // namespace ov::intel_gpu

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, asym_dynamic_quantization, fals
OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, could_use_flashattn_v2, true, "Enable/Disable SDPA primitive executing with FlashAttenV2 online softmax tricks.")
OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, dynamic_quantization_threshold, 64, "Apply dynamic quantization only when batch size is larger than this value in OneDNN")
OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, weightless_attr, nullptr, "Used to configure ov::WeightlessCacheAttribute for constants that are not loaded from a .bin file. This typically applies to non-IR inputs (e.g., ORT)")
OV_CONFIG_RELEASE_INTERNAL_OPTION(ov::intel_gpu, xattention_block_size, 128, "block size for X-Attention sparse.")

OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, help, false, "Print help message for all config options")
OV_CONFIG_DEBUG_GLOBAL_OPTION(ov::intel_gpu, verbose, 0, "Enable logging for debugging purposes. The higher value the more verbose output. 0 - Disabled, 4 - Maximum verbosity")
Expand Down
27 changes: 27 additions & 0 deletions src/plugins/intel_gpu/src/graph/debug_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,33 @@ NodeDebugHelper::~NodeDebugHelper() {
log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw);
}
}

for (size_t i = 0; i < m_inst.inputs_memory_count(); i++) {
std::string name = get_file_prefix() + "_updated_src_" + std::to_string(i);
auto output_mem = m_inst.input_memory_ptr(i);
if (output_mem == nullptr) {
GPU_DEBUG_COUT << " updated_input_mem is nullptr. Nothing to dump." << std::endl;
continue;
}

auto& output_layout = m_inst.get_input_layout(i);
if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary) {
// Binary dump : raw
auto filename = get_file_path_for_binary_dump(output_layout, name, config.get_dump_tensors_path());

mem_lock<char, mem_lock_type::read> lock(output_mem, m_stream);
ov::util::save_binary(filename, lock.data(), output_mem->size());
GPU_DEBUG_COUT << " Dump layer dst : " << layer_name << " to " << filename << std::endl;
debug_str_for_bin_load += (filename + ",");
} else {
const bool dump_raw = config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::text_raw;
GPU_DEBUG_COUT << " Dump " << (dump_raw ? "raw " : "") << name << std::endl;
auto filename = config.get_dump_tensors_path() + get_name_for_dump(name) + ".txt";
// Text dump
log_memory_to_file(output_mem, output_layout, m_stream, filename, dump_raw);
}
}

if (config.get_dump_tensors_format() == ov::intel_gpu::DumpFormat::binary && m_inst.is_input()) {
debug_str_for_bin_load[debug_str_for_bin_load.size()-1] = '\"';
GPU_DEBUG_COUT << debug_str_for_bin_load << std::endl;;
Expand Down
296 changes: 294 additions & 2 deletions src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ inline matrix<float, _kv_step, _q_step> ugemm_KQ(uint slm_K, matrix_ref<half, nu
template<int num_P_tiles = REG_N/REG_M, int num_rO_tiles>
inline void ugemm_PV0(uint slm_V, matrix_ref<half, REG_N, REG_K> P, matrix_ref<float, num_rO_tiles, REG_M*REG_N> rO, uint slm_offset = 0) {
constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles;

auto P2 = P.format<half, num_P_tiles, REG_M * REG_K>();
#pragma unroll
for(int k = 0, ri = 0; k < _head_size; k += REG_N, ri += num_P_tiles) {
Expand Down Expand Up @@ -312,6 +312,74 @@ vector<float, cols> online_softmax_update(matrix_ref<T, rows, cols> St, vector_r
return max_comp;
}

#ifdef CM_HAS_LSC_UNTYPED_2D
#define cm_load_normal cm_load<lsc::Normal>
#define cm_load_transpose cm_load<lsc::Transpose>
#define cm_load_vnni cm_load<lsc::VNNI>
#define cm_store_normal cm_store
#else
// simulation of LSC API using SVM API
template <typename T = int, unsigned NBlocks = 1, unsigned BlockH = 1, unsigned BlockW = 1>
inline void cm_load_normal(vector_ref<T, NBlocks*BlockH*BlockW> Res, const lsc::block_2d_desc<T, NBlocks, BlockH, BlockW> &Desc, int16_t Pred = 1) {
static_assert(NBlocks == 1);
auto pitch = Desc.get_pitch() + 1;
auto base = reinterpret_cast<svmptr_t>(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T));
#pragma unroll
for(int i = 0; i < BlockH; i++) {
cm_svm_block_read(base + i * pitch, Res.select<BlockW, 1>(i*BlockW));
}
}

template <typename T = int, unsigned NBlocks = 1, unsigned BlockH = 1, unsigned BlockW = 1>
inline void cm_load_transpose(vector_ref<T, NBlocks*BlockW*BlockH> Res, const lsc::block_2d_desc<T, NBlocks, BlockH, BlockW> &Desc, int16_t Pred = 1) {
static_assert(NBlocks == 1);
auto pitch = Desc.get_pitch() + 1;
auto base = reinterpret_cast<svmptr_t>(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T));
matrix<T, BlockH, BlockW> temp;
#pragma unroll
for(int i = 0; i < BlockH; i++) {
cm_svm_block_read(base + i * pitch, temp[i]);
}
Transpose2DMatrix(temp, Res.format<T, BlockW, BlockH>());
}

// in VNNI case, NBlocks is increasing along X dimension (increase cache-line usage)
template <typename T = int, unsigned NBlocks = 1, unsigned BlockH = 1, unsigned BlockW = 1>
inline void cm_load_vnni(vector_ref<T, NBlocks*BlockW*BlockH> Res, const lsc::block_2d_desc<T, NBlocks, BlockH, BlockW> &Desc, int16_t Pred = 1) {
static_assert(NBlocks == 1 || NBlocks == 2);
// each block must be a full XMX B matrix
static_assert(BlockH == REG_K);
static_assert(BlockW == REG_N);
auto pitch = Desc.get_pitch() + 1;
auto base = reinterpret_cast<svmptr_t>(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T));
matrix<T, BlockH, NBlocks * BlockW> temp;
#pragma unroll
for(int i = 0; i < BlockH; i++) {
cm_svm_block_read(base + i * pitch, temp[i]);
}

auto out_vnni = Res.format<T, NBlocks * (BlockH/2), 2*BlockW>();
#pragma unroll
for(int i = 0; i < NBlocks; i ++) {
out_vnni.select<BlockH/2, 1, BlockW, 2>(i*(BlockH/2), 0) = temp.select<BlockH/2, 2, BlockW, 1>(0, i*BlockW);
out_vnni.select<BlockH/2, 1, BlockW, 2>(i*(BlockH/2), 1) = temp.select<BlockH/2, 2, BlockW, 1>(1, i*BlockW);
}
}

template <typename T = int, unsigned NBlocks = 1, unsigned BlockH = 1, unsigned BlockW = 1>
inline void cm_store_normal(const lsc::block_2d_desc<T, NBlocks, BlockH, BlockW> &Desc, vector_ref<T, NBlocks*BlockW*BlockH> Res) {
static_assert(NBlocks == 1);
auto pitch = Desc.get_pitch() + 1;
auto base = reinterpret_cast<svmptr_t>(Desc.get_base() + Desc.get_block_y()*pitch + Desc.get_block_x() * sizeof(T));
#pragma unroll
for(int i = 0; i < BlockH; i++) {
cm_svm_block_write(base + i * pitch, Res.select<BlockW, 1>(i*BlockW));
}
}
#endif



//===============================================================================================
template <int i, int N, int M>
constexpr void apply_causal_mask(matrix_ref<float, N, M> St) {
Expand All @@ -322,6 +390,7 @@ constexpr void apply_causal_mask(matrix_ref<float, N, M> St) {
}

#ifdef CM_HAS_LSC_UNTYPED_2D

template<bool use_causal_mask, int num_heads, int num_kv_heads, int head_size, int is_qkv_fused = 0>
void sdpa_kernel_lsc(
uint slm_K,
Expand Down Expand Up @@ -482,7 +551,6 @@ void sdpa_kernel_lsc(
}
}


template<bool use_causal_mask, int num_heads, int num_kv_heads, int head_size, int is_qkv_fused, int wg_local_size>
void sdpa_kernel_lsc_prefetch(
int wg_local_id,
Expand Down Expand Up @@ -662,6 +730,230 @@ void sdpa_kernel_lsc_prefetch(
cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format<half, num_P_tiles, REG_M * REG_N>().row(1));
}
}

template<bool use_causal_mask, int num_heads, int num_kv_heads, int head_size, int is_qkv_fused, int wg_local_size>
void pa_kernel_lsc_prefetch(
int wg_local_id,
int q_start,
int kv_stop, //
int q_len, //q_step
int kv_len, //not used for now
svmptr_t q_base [[type("svmptr_t")]],
svmptr_t k_base [[type("svmptr_t")]],
svmptr_t v_base [[type("svmptr_t")]],
#if SPARSE_BLOCK_SIZE > 1
svmptr_t sparse_mask_base [[type("svmptr_t")]],
int xattn_k_block_pad,
#endif
svmptr_t o_base [[type("svmptr_t")]],
int32_t past_lens,
int32_t* block_indices [[type("svmptr_t")]]) {
constexpr uint o_pitch = (num_heads * head_size * sizeof(half));
constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch;
// constexpr uint k_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half));
// constexpr uint v_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half));
//[block_num, kv_heads, block_size, head_size]
constexpr uint k_pitch = head_size * sizeof(half);
constexpr uint v_pitch = k_pitch;

vector<float, q_step> cur_max;
vector<float, q_step> cur_sum;

bool need_comp = false;

cur_max = -3e38f;
cur_sum = 0;
constexpr int num_P_tiles = REG_N / REG_M;
matrix<half, head_size/REG_K, REG_K*REG_N> rQ;
matrix <float, head_size/REG_N*num_P_tiles, REG_M*REG_N> rO;

auto q_tokens_left = q_len;
static_assert(q_step == REG_N);
static_assert(kv_step == REG_K);

if (q_tokens_left < 0) q_tokens_left = 0;
if (q_tokens_left > q_step) q_tokens_left = q_step;

#if SPARSE_BLOCK_SIZE > 1
// printf("wg:%d.%d q: %d, +%d kv: %d, x-attn: %p\n", 0, wg_local_id, q_start, q_tokens_left, kv_stop, reinterpret_cast<bool*>(sparse_mask_base));
#endif

if (q_tokens_left > 0) {
lsc::block_2d_desc<uint, 1, REG_N, REG_K/2> b2dQ(reinterpret_cast<uint*>(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0);
#pragma unroll
for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) {
cm_load<lsc::Transpose>(rQ[ri].format<uint>(), b2dQ.set_block_x(k));
rQ[ri].format<half>() = cm_mul<half>(rQ[ri].format<half>(), (half)scale_factor);
}
}

lsc::block_2d_desc<half, 1, kv_step, REG_K> b2dK(k_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0);
lsc::block_2d_desc<half, 1, REG_K, REG_N> b2dV(v_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0);

static_assert(wg_local_size == 16);
lsc::block_2d_desc<half, 1, kv_step/wg_local_size, REG_K> prefetch_K(k_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, k_pitch - 1, 0, 0);
lsc::block_2d_desc<half, 1, REG_K/wg_local_size, REG_N> prefetch_V(v_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(half) - 1, v_pitch - 1, 0, 0);
constexpr int blk_stride = CMFLA_NUM_KV_HEADS*CMFLA_HEAD_SIZE*CMPA_BLOCK_SZ;
int causal_left = q_start+past_lens;

for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) {
auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ];
//For the last step, duplicate prefetch here.
uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step);
auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ];
//# St = k @ Qt
matrix<float, kv_step, q_step> St; // = ugemm_KQ(slm_K, rQ, slm_offset);
{
constexpr int num_K = kv_step/REG_M;
auto St2 = St.format<float, num_K, REG_M*REG_N>();

matrix<half, num_K, REG_M * REG_K> Kmat;
//cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format<half>());

prefetch_K.set_base_ptr((reinterpret_cast<half*>(k_base)+prefetch_block_id*blk_stride));
prefetch_K.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ);
cm_prefetch<CacheHint::Cached, CacheHint::Cached>(prefetch_K.set_block_x(0));

#if SPARSE_BLOCK_SIZE > 1
{
auto kv_start_block = kv_pos/ SPARSE_BLOCK_SIZE;
// when kv_len % 128 < 16, the result will be discard, find_block will not generate mask for the point
if (kv_start_block < xattn_k_block_pad) {
bool sparse_mask = *(reinterpret_cast<bool*>(sparse_mask_base) + kv_start_block);
if (!sparse_mask) {
if constexpr (use_causal_mask) {
causal_left -= kv_step;
}
continue;
}
}
}
#endif

b2dK.set_base_ptr((reinterpret_cast<half*>(k_base)+cur_block_id*blk_stride));
b2dK.set_block_y(kv_pos%CMPA_BLOCK_SZ);
cm_load<lsc::Normal>(Kmat.format<half>(), b2dK.set_block_x(0));
#pragma unroll
for(int k = 0; k < num_K; k++)
St2.row(k) = cm_dpas<CM_PRECISION_HF, CM_PRECISION_HF, SystolicDepth, RepeatCount, float>(
0,
rQ[0].format<int32_t>(),
Kmat[k].format<int32_t>());

#pragma unroll
for(int ri = 1; ri < head_size/REG_K; ri++) {
//cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format<half>());
cm_prefetch<CacheHint::Cached, CacheHint::Cached>(prefetch_K.set_block_x(ri*REG_K));
cm_load<lsc::Normal>(Kmat.format<half>(), b2dK.set_block_x(ri*REG_K));
#pragma unroll
for(int k = 0; k < num_K; k++) {
St2.row(k) = cm_dpas<CM_PRECISION_HF, CM_PRECISION_HF, SystolicDepth, RepeatCount, float>(
St2.row(k),
rQ[ri].format<int32_t>(),
Kmat[k].format<int32_t>());
}
}
}
if constexpr (use_causal_mask) {
// since kv_step == q_step == 16, causal_left is n*kv_step
if (causal_left == 0) {
apply_causal_mask<1>(St);
} else if (causal_left < 0) {
St = -3.4e38f;
}
causal_left -= kv_step;
} else {
int kv_tokens = kv_stop - kv_pos;
// LSC ensures no overflow-access, but mask off k-tails attn-score is still required
for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f;
}

// show(St);
auto max_comp = online_softmax_update(St, cur_max, cur_sum);

matrix<half, REG_N, REG_K> P;
Transpose2DMatrix(St, P);

prefetch_V.set_base_ptr((reinterpret_cast<half*>(v_base)+prefetch_block_id*blk_stride));
prefetch_V.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ);

b2dV.set_base_ptr((reinterpret_cast<half*>(v_base)+cur_block_id*blk_stride));
b2dV.set_block_y(kv_pos%CMPA_BLOCK_SZ);
if (need_comp == false) {
// ugemm_PV0(slm_V, P, rO, slm_offset);
auto P2 = P.format<half, num_P_tiles, REG_M * REG_K>();
#pragma unroll
for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) {
matrix<half, REG_K/2, REG_N*2> Vmat;
cm_prefetch<CacheHint::Cached, CacheHint::Cached>(prefetch_V.set_block_x(k));
cm_load<lsc::VNNI>(Vmat.format<half>(), b2dV.set_block_x(k));
#pragma unroll
for(int p = 0; p < num_P_tiles; p++) {
rO[ri + p] = cm_dpas<CM_PRECISION_HF, CM_PRECISION_HF, SystolicDepth, RepeatCount, float>(
0,
Vmat.format<int32_t>(),
P2.row(p).format<int32_t>());
// show(rO[ri + p].format<float, REG_M, REG_N>());
}
}

need_comp = true;
}
else {
//ugemm_PV1(slm_V, P, max_comp, rO, slm_offset);
auto P2 = P.format<half, num_P_tiles, REG_M * REG_K>();
#pragma unroll
for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) {
matrix<half, REG_K/2, REG_N*2> Vmat;

cm_prefetch<CacheHint::Cached, CacheHint::Cached>(prefetch_V.set_block_x(k));
cm_load<lsc::VNNI>(Vmat.format<half>(), b2dV.set_block_x(k));

//# compensate cur_O
// matrix <float, head_size/REG_K*2, REG_M*REG_N> rO;
#pragma unroll
for(int p = 0; p < num_P_tiles; p++) {
auto cO = rO[ri + p].format<float, REG_M, REG_N>();
#pragma unroll
for(int r = 0; r < REG_M; r++)
cO.row(r) = cm_mul<float>(cO.row(r), max_comp[r + p*REG_M]);
}

#pragma unroll
for(int p = 0; p < num_P_tiles; p++) {
rO[ri + p] = cm_dpas<CM_PRECISION_HF, CM_PRECISION_HF, SystolicDepth, RepeatCount>(
rO[ri + p].format<float>(),
Vmat.format<int32_t>(),
P2.row(p).format<int32_t>());
// show(rO[ri + p].format<float, REG_M, REG_N>());
}
}
}
}
if (q_tokens_left == 0) return;

//# save cur_O/cur_sum.transpose(0, 1)
matrix<half, num_P_tiles*REG_M, REG_N> cur_O_f16;
cur_sum = cm_inv(cur_sum);

lsc::block_2d_desc<half, 1, REG_M, REG_N> b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0);

#pragma unroll
for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) {
#pragma unroll
for(int p = 0; p < num_P_tiles; p++) {
auto cO = rO[ri + p].format<float, REG_M, REG_N>();
#pragma unroll
for(int r = 0; r < cO.n_rows(); r++) {
cur_O_f16[r + p*REG_M] = cm_mul<float>(cO.row(r), cur_sum[r + p*REG_M]);

}
}
b2dO.set_block_x(k);
cm_store(b2dO.set_block_y(0), cur_O_f16.format<half, num_P_tiles, REG_M * REG_N>().row(0));
cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format<half, num_P_tiles, REG_M * REG_N>().row(1));
}
}
#endif

template<bool use_causal_mask, int num_heads, int num_kv_heads, int head_size, int is_qkv_fused = 0>
Expand Down
Loading
Loading