diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 9978a96..4094af1 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -7,7 +7,8 @@ import chisel3._ import chisel3.util._ import freechips.rocketchip.tile -class DPUPipe extends Module with tile.HasFPUParameters { +// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. +class TensorDotProductUnit extends Module with tile.HasFPUParameters { val fLen = 32 val minFLen = 32 def xLen = 32 @@ -19,6 +20,7 @@ class DPUPipe extends Module with tile.HasFPUParameters { val b = Vec(dotProductDim, Bits((fLen).W)) val c = Bits((fLen).W) })) + val stall = Input(Bool()) val out = Valid(new Bundle { val data = Bits((fLen).W) }) @@ -30,20 +32,12 @@ class DPUPipe extends Module with tile.HasFPUParameters { val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) val in3 = unbox(recode(io.in.bits.c, S), S, Some(tile.FType.S)) - // val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig)) - // fma.io.validin := io.in.valid - // fma.io.op := 0.U // FIXME - // fma.io.roundingMode := hardfloat.consts.round_near_even - // fma.io.detectTininess := hardfloat.consts.tininess_afterRounding - // fma.io.a := unbox(in1, S, Some(tile.FType.S)) - // fma.io.b := unbox(in2, S, Some(tile.FType.S)) - // fma.io.c := unbox(in3, S, Some(tile.FType.S)) - val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig)) dpu.io.in.valid := io.in.valid dpu.io.in.bits.a := in1 dpu.io.in.bits.b := in2 dpu.io.in.bits.c := in3 + dpu.io.stall := io.stall io.out.valid := dpu.io.out.valid io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) @@ -63,6 +57,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { // val roundingMode = UInt(3.W) // val detectTininess = UInt(1.W) })) + val stall = Input(Bool()) val out = Valid(new Bundle { val data = Bits((recFLen).W) }) @@ -70,7 +65,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth))) mul.zipWithIndex.foreach { case (m, i) => - m.io.roundingMode := hardfloat.consts.round_near_even // consts.round_near_maxMag + // FIXME: these settings are arbitrary + m.io.roundingMode := hardfloat.consts.round_near_even m.io.detectTininess := hardfloat.consts.tininess_afterRounding m.io.a := io.in.bits.a(i) m.io.b := io.in.bits.b(i) @@ -79,7 +75,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out))) val mulStageC = Pipe(io.in.valid, io.in.bits.c) - // mul stage end ------------------------------------------------------- + // mul stage end ------------------------------------------------------------- val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth))) add1.zipWithIndex.foreach { case (a, i) => @@ -93,7 +89,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out))) val add1StageC = Pipe(mulStageC) - // add1 stage end ----------------------------------------------------- + // add1 stage end ------------------------------------------------------------ val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth)) add2.io.subOp := 0.U // FIXME @@ -106,7 +102,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { val add2StageOut = Pipe(add1StageOut.valid, add2.io.out) val add2StageC = Pipe(add1StageC) - // add2 stage end ----------------------------------------------------- + // add2 stage end ------------------------------------------------------------ val acc = Module(new hardfloat.AddRecFN(expWidth, sigWidth)) acc.io.subOp := 0.U // FIXME @@ -119,7 +115,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { io.out.bits.data := Pipe(add2StageOut.valid, acc.io.out).bits // FIXME: exception output ignored - // acc stage end ----------------------------------------------------- + // acc stage end ------------------------------------------------------------- } class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {