From 3b71276c4ae41bfbca41a368f797989416975f2f Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 13 Nov 2024 16:01:11 -0800 Subject: [PATCH] tensor: Do dot-product in fp16, only do accum in fp32 This is to better match Gemmini PEs doing MACs in full fp16, and only doing accumulation in fp32. --- src/main/scala/radiance/core/TensorDPU.scala | 14 ++++++++++---- .../scala/radiance/TensorCoreDecoupledTest.scala | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index a4e6db0..ce131df 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -147,7 +147,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex val mulSigWidth = m.io.rawOut.sigWidth val roundRawFNToRecFN = Module(new hardfloat.RoundAnyRawFNToRecFN( - mulExpWidth, mulSigWidth, outExpWidth, outSigWidth, 0)) + mulExpWidth, mulSigWidth, expWidth, sigWidth, 0)) roundRawFNToRecFN.io.invalidExc := m.io.invalidExc roundRawFNToRecFN.io.infiniteExc := false.B roundRawFNToRecFN.io.in := m.io.rawOut @@ -169,7 +169,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex // instantiate wires for input values to each reduction pipeline stage val interim = (log2Dim to 0 by -1).map { i => - Wire(Valid(Vec(1 << i, Bits(recOutFLen.W)))) + Wire(Valid(Vec(1 << i, Bits(recInFLen.W)))) } // instantiate wires for pipe registers for C val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) ) @@ -186,7 +186,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex require(inputs.bits.length == 2 * outputs.bits.length) val thisDim = inputs.bits.length val adders = Seq.fill(thisDim / 2)( - Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) + Module(new hardfloat.AddRecFN(expWidth, sigWidth)) ) val addOuts = adders.zipWithIndex.map { case (a, i) => a.io.subOp := 0.U // FIXME dont know what this is @@ -212,9 +212,15 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex // add stages end ------------------------------------------------------------ // add final A and B dot-product result to accumulator C + val conv = Module(new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth)) + conv.io.in := addStageOut.bits(0) + conv.io.roundingMode := hardfloat.consts.round_near_even + conv.io.detectTininess := hardfloat.consts.tininess_afterRounding + // assert(conv.io.exceptionFlags === 0.U) + val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) acc.io.subOp := 0.U // FIXME - acc.io.a := addStageOut.bits(0) + acc.io.a := conv.io.out acc.io.b := addStageC.bits acc.io.roundingMode := hardfloat.consts.round_near_even acc.io.detectTininess := hardfloat.consts.tininess_afterRounding diff --git a/src/test/scala/radiance/TensorCoreDecoupledTest.scala b/src/test/scala/radiance/TensorCoreDecoupledTest.scala index 7b31eb7..619f699 100644 --- a/src/test/scala/radiance/TensorCoreDecoupledTest.scala +++ b/src/test/scala/radiance/TensorCoreDecoupledTest.scala @@ -9,7 +9,7 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester { behavior of "TensorCoreDecoupled" it should "do the right thing" in { - test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, tilingParams = TensorTilingParams())) + test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, half = true)) { c => c.io.initiate.valid.poke(true.B) c.io.initiate.bits.wid.poke(0.U)