diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala new file mode 100644 index 0000000..ce60990 --- /dev/null +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -0,0 +1,70 @@ +// See LICENSE.SiFive for license details. +// See LICENSE.Berkeley for license details. + +package radiance.core + +import chisel3._ +import chisel3.util._ +import freechips.rocketchip.rocket._ + +class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { + require(latency <= 2) + + val io = IO(new Bundle { + val validin = Input(Bool()) + val op = Input(Bits(2.W)) + val a = Input(Bits((expWidth + sigWidth + 1).W)) + val b = Input(Bits((expWidth + sigWidth + 1).W)) + val c = Input(Bits((expWidth + sigWidth + 1).W)) + val roundingMode = Input(UInt(3.W)) + val detectTininess = Input(UInt(1.W)) + val out = Output(Bits((expWidth + sigWidth + 1).W)) + val exceptionFlags = Output(Bits(5.W)) + val validout = Output(Bool()) + }) + + //------------------------------------------------------------------------ + //------------------------------------------------------------------------ + + val mulAddRecFNToRaw_preMul = Module(new hardfloat.MulAddRecFNToRaw_preMul(expWidth, sigWidth)) + val mulAddRecFNToRaw_postMul = Module(new hardfloat.MulAddRecFNToRaw_postMul(expWidth, sigWidth)) + + mulAddRecFNToRaw_preMul.io.op := io.op + mulAddRecFNToRaw_preMul.io.a := io.a + mulAddRecFNToRaw_preMul.io.b := io.b + mulAddRecFNToRaw_preMul.io.c := io.c + + val mulAddResult = + (mulAddRecFNToRaw_preMul.io.mulAddA * + mulAddRecFNToRaw_preMul.io.mulAddB) +& + mulAddRecFNToRaw_preMul.io.mulAddC + + val valid_stage0 = Wire(Bool()) + val roundingMode_stage0 = Wire(UInt(3.W)) + val detectTininess_stage0 = Wire(UInt(1.W)) + + val postmul_regs = if(latency>0) 1 else 0 + mulAddRecFNToRaw_postMul.io.fromPreMul := Pipe(io.validin, mulAddRecFNToRaw_preMul.io.toPostMul, postmul_regs).bits + mulAddRecFNToRaw_postMul.io.mulAddResult := Pipe(io.validin, mulAddResult, postmul_regs).bits + mulAddRecFNToRaw_postMul.io.roundingMode := Pipe(io.validin, io.roundingMode, postmul_regs).bits + roundingMode_stage0 := Pipe(io.validin, io.roundingMode, postmul_regs).bits + detectTininess_stage0 := Pipe(io.validin, io.detectTininess, postmul_regs).bits + valid_stage0 := Pipe(io.validin, false.B, postmul_regs).valid + + //------------------------------------------------------------------------ + //------------------------------------------------------------------------ + + val roundRawFNToRecFN = Module(new hardfloat.RoundRawFNToRecFN(expWidth, sigWidth, 0)) + + val round_regs = if(latency==2) 1 else 0 + roundRawFNToRecFN.io.invalidExc := Pipe(valid_stage0, mulAddRecFNToRaw_postMul.io.invalidExc, round_regs).bits + roundRawFNToRecFN.io.in := Pipe(valid_stage0, mulAddRecFNToRaw_postMul.io.rawOut, round_regs).bits + roundRawFNToRecFN.io.roundingMode := Pipe(valid_stage0, roundingMode_stage0, round_regs).bits + roundRawFNToRecFN.io.detectTininess := Pipe(valid_stage0, detectTininess_stage0, round_regs).bits + io.validout := Pipe(valid_stage0, false.B, round_regs).valid + + roundRawFNToRecFN.io.infiniteExc := false.B + + io.out := roundRawFNToRecFN.io.out + io.exceptionFlags := roundRawFNToRecFN.io.exceptionFlags +} diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala new file mode 100644 index 0000000..4ecaa9b --- /dev/null +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -0,0 +1,41 @@ +package radiance.core + +import chisel3._ +import chisel3.stage.PrintFullStackTraceAnnotation +import chisel3.util._ +import chiseltest._ +import chiseltest.simulator.VerilatorFlags +import org.scalatest.flatspec.AnyFlatSpec + +class MulAddTest extends AnyFlatSpec with ChiselScalatestTester { + behavior of "MulAddRecFNPipe" + + it should "do basic arithmetic" in { + test(new MulAddRecFNPipe(2, 8, 23)) + // .withAnnotations(Seq(WriteVcdAnnotation)) + { c => + c.io.validin.poke(true.B) + // 0: MADD + // 1: MSUB + // 2: NMSUB + // 3: NMADD + c.io.op.poke(0.U) + // rounding mode (p.113 of spec) + // 0: round to nearest, ties to even + c.io.roundingMode.poke(0.U) + c.io.detectTininess.poke(hardfloat.consts.tininess_beforeRounding) + c.io.a.poke(0x3f800000.U/*2.0*/) + c.io.b.poke(0x3f800000.U/*3.0*/) + c.io.c.poke(0x00000000.U/*0.0*/) + c.clock.step() + c.io.validin.poke(false.B) + c.io.validout.expect(false.B) + c.clock.step() + c.io.validout.expect(true.B) + c.io.out.expect(0x40c00000.U/*6.0*/) + c.clock.step() + c.io.validout.expect(false.B) + } + } +} +