diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 90ac883..20a59dd 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -10,22 +10,26 @@ import freechips.rocketchip.tile // Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. // `half`: if True, generate fp16 MACs; if False fp32. class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUParameters { - val t = if (half) tile.FType.H else tile.FType.S + val tIn = if (half) tile.FType.H else tile.FType.S + // output datatype fixed to single-precision + val tOut = tile.FType.S - val fLen = t.ieeeWidth + val inFLen = tIn.ieeeWidth + val outFLen = tOut.ieeeWidth + val fLen = outFLen // needed for HasFPUParameters val minFLen = 16 // fp16 def xLen = 32 val dotProductDim = 4 val io = IO(new Bundle { val in = Flipped(Valid(new Bundle { - val a = Vec(dotProductDim, Bits((fLen).W)) - val b = Vec(dotProductDim, Bits((fLen).W)) - val c = Bits((fLen).W) + val a = Vec(dotProductDim, Bits((inFLen).W)) + val b = Vec(dotProductDim, Bits((inFLen).W)) + val c = Bits((inFLen).W) })) val stall = Input(Bool()) val out = Valid(new Bundle { - val data = Bits((fLen).W) + val data = Bits((outFLen).W) }) }) @@ -33,11 +37,11 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar // make sure recoding/uncoding happens only at the edge, not at every // pipeline stage inside the dpu val tag = if (half) H else S - val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(t))) - val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(t))) - val in3 = unbox(recode(io.in.bits.c, tag), tag, Some(t)) + val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn))) + val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn))) + val in3 = unbox(recode(io.in.bits.c, tag), tag, Some(tIn)) - val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig)) + val dpu = Module(new DotProductPipe(dotProductDim, tIn, tOut)) dpu.io.in.valid := io.in.valid dpu.io.in.bits.a := in1 dpu.io.in.bits.b := in2 @@ -45,7 +49,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar dpu.io.stall := io.stall io.out.valid := dpu.io.out.valid - io.out.bits.data := ieee(box(dpu.io.out.bits.data, tag)) + io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) } // Copied from chisel3.util.Pipe. @@ -94,74 +98,116 @@ object StallingPipe { // Computes d = a(0)*b(0) + ... + a(`dim`-1)*b(`dim`-1) + c. // Fully pipelined with a fixed latency determined by `dim`. -class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { +class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) extends Module { require(dim == 4, "DPU currently only supports dimension 4") + val expWidth = inputType.exp + val sigWidth = inputType.sig + val outExpWidth = outputType.exp + val outSigWidth = outputType.sig - val recFLen = expWidth + sigWidth + 1 + val recInFLen = expWidth + sigWidth + 1 + val recOutFLen = outExpWidth + outSigWidth + 1 val io = IO(new Bundle { val in = Flipped(Valid(new Bundle { - val a = Vec(4, Bits((recFLen).W)) - val b = Vec(4, Bits((recFLen).W)) - val c = Bits((recFLen).W) + val a = Vec(4, Bits((recInFLen).W)) + val b = Vec(4, Bits((recInFLen).W)) + val c = Bits((recInFLen).W) // val roundingMode = UInt(3.W) // val detectTininess = UInt(1.W) })) val stall = Input(Bool()) val out = Valid(new Bundle { - val data = Bits((recFLen).W) + val data = Bits((recOutFLen).W) }) }) - val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth))) - mul.zipWithIndex.foreach { case (m, i) => + val mul = Seq.fill(dim)(Module(new hardfloat.MulFullRawFN(expWidth, sigWidth))) + val mulOuts = mul.zipWithIndex.map { case (m, i) => // FIXME: these settings are arbitrary - m.io.roundingMode := hardfloat.consts.round_near_even - m.io.detectTininess := hardfloat.consts.tininess_afterRounding - m.io.a := io.in.bits.a(i) - m.io.b := io.in.bits.b(i) + // m.io.roundingMode := hardfloat.consts.round_near_even + // m.io.detectTininess := hardfloat.consts.tininess_afterRounding + // m.io.a := io.in.bits.a(i) + // m.io.b := io.in.bits.b(i) + val rawInA = hardfloat.rawFloatFromRecFN(expWidth, sigWidth, io.in.bits.a(i)) + val rawInB = hardfloat.rawFloatFromRecFN(expWidth, sigWidth, io.in.bits.b(i)) + m.io.a := rawInA + m.io.b := rawInB + // m.io.invalidExc output ignored + // assert(m.io.invalidExc === false.B) } - val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mul.map(_.io.out))) + val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mul.map(_.io.rawOut))) val mulStageC = StallingPipe(io.stall, io.in.valid, io.in.bits.c) + val mulExpWidth = mulStageOut.bits.head.expWidth + val mulSigWidth = mulStageOut.bits.head.sigWidth + // mul stage end ------------------------------------------------------------- - val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth))) - add1.zipWithIndex.foreach { case (a, i) => - a.io.subOp := 0.U // FIXME + val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRawFN(mulExpWidth, mulSigWidth))) + val add1Outs = add1.zipWithIndex.map { case (a, i) => + a.io.subOp := 0.U // FIXME dont know what this is a.io.a := mulStageOut.bits(2 * i + 0) a.io.b := mulStageOut.bits(2 * i + 1) a.io.roundingMode := hardfloat.consts.round_near_even - a.io.detectTininess := hardfloat.consts.tininess_afterRounding + // a.io.detectTininess := hardfloat.consts.tininess_afterRounding + // a.io.invalidExc output ignored + // assert(a.io.invalidExc === false.B) + + // round back to fp32 recoded format + // FIXME: awkward to do this in the middle; do right after mul? + val addExpWidth = a.io.rawOut.expWidth + val addSigWidth = a.io.rawOut.sigWidth + val roundRawFNToRecFN = + Module(new hardfloat.RoundAnyRawFNToRecFN(addExpWidth, addSigWidth, outExpWidth, outSigWidth, 0)) + roundRawFNToRecFN.io.invalidExc := a.io.invalidExc + roundRawFNToRecFN.io.infiniteExc := false.B + roundRawFNToRecFN.io.in := a.io.rawOut + roundRawFNToRecFN.io.roundingMode := hardfloat.consts.round_near_even + roundRawFNToRecFN.io.detectTininess := hardfloat.consts.tininess_afterRounding + roundRawFNToRecFN.io.out + // roundRawFNToRecFN.io.exceptionFlags ignored } - val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1.map(_.io.out))) + // val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1.map(_.io.out))) + val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1Outs)) val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits) // add1 stage end ------------------------------------------------------------ - val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth)) + val add2 = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) add2.io.subOp := 0.U // FIXME assert(add1StageOut.bits.length == 2) add2.io.a := add1StageOut.bits(0) add2.io.b := add1StageOut.bits(1) add2.io.roundingMode := hardfloat.consts.round_near_even add2.io.detectTininess := hardfloat.consts.tininess_afterRounding + assert(add2.io.exceptionFlags === 0.U) val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out) val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits) // add2 stage end ------------------------------------------------------------ - val acc = Module(new hardfloat.AddRecFN(expWidth, sigWidth)) + // convert to recoded format for addition to C + // TODO: raw+raw addition might be cheaper? + val recToRec = Module( + new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth)) + recToRec.io.in := add2StageC.bits + recToRec.io.roundingMode := hardfloat.consts.round_near_even + recToRec.io.detectTininess := hardfloat.consts.tininess_afterRounding + assert(recToRec.io.exceptionFlags === 0.U) + val add2StageCRec = recToRec.io.out + + val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) acc.io.subOp := 0.U // FIXME acc.io.a := add2StageOut.bits - acc.io.b := add2StageC.bits + acc.io.b := add2StageCRec acc.io.roundingMode := hardfloat.consts.round_near_even acc.io.detectTininess := hardfloat.consts.tininess_afterRounding + assert(acc.io.exceptionFlags === 0.U) val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out) - // FIXME: exception output ignored // acc stage end ------------------------------------------------------------- diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index 53cffc3..a978e5c 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -85,7 +85,7 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { // 4-cycle latency + stalls c.io.out.valid.expect(true.B) - c.io.out.bits.data.expect(0x56d0.U) + c.io.out.bits.data.expect(0x42da0000L.U) c.clock.step()