Support fp16 operand/accum in TensorDPU
TODO: fp32 accum
This commit is contained in:
@@ -8,9 +8,12 @@ import chisel3.util._
|
||||
import freechips.rocketchip.tile
|
||||
|
||||
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
||||
class TensorDotProductUnit extends Module with tile.HasFPUParameters {
|
||||
val fLen = 32
|
||||
val minFLen = 32
|
||||
// `half`: if True, generate fp16 MACs; if False fp32.
|
||||
class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUParameters {
|
||||
val t = if (half) tile.FType.H else tile.FType.S
|
||||
|
||||
val fLen = t.ieeeWidth
|
||||
val minFLen = 16 // fp16
|
||||
def xLen = 32
|
||||
val dotProductDim = 4
|
||||
|
||||
@@ -26,12 +29,12 @@ class TensorDotProductUnit extends Module with tile.HasFPUParameters {
|
||||
})
|
||||
})
|
||||
|
||||
val t = tile.FType.S
|
||||
|
||||
// IEEE -> recode() -> unbox() -> Hardfloat -> box() -> ieee() -> IEEE
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tile.FType.S))
|
||||
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
|
||||
// make sure recoding/uncoding happens only at the edge, not at every
|
||||
// pipeline stage inside the dpu
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(t)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(t)))
|
||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(t))
|
||||
|
||||
val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig))
|
||||
dpu.io.in.valid := io.in.valid
|
||||
@@ -88,8 +91,8 @@ object StallingPipe {
|
||||
}
|
||||
}
|
||||
|
||||
// Computes d = a(0)*b(0) + ... + a(3)*b(3) + c.
|
||||
// Fully pipelined with a fixed latency of 4 cycles.
|
||||
// Computes d = a(0)*b(0) + ... + a(`dim`-1)*b(`dim`-1) + c.
|
||||
// Fully pipelined with a fixed latency determined by `dim`.
|
||||
class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
require(dim == 4, "DPU currently only supports dimension 4")
|
||||
|
||||
|
||||
@@ -46,23 +46,70 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
|
||||
implicit val p: Parameters = Parameters.empty
|
||||
|
||||
it should "pass" in {
|
||||
test(new TensorDotProductUnit)
|
||||
it should "pass fp16" in {
|
||||
test(new TensorDotProductUnit(half = true))
|
||||
// .withAnnotations(Seq(VerilatorBackendAnnotation))
|
||||
// .withAnnotations(Seq(WriteVcdAnnotation))
|
||||
{ c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
// (1,3,5,7)*(2,4,6,8) + 9 = 109
|
||||
c.io.in.bits.a(0).poke(0x3f800000L.U(64.W))
|
||||
c.io.in.bits.a(1).poke(0x40400000L.U(64.W))
|
||||
c.io.in.bits.a(2).poke(0x40a00000L.U(64.W))
|
||||
c.io.in.bits.a(3).poke(0x40e00000L.U(64.W))
|
||||
c.io.in.bits.b(0).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(1).poke(0x40800000L.U(64.W))
|
||||
c.io.in.bits.b(2).poke(0x40c00000L.U(64.W))
|
||||
c.io.in.bits.b(3).poke(0x41000000L.U(64.W))
|
||||
c.io.in.bits.c .poke(0x41100000L.U(64.W))
|
||||
c.io.in.bits.a(0).poke(0x3c00.U(16.W))
|
||||
c.io.in.bits.a(1).poke(0x4200.U(16.W))
|
||||
c.io.in.bits.a(2).poke(0x4500.U(16.W))
|
||||
c.io.in.bits.a(3).poke(0x4700.U(16.W))
|
||||
c.io.in.bits.b(0).poke(0x4000.U(16.W))
|
||||
c.io.in.bits.b(1).poke(0x4400.U(16.W))
|
||||
c.io.in.bits.b(2).poke(0x4600.U(16.W))
|
||||
c.io.in.bits.b(3).poke(0x4800.U(16.W))
|
||||
c.io.in.bits.c .poke(0x4880.U(16.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
// stall the pipeline
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
// 4-cycle latency + stalls
|
||||
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x56d0.U)
|
||||
|
||||
c.clock.step()
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
}
|
||||
}
|
||||
|
||||
it should "pass fp32" in {
|
||||
test(new TensorDotProductUnit(half = false))
|
||||
// .withAnnotations(Seq(VerilatorBackendAnnotation))
|
||||
// .withAnnotations(Seq(WriteVcdAnnotation))
|
||||
{ c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
// (1,3,5,7)*(2,4,6,8) + 9 = 109
|
||||
c.io.in.bits.a(0).poke(0x3f800000L.U(32.W))
|
||||
c.io.in.bits.a(1).poke(0x40400000L.U(32.W))
|
||||
c.io.in.bits.a(2).poke(0x40a00000L.U(32.W))
|
||||
c.io.in.bits.a(3).poke(0x40e00000L.U(32.W))
|
||||
c.io.in.bits.b(0).poke(0x40000000L.U(32.W))
|
||||
c.io.in.bits.b(1).poke(0x40800000L.U(32.W))
|
||||
c.io.in.bits.b(2).poke(0x40c00000L.U(32.W))
|
||||
c.io.in.bits.b(3).poke(0x41000000L.U(32.W))
|
||||
c.io.in.bits.c .poke(0x41100000L.U(32.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user