diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 4094af1..dd848d4 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -72,8 +72,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { m.io.b := io.in.bits.b(i) } - val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out))) - val mulStageC = Pipe(io.in.valid, io.in.bits.c) + val mulStageOut = Pipe(!io.stall && io.in.valid, VecInit(mul.map(_.io.out))) + val mulStageC = Pipe(!io.stall && io.in.valid, io.in.bits.c) // mul stage end ------------------------------------------------------------- @@ -86,8 +86,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { a.io.detectTininess := hardfloat.consts.tininess_afterRounding } - val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out))) - val add1StageC = Pipe(mulStageC) + val add1StageOut = Pipe(!io.stall && mulStageOut.valid, VecInit(add1.map(_.io.out)), latency = 0) + val add1StageC = Pipe(!io.stall && mulStageOut.valid, mulStageC.bits, latency = 0) // add1 stage end ------------------------------------------------------------ @@ -99,8 +99,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { add2.io.roundingMode := hardfloat.consts.round_near_even add2.io.detectTininess := hardfloat.consts.tininess_afterRounding - val add2StageOut = Pipe(add1StageOut.valid, add2.io.out) - val add2StageC = Pipe(add1StageC) + val add2StageOut = Pipe(!io.stall && add1StageOut.valid, add2.io.out, latency = 0) + val add2StageC = Pipe(!io.stall && add1StageOut.valid, add1StageC.bits, latency = 0) // add2 stage end ------------------------------------------------------------ @@ -111,11 +111,13 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { acc.io.roundingMode := hardfloat.consts.round_near_even acc.io.detectTininess := hardfloat.consts.tininess_afterRounding - io.out.valid := Pipe(add2StageOut.valid, false.B).valid - io.out.bits.data := Pipe(add2StageOut.valid, acc.io.out).bits + val accStageOut = Pipe(!io.stall && add2StageOut.valid, acc.io.out) // FIXME: exception output ignored // acc stage end ------------------------------------------------------------- + + io.out.valid := accStageOut.valid + io.out.bits.data := accStageOut.bits } class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index 798676a..1fbfd4d 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -49,33 +49,37 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { it should "pass" in { test(new TensorDotProductUnit) .withAnnotations(Seq(VerilatorBackendAnnotation)) - .withAnnotations(Seq(WriteVcdAnnotation)) + // .withAnnotations(Seq(WriteVcdAnnotation)) { c => c.io.in.valid.poke(true.B) c.io.stall.poke(false.B) - // (2,2,2,2)*(2,2,2,2) + 3 = 19 - c.io.in.bits.a(0).poke(0x40000000L.U(64.W)) - c.io.in.bits.a(1).poke(0x40000000L.U(64.W)) - c.io.in.bits.a(2).poke(0x40000000L.U(64.W)) - c.io.in.bits.a(3).poke(0x40000000L.U(64.W)) + // (1,3,5,7)*(2,4,6,8) + 9 = 109 + c.io.in.bits.a(0).poke(0x3f800000L.U(64.W)) + c.io.in.bits.a(1).poke(0x40400000L.U(64.W)) + c.io.in.bits.a(2).poke(0x40a00000L.U(64.W)) + c.io.in.bits.a(3).poke(0x40e00000L.U(64.W)) c.io.in.bits.b(0).poke(0x40000000L.U(64.W)) - c.io.in.bits.b(1).poke(0x40000000L.U(64.W)) - c.io.in.bits.b(2).poke(0x40000000L.U(64.W)) - c.io.in.bits.b(3).poke(0x40000000L.U(64.W)) - c.io.in.bits.c .poke(0x40400000L.U(64.W)) + c.io.in.bits.b(1).poke(0x40800000L.U(64.W)) + c.io.in.bits.b(2).poke(0x40c00000L.U(64.W)) + c.io.in.bits.b(3).poke(0x41000000L.U(64.W)) + c.io.in.bits.c .poke(0x41100000L.U(64.W)) + + c.io.out.valid.expect(false.B) c.clock.step() c.io.in.valid.poke(false.B) + c.io.out.valid.expect(false.B) // stall the pipeline // c.io.stall.poke(true.B) c.clock.step() - c.io.stall.poke(false.B) - c.clock.step() - c.clock.step() + // c.io.stall.poke(false.B) + // c.io.out.valid.expect(false.B) + // c.clock.step() + // c.clock.step() // 4-cycle latency c.io.out.valid.expect(true.B) - c.io.out.bits.data.expect(0x41980000L.U) + c.io.out.bits.data.expect(0x42da0000L.U) c.clock.step()