Support fp16 operand/accum in TensorDPU
TODO: fp32 accum
This commit is contained in:
@@ -46,23 +46,70 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
|
||||
implicit val p: Parameters = Parameters.empty
|
||||
|
||||
it should "pass" in {
|
||||
test(new TensorDotProductUnit)
|
||||
it should "pass fp16" in {
|
||||
test(new TensorDotProductUnit(half = true))
|
||||
// .withAnnotations(Seq(VerilatorBackendAnnotation))
|
||||
// .withAnnotations(Seq(WriteVcdAnnotation))
|
||||
{ c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
// (1,3,5,7)*(2,4,6,8) + 9 = 109
|
||||
c.io.in.bits.a(0).poke(0x3f800000L.U(64.W))
|
||||
c.io.in.bits.a(1).poke(0x40400000L.U(64.W))
|
||||
c.io.in.bits.a(2).poke(0x40a00000L.U(64.W))
|
||||
c.io.in.bits.a(3).poke(0x40e00000L.U(64.W))
|
||||
c.io.in.bits.b(0).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(1).poke(0x40800000L.U(64.W))
|
||||
c.io.in.bits.b(2).poke(0x40c00000L.U(64.W))
|
||||
c.io.in.bits.b(3).poke(0x41000000L.U(64.W))
|
||||
c.io.in.bits.c .poke(0x41100000L.U(64.W))
|
||||
c.io.in.bits.a(0).poke(0x3c00.U(16.W))
|
||||
c.io.in.bits.a(1).poke(0x4200.U(16.W))
|
||||
c.io.in.bits.a(2).poke(0x4500.U(16.W))
|
||||
c.io.in.bits.a(3).poke(0x4700.U(16.W))
|
||||
c.io.in.bits.b(0).poke(0x4000.U(16.W))
|
||||
c.io.in.bits.b(1).poke(0x4400.U(16.W))
|
||||
c.io.in.bits.b(2).poke(0x4600.U(16.W))
|
||||
c.io.in.bits.b(3).poke(0x4800.U(16.W))
|
||||
c.io.in.bits.c .poke(0x4880.U(16.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
// stall the pipeline
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
// 4-cycle latency + stalls
|
||||
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x56d0.U)
|
||||
|
||||
c.clock.step()
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
}
|
||||
}
|
||||
|
||||
it should "pass fp32" in {
|
||||
test(new TensorDotProductUnit(half = false))
|
||||
// .withAnnotations(Seq(VerilatorBackendAnnotation))
|
||||
// .withAnnotations(Seq(WriteVcdAnnotation))
|
||||
{ c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
// (1,3,5,7)*(2,4,6,8) + 9 = 109
|
||||
c.io.in.bits.a(0).poke(0x3f800000L.U(32.W))
|
||||
c.io.in.bits.a(1).poke(0x40400000L.U(32.W))
|
||||
c.io.in.bits.a(2).poke(0x40a00000L.U(32.W))
|
||||
c.io.in.bits.a(3).poke(0x40e00000L.U(32.W))
|
||||
c.io.in.bits.b(0).poke(0x40000000L.U(32.W))
|
||||
c.io.in.bits.b(1).poke(0x40800000L.U(32.W))
|
||||
c.io.in.bits.b(2).poke(0x40c00000L.U(32.W))
|
||||
c.io.in.bits.b(3).poke(0x41000000L.U(32.W))
|
||||
c.io.in.bits.c .poke(0x41100000L.U(32.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user