diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 6050963..8ba13ef 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -8,9 +8,12 @@ import chisel3.util._ import freechips.rocketchip.tile // Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. -class TensorDotProductUnit extends Module with tile.HasFPUParameters { - val fLen = 32 - val minFLen = 32 +// `half`: if True, generate fp16 MACs; if False fp32. +class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUParameters { + val t = if (half) tile.FType.H else tile.FType.S + + val fLen = t.ieeeWidth + val minFLen = 16 // fp16 def xLen = 32 val dotProductDim = 4 @@ -26,12 +29,12 @@ class TensorDotProductUnit extends Module with tile.HasFPUParameters { }) }) - val t = tile.FType.S - - // IEEE -> recode() -> unbox() -> Hardfloat -> box() -> ieee() -> IEEE - val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) - val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) - val in3 = unbox(recode(io.in.bits.c, S), S, Some(tile.FType.S)) + // [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE] + // make sure recoding/uncoding happens only at the edge, not at every + // pipeline stage inside the dpu + val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(t))) + val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(t))) + val in3 = unbox(recode(io.in.bits.c, S), S, Some(t)) val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig)) dpu.io.in.valid := io.in.valid @@ -88,8 +91,8 @@ object StallingPipe { } } -// Computes d = a(0)*b(0) + ... + a(3)*b(3) + c. -// Fully pipelined with a fixed latency of 4 cycles. +// Computes d = a(0)*b(0) + ... + a(`dim`-1)*b(`dim`-1) + c. +// Fully pipelined with a fixed latency determined by `dim`. class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { require(dim == 4, "DPU currently only supports dimension 4") diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index b31e3fe..53cffc3 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -46,23 +46,70 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { implicit val p: Parameters = Parameters.empty - it should "pass" in { - test(new TensorDotProductUnit) + it should "pass fp16" in { + test(new TensorDotProductUnit(half = true)) // .withAnnotations(Seq(VerilatorBackendAnnotation)) // .withAnnotations(Seq(WriteVcdAnnotation)) { c => c.io.in.valid.poke(true.B) c.io.stall.poke(false.B) // (1,3,5,7)*(2,4,6,8) + 9 = 109 - c.io.in.bits.a(0).poke(0x3f800000L.U(64.W)) - c.io.in.bits.a(1).poke(0x40400000L.U(64.W)) - c.io.in.bits.a(2).poke(0x40a00000L.U(64.W)) - c.io.in.bits.a(3).poke(0x40e00000L.U(64.W)) - c.io.in.bits.b(0).poke(0x40000000L.U(64.W)) - c.io.in.bits.b(1).poke(0x40800000L.U(64.W)) - c.io.in.bits.b(2).poke(0x40c00000L.U(64.W)) - c.io.in.bits.b(3).poke(0x41000000L.U(64.W)) - c.io.in.bits.c .poke(0x41100000L.U(64.W)) + c.io.in.bits.a(0).poke(0x3c00.U(16.W)) + c.io.in.bits.a(1).poke(0x4200.U(16.W)) + c.io.in.bits.a(2).poke(0x4500.U(16.W)) + c.io.in.bits.a(3).poke(0x4700.U(16.W)) + c.io.in.bits.b(0).poke(0x4000.U(16.W)) + c.io.in.bits.b(1).poke(0x4400.U(16.W)) + c.io.in.bits.b(2).poke(0x4600.U(16.W)) + c.io.in.bits.b(3).poke(0x4800.U(16.W)) + c.io.in.bits.c .poke(0x4880.U(16.W)) + + c.io.out.valid.expect(false.B) + + c.clock.step() + c.io.in.valid.poke(false.B) + c.io.out.valid.expect(false.B) + + // stall the pipeline + c.io.stall.poke(true.B) + c.clock.step() + c.io.stall.poke(true.B) + c.clock.step() + c.io.stall.poke(true.B) + c.clock.step() + c.io.stall.poke(false.B) + + c.clock.step() + c.clock.step() + c.clock.step() + // 4-cycle latency + stalls + + c.io.out.valid.expect(true.B) + c.io.out.bits.data.expect(0x56d0.U) + + c.clock.step() + + c.io.out.valid.expect(false.B) + } + } + + it should "pass fp32" in { + test(new TensorDotProductUnit(half = false)) + // .withAnnotations(Seq(VerilatorBackendAnnotation)) + // .withAnnotations(Seq(WriteVcdAnnotation)) + { c => + c.io.in.valid.poke(true.B) + c.io.stall.poke(false.B) + // (1,3,5,7)*(2,4,6,8) + 9 = 109 + c.io.in.bits.a(0).poke(0x3f800000L.U(32.W)) + c.io.in.bits.a(1).poke(0x40400000L.U(32.W)) + c.io.in.bits.a(2).poke(0x40a00000L.U(32.W)) + c.io.in.bits.a(3).poke(0x40e00000L.U(32.W)) + c.io.in.bits.b(0).poke(0x40000000L.U(32.W)) + c.io.in.bits.b(1).poke(0x40800000L.U(32.W)) + c.io.in.bits.b(2).poke(0x40c00000L.U(32.W)) + c.io.in.bits.b(3).poke(0x41000000L.U(32.W)) + c.io.in.bits.c .poke(0x41100000L.U(32.W)) c.io.out.valid.expect(false.B)