/*************************************************************** * _ * * | | * * __ _ ___ _ __ ___ _ __ ___ ___ | | ___ __ _ _ _ * * / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _ \| |/ _ \ / _` | | | | * * | (_| | __/ | | | | | | | | | | (_) | | (_) | (_| | |_| | * * \__, |\___|_| |_| |_|_| |_| |_|\___/|_|\___/ \__, |\__, | * * __/ | __/ | __/ | * * |___/ |___/ |___/ * * * * version 0.1 * ***************************************************************/ #ifndef GEMMOLOGY_FWD_H #define GEMMOLOGY_FWD_H #include #include #include #include namespace gemmology { namespace callbacks { struct Unquantize { float unquant_mult; template xsimd::batch operator()(xsimd::batch total, size_t, size_t, size_t); template std::tuple, xsimd::batch> operator()( std::tuple, xsimd::batch> total, size_t, size_t, size_t); }; struct AddBias { const float *bias_addr; template xsimd::batch operator()(xsimd::batch total, size_t, size_t col_idx, size_t); template std::tuple, xsimd::batch> operator()( std::tuple, xsimd::batch> total, size_t, size_t col_idx, size_t); }; struct Write { float *output_addr; Write(float *o) : output_addr(o) {} template void operator()(xsimd::batch result, size_t row_idx, size_t col_idx, size_t col_size); template void operator()(xsimd::batch result, size_t row_idx, size_t col_idx, size_t col_size); template void operator()( std::tuple, xsimd::batch> result, size_t row_idx, size_t col_idx, size_t col_size); template void operator()( std::tuple, xsimd::batch> result, size_t row_idx, size_t col_idx, size_t col_size); }; struct UnquantizeAndWrite { Unquantize unquantize; Write write; UnquantizeAndWrite(float factor, float *output) : unquantize{factor}, write{output} {} template void operator()(T const &total, size_t row_idx, size_t col_idx, size_t col_size); }; struct UnquantizeAndAddBiasAndWrite { Unquantize unquantize; AddBias add_bias; Write write; UnquantizeAndAddBiasAndWrite(float factor, const float *bias, float *output) : unquantize{factor}, add_bias{bias}, write{output} {} template void operator()(T const &total, size_t row_idx, size_t col_idx, size_t col_size); }; } // namespace callbacks // // Arch-specific implementation of each routine // template struct Engine { static void QuantizeU(const float *input, uint8_t *output, float quant_mult, size_t size); static void Quantize(const float *const input, int8_t *const output, float quant_mult, size_t size); template static void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, const IntegerTy *cols_begin, const IntegerTy *cols_end); static void PrepareBTransposed(const float *input, int8_t *output, float quant_mult, size_t cols, size_t rows); static void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, size_t cols, size_t rows); static void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, size_t rows, size_t cols); static void PrepareA(const float *input, int8_t *output, float quant_mult, size_t rows, size_t cols); struct Shift { static void PrepareA(const float *input, uint8_t *output, float quant_mult, size_t rows, size_t cols); template static void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, size_t width, size_t B_cols, Callback callback); template static void PrepareBias(const int8_t *B, size_t width, size_t B_cols, Callback C); }; }; // // Top-level wrappers that mostly match intgemm API // template inline void QuantizeU(const float *input, uint8_t *output, float quant_mult, size_t size) { return Engine::QuantizeU(input, output, quant_mult, size); } template inline void Quantize(const float *const input, int8_t *const output, float quant_mult, size_t size) { return Engine::Quantize(input, output, quant_mult, size); } template inline void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, const IntegerTy *cols_begin, const IntegerTy *cols_end) { return Engine::SelectColumnsB(input, output, rows, cols_begin, cols_end); } template inline void PrepareBTransposed(const float *input, int8_t *output, float quant_mult, size_t cols, size_t rows) { return Engine::PrepareBTransposed(input, output, quant_mult, cols, rows); } template inline void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, size_t cols, size_t rows) { return Engine::PrepareBQuantizedTransposed(input, output, cols, rows); } template inline void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, size_t rows, size_t cols) { return Engine::PrepareB(input, output_shadow, quant_mult, rows, cols); } template inline void PrepareA(const float *input, int8_t *output, float quant_mult, size_t rows, size_t cols) { return Engine::PrepareA(input, output, quant_mult, rows, cols); } namespace Shift { template inline void PrepareA(const float *input, uint8_t *output, float quant_mult, size_t rows, size_t cols) { return Engine::Shift::PrepareA(input, output, quant_mult, rows, cols); } template inline void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, size_t width, size_t B_cols, Callback C) { return Engine::Shift::Multiply(A, B, A_rows, width, B_cols, C); } template inline void PrepareBias(const int8_t *B, size_t width, size_t B_cols, Callback C) { return Engine::Shift::PrepareBias(B, width, B_cols, C); } } // namespace Shift } // namespace gemmology #endif