Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ggml/src/iqk/iqk_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
#define IQK_IMPLEMENT
#endif

#ifdef GGML_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef GGML_BUILD
# define IQK_API __declspec(dllexport)
# else
# define IQK_API __declspec(dllimport)
# endif
# else
# define IQK_API __attribute__ ((visibility ("default")))
# endif
#else
# define IQK_API
#endif

#ifdef _MSC_VER
#define IQK_NOINLINE __declspec(noinline)
#define IQK_ALWAYS_INLINE inline
Expand Down
9 changes: 5 additions & 4 deletions ggml/src/iqk/iqk_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) {

// TODO: get the ggml_type enum here without polution
//
bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
int neq3, int neq2, long nbq3, long nbq2,
int nek3, int nek2, long nbk3, long nbk2,
int nev3, int nev2, long nbv3, long nbv2,
Expand Down Expand Up @@ -258,9 +258,10 @@ bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int ty
[[maybe_unused]] int nek3, [[maybe_unused]] int nek2, [[maybe_unused]] long nbk3, [[maybe_unused]] long nbk2,
[[maybe_unused]] int nev3, [[maybe_unused]] int nev2, [[maybe_unused]] long nbv3, [[maybe_unused]] long nbv2,
[[maybe_unused]] int ne2, [[maybe_unused]] int ne1, [[maybe_unused]] long nb1,
[[maybe_unused]] int int_type_k, // type of k
[[maybe_unused]] int int_type_v, // type of v
[[maybe_unused]] int D, // head size
[[maybe_unused]] int type_k, // type of k
[[maybe_unused]] int type_v, // type of v
[[maybe_unused]] int Dk, // K head size
[[maybe_unused]] int Dv, // V head size
[[maybe_unused]] int nq, // number of columns in q
[[maybe_unused]] int nk, // number of rows in k
[[maybe_unused]] int stride_q, // distance between q columns in bytes
Expand Down
16 changes: 8 additions & 8 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ struct MulMat {

}

bool iqk_mul_mat(long Nx, long Ny, long ne00,
extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth) {
Expand Down Expand Up @@ -440,7 +440,7 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
}
}

bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
long ne02, long ne03, long ne12, long ne13,
long nb02, long nb03, long nb12, long nb13, long nb2, long nb3,
int typeA, const void * A, long strideA,
Expand Down Expand Up @@ -545,7 +545,7 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
return true;
}

bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
Expand All @@ -571,7 +571,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
return true;
}

bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
Expand Down Expand Up @@ -17550,11 +17550,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k

#else // IQK_IMPLEMENT

bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
extern "C" IQK_API bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
return false;
}

bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/,
extern "C" IQK_API bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/,
long /*ne02*/, long /*ne03*/, long /*ne12*/, long /*ne13*/,
long /*nb02*/, long /*nb03*/, long /*nb12*/, long /*nb13*/, long /*nb2*/, long /*nb3*/,
int /*typeA*/, const void * /*A*/, long /*strideA*/,
Expand All @@ -17563,12 +17563,12 @@ bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/,
return false;
}

bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long,
extern "C" IQK_API bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long,
const void *, int, int) {
return false;
}

bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/,
extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/,
int /*typeA*/, const void * /*Aup*/, const void * /*Agate*/, long /*strideA*/,
int /*typeB*/, const void * /*B*/, long /*strideB*/,
float * /*C*/, long /*nb1*/, long /*nb2*/, const void * /*vrow_mapping*/, int /*ith*/, int /*nth*/) {
Expand Down
11 changes: 6 additions & 5 deletions ggml/src/iqk/iqk_mul_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,36 @@
#pragma once
#include <stdint.h>
#include <stdbool.h>
#include "iqk_config.h"
#ifdef __cplusplus
extern "C" {
#endif

bool iqk_mul_mat(long Nx, long Ny, long ne00,
IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth);

bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
long ne02, long ne03, long ne12, long ne13,
long nb02, long nb03, long nb12, long nb13, long nb2, long nb3,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth);

bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);

bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);

typedef void (*barrier_t) (void *);

bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
int neq3, int neq2, long nbq3, long nbq2,
int nek3, int nek2, long nbk3, long nbk2,
int nev3, int nev2, long nbv3, long nbv2,
Expand Down