sgemm_wg: Parameterize threadblock dimensions

This commit is contained in:
Hansung Kim
2024-02-17 18:05:59 -08:00
parent 301f1ca260
commit d2da0d3394
2 changed files with 56 additions and 24 deletions

View File

@@ -116,7 +116,7 @@ int run_test(const kernel_arg_t& kernel_arg,
for (uint32_t i = 0; i < dim_m * dim_n; ++i) {
float ref = ref_data.at(i);
float cur = buf_ptr[i];
if (cur != ref) {
if (std::abs((cur - ref) / ref) > 1e-6) {
std::cout << "error at result #" << std::dec << i
<< std::hex << ": actual=" << cur << ", expected=" << ref << std::endl;
++errors;
@@ -146,9 +146,10 @@ int main(int argc, char *argv[]) {
std::cout << "open device connection" << std::endl;
RT_CHECK(vx_dev_open(&device));
uint32_t dim_m = 4; // FIXME: hardcoded
uint32_t dim_n = 4; // FIXME: hardcoded
uint32_t dim_k = 128; // FIXME: hardcoded
// FIXME: hardcoded
uint32_t dim_m = 16;
uint32_t dim_n = 16;
uint32_t dim_k = 32;
generate_source_matrix(dim_m, dim_n, dim_k);
generate_reference_matmul(dim_m, dim_n, dim_k);