@@ -41,6 +41,54 @@ static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& v
41
41
#endif
42
42
}
43
43
44
+ // static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
45
+ // const half * h = (const half *)s;
46
+ // s[0] = trellis_next(val1);
47
+ // s[1] = trellis_next(val1);
48
+ // s[2] = trellis_next(val1);
49
+ // s[3] = trellis_next(val1);
50
+ // #ifdef GGML_CUDA_F16
51
+ // bdot1 = __hfma2(y[ 0], {h[0]+h[1]+h[2]+h[3], h[4]+h[5]+h[6]+h[7]}, bdot1);
52
+ // #else
53
+ // bdot1.x += y[ 0].x * (float)(h[0] + h[1] + h[2] + h[3]);
54
+ // bdot1.y += y[ 0].y * (float)(h[4] + h[5] + h[6] + h[7]);
55
+ // #endif
56
+ // s[0] = trellis_next(val2);
57
+ // s[1] = trellis_next(val2);
58
+ // s[2] = trellis_next(val2);
59
+ // s[3] = trellis_next(val2);
60
+ // #ifdef GGML_CUDA_F16
61
+ // bdot2 = __hfma2(y[64], {h[0]+h[1]+h[2]+h[3], h[4]+h[5]+h[6]+h[7]}, bdot2);
62
+ // #else
63
+ // bdot2.x += y[64].x * (float)(h[0] + h[1] + h[2] + h[3]);
64
+ // bdot2.y += y[64].y * (float)(h[4] + h[5] + h[6] + h[7]);
65
+ // #endif
66
+ // }
67
+
68
+ static __device__ __forceinline__ void trellis_accum_abs (uint8_t signs1, uint8_t signs2, uint8_t mask1, uint8_t mask2,
69
+ uint32_t & val1, uint32_t & val2, uint32_t * s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
70
+ const half * h = (const half *)s;
71
+ s[0 ] = trellis_next (val1);
72
+ s[1 ] = trellis_next (val1);
73
+ s[2 ] = trellis_next (val2);
74
+ s[3 ] = trellis_next (val2);
75
+ #ifdef GGML_CUDA_F16
76
+ half h00 = __habs (h[0 ]+h[1 ]), h01 = __habs (h[2 ]+h[3 ]);
77
+ half h10 = __habs (h[4 ]+h[5 ]), h11 = __habs (h[6 ]+h[7 ]);
78
+ half2 h1 = {signs1 & mask1 ? -h00 : h00, signs2 & mask1 ? -h01 : h01};
79
+ half2 h2 = {signs1 & mask2 ? -h10 : h10, signs2 & mask2 ? -h11 : h11};
80
+ // half2 h1 = __hmul2(__habs2({h[0]+h[1], h[2]+h[3]}), {signs1 & mask1 ? -1 : 1, signs2 & mask1 ? -1 : 1});
81
+ // half2 h2 = __hmul2(__habs2({h[4]+h[5], h[6]+h[7]}), {signs1 & mask2 ? -1 : 1, signs2 & mask2 ? -1 : 1});
82
+ bdot1 = __hfma2 (y[ 0 ], h1, bdot1);
83
+ bdot2 = __hfma2 (y[64 ], h2, bdot2);
84
+ #else
85
+ bdot1.x += y[ 0 ].x * fabsf ((float )(h[0 ] + h[1 ])) * (signs1 & mask1 ? -1 : 1 );
86
+ bdot1.y += y[ 0 ].y * fabsf ((float )(h[2 ] + h[3 ])) * (signs2 & mask1 ? -1 : 1 );
87
+ bdot2.x += y[64 ].x * fabsf ((float )(h[4 ] + h[5 ])) * (signs1 & mask2 ? -1 : 1 );
88
+ bdot2.y += y[64 ].y * fabsf ((float )(h[6 ] + h[7 ])) * (signs2 & mask2 ? -1 : 1 );
89
+ #endif
90
+ }
91
+
44
92
static __device__ __forceinline__ void trellis_accum (const dfloat2& dl1, const dfloat2& dl2, const dfloat2& bdot1, const dfloat2& bdot2, dfloat2& tmp) {
45
93
#ifdef GGML_CUDA_F16
46
94
tmp = __hfma2 (dl1, bdot1, tmp);
@@ -114,25 +162,23 @@ static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ v
114
162
115
163
uint32_t s[4 ];
116
164
165
+ uint8_t mask1 = 1 << (it/4 );
166
+ uint8_t mask2 = mask1 << 4 ;
167
+
117
168
for (int i = ix; i < num_blocks_per_row; i += 2 ) {
118
169
const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8 *it);
119
- const uint8_t * ql = x[i].ql ;
120
- const uint8_t * qh = x[i].qh ;
121
- const dfloat scale1 = iq4k_values[ (x[i].scales [it/4 ] & 0xf )+ 16 ] ;
122
- const dfloat scale2 = iq4k_values[ (x[i].scales [it/4 ] >> 4 )+ 16 ] ;
170
+ const uint16_t * ql = ( const uint16_t *) x[i].ql ;
171
+ const uint8_t * qh = x[i].qh ;
172
+ const dfloat scale1 = (x[i].scales [it/4 ] & 0xf );
173
+ const dfloat scale2 = (x[i].scales [it/4 ] >> 4 );
123
174
const dfloat2 dl1 = {scale1, scale1};
124
175
const dfloat2 dl2 = {scale2, scale2};
125
176
dfloat2 bdot1 = {0 , 0 };
126
177
dfloat2 bdot2 = {0 , 0 };
127
- uint32_t val1 = ql[2 *it+ 0 ] + ((qh[2 *it+0 ] << 8 ) & 0xf00 ) + 4096 ;
128
- uint32_t val2 = ql[2 *it+32 ] + ((qh[2 *it+0 ] << 4 ) & 0xf00 ) + 4096 ;
129
- for (int k = 0 ; k < 2 ; ++k) {
130
- trellis_accum (val1, val2, s, y+k, bdot1, bdot2);
131
- }
132
- val1 = ql[2 *it+ 1 ] + ((qh[2 *it+1 ] << 8 ) & 0xf00 ) + 4096 ;
133
- val2 = ql[2 *it+33 ] + ((qh[2 *it+1 ] << 4 ) & 0xf00 ) + 4096 ;
134
- for (int k = 2 ; k < 4 ; ++k) {
135
- trellis_accum (val1, val2, s, y+k, bdot1, bdot2);
178
+ uint32_t val1 = ql[it+ 0 ] + 4096 ;
179
+ uint32_t val2 = ql[it+16 ] + 4096 ;
180
+ for (int k = 0 ; k < 4 ; ++k) {
181
+ trellis_accum_abs (qh[(8 *it+2 *k+0 )%32 ], qh[(8 *it+2 *k+1 )%32 ], mask1, mask2, val1, val2, s, y+k, bdot1, bdot2);
136
182
}
137
183
trellis_accum (dl1, dl2, bdot1, bdot2, tmp);
138
184
}
0 commit comments