Add placeholder tensor core DPU module

This commit is contained in:
Hansung Kim
2024-05-27 21:16:53 -07:00
parent 387e05404e
commit 615815acf5
2 changed files with 111 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
// See LICENSE.SiFive for license details.
// See LICENSE.Berkeley for license details.
package radiance.core
import chisel3._
import chisel3.util._
import freechips.rocketchip.rocket._
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
require(latency <= 2)
val io = IO(new Bundle {
val validin = Input(Bool())
val op = Input(Bits(2.W))
val a = Input(Bits((expWidth + sigWidth + 1).W))
val b = Input(Bits((expWidth + sigWidth + 1).W))
val c = Input(Bits((expWidth + sigWidth + 1).W))
val roundingMode = Input(UInt(3.W))
val detectTininess = Input(UInt(1.W))
val out = Output(Bits((expWidth + sigWidth + 1).W))
val exceptionFlags = Output(Bits(5.W))
val validout = Output(Bool())
})
//------------------------------------------------------------------------
//------------------------------------------------------------------------
val mulAddRecFNToRaw_preMul = Module(new hardfloat.MulAddRecFNToRaw_preMul(expWidth, sigWidth))
val mulAddRecFNToRaw_postMul = Module(new hardfloat.MulAddRecFNToRaw_postMul(expWidth, sigWidth))
mulAddRecFNToRaw_preMul.io.op := io.op
mulAddRecFNToRaw_preMul.io.a := io.a
mulAddRecFNToRaw_preMul.io.b := io.b
mulAddRecFNToRaw_preMul.io.c := io.c
val mulAddResult =
(mulAddRecFNToRaw_preMul.io.mulAddA *
mulAddRecFNToRaw_preMul.io.mulAddB) +&
mulAddRecFNToRaw_preMul.io.mulAddC
val valid_stage0 = Wire(Bool())
val roundingMode_stage0 = Wire(UInt(3.W))
val detectTininess_stage0 = Wire(UInt(1.W))
val postmul_regs = if(latency>0) 1 else 0
mulAddRecFNToRaw_postMul.io.fromPreMul := Pipe(io.validin, mulAddRecFNToRaw_preMul.io.toPostMul, postmul_regs).bits
mulAddRecFNToRaw_postMul.io.mulAddResult := Pipe(io.validin, mulAddResult, postmul_regs).bits
mulAddRecFNToRaw_postMul.io.roundingMode := Pipe(io.validin, io.roundingMode, postmul_regs).bits
roundingMode_stage0 := Pipe(io.validin, io.roundingMode, postmul_regs).bits
detectTininess_stage0 := Pipe(io.validin, io.detectTininess, postmul_regs).bits
valid_stage0 := Pipe(io.validin, false.B, postmul_regs).valid
//------------------------------------------------------------------------
//------------------------------------------------------------------------
val roundRawFNToRecFN = Module(new hardfloat.RoundRawFNToRecFN(expWidth, sigWidth, 0))
val round_regs = if(latency==2) 1 else 0
roundRawFNToRecFN.io.invalidExc := Pipe(valid_stage0, mulAddRecFNToRaw_postMul.io.invalidExc, round_regs).bits
roundRawFNToRecFN.io.in := Pipe(valid_stage0, mulAddRecFNToRaw_postMul.io.rawOut, round_regs).bits
roundRawFNToRecFN.io.roundingMode := Pipe(valid_stage0, roundingMode_stage0, round_regs).bits
roundRawFNToRecFN.io.detectTininess := Pipe(valid_stage0, detectTininess_stage0, round_regs).bits
io.validout := Pipe(valid_stage0, false.B, round_regs).valid
roundRawFNToRecFN.io.infiniteExc := false.B
io.out := roundRawFNToRecFN.io.out
io.exceptionFlags := roundRawFNToRecFN.io.exceptionFlags
}

View File

@@ -0,0 +1,41 @@
package radiance.core
import chisel3._
import chisel3.stage.PrintFullStackTraceAnnotation
import chisel3.util._
import chiseltest._
import chiseltest.simulator.VerilatorFlags
import org.scalatest.flatspec.AnyFlatSpec
class MulAddTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "MulAddRecFNPipe"
it should "do basic arithmetic" in {
test(new MulAddRecFNPipe(2, 8, 23))
// .withAnnotations(Seq(WriteVcdAnnotation))
{ c =>
c.io.validin.poke(true.B)
// 0: MADD
// 1: MSUB
// 2: NMSUB
// 3: NMADD
c.io.op.poke(0.U)
// rounding mode (p.113 of spec)
// 0: round to nearest, ties to even
c.io.roundingMode.poke(0.U)
c.io.detectTininess.poke(hardfloat.consts.tininess_beforeRounding)
c.io.a.poke(0x3f800000.U/*2.0*/)
c.io.b.poke(0x3f800000.U/*3.0*/)
c.io.c.poke(0x00000000.U/*0.0*/)
c.clock.step()
c.io.validin.poke(false.B)
c.io.validout.expect(false.B)
c.clock.step()
c.io.validout.expect(true.B)
c.io.out.expect(0x40c00000.U/*6.0*/)
c.clock.step()
c.io.validout.expect(false.B)
}
}
}