From 68a7f66046db8975fbadeb7da4cfad767fd61e17 Mon Sep 17 00:00:00 2001 From: Zhongdi LUO Date: Thu, 2 Jul 2026 08:06:42 +0000 Subject: [PATCH] feat: add fp8 tensor dot product path --- src/main/scala/radiance/core/TensorDPU.scala | 46 ++++++++++++++++--- src/test/scala/radiance/FP8E4M3Test.scala | 48 ++++++++++++++++++++ 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 5666dc6..2504974 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -38,17 +38,37 @@ object FP8E4M3 { } } +object TensorInputType extends Enumeration { + val FP16, FP32, FP8E4M3 = Value + + def fromHalf(half: Boolean): Value = { + if (half) FP16 else FP32 + } +} + // Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. // `half`: if True, generate fp16 MACs; if False fp32. class TensorDotProductUnit( - val dim: Int = 4, - val half: Boolean + val dim: Int, + val half: Boolean, + val inputType: TensorInputType.Value ) extends Module with tile.HasFPUParameters { - val tIn = if (half) tile.FType.H else tile.FType.S + def this(dim: Int = 4, half: Boolean) = { + this(dim, half, TensorInputType.fromHalf(half)) + } + + val tIn = inputType match { + case TensorInputType.FP16 => tile.FType.H + case TensorInputType.FP32 => tile.FType.S + case TensorInputType.FP8E4M3 => tile.FType.S + } // output datatype fixed to single-precision val tOut = tile.FType.S - val inFLen = tIn.ieeeWidth + val inFLen = inputType match { + case TensorInputType.FP8E4M3 => 8 + case _ => tIn.ieeeWidth + } val outFLen = tOut.ieeeWidth val fLen = outFLen // needed for HasFPUParameters val minFLen = 16 // fp16 @@ -71,9 +91,21 @@ class TensorDotProductUnit( // [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE] // make sure recoding/uncoding happens only at the edge, not at every // pipeline stage inside the dpu - val tag = if (half) H else S - val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn))) - val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn))) + val tag = inputType match { + case TensorInputType.FP16 => H + case TensorInputType.FP32 => S + case TensorInputType.FP8E4M3 => S + } + private def recodeInput(x: Bits): UInt = { + inputType match { + case TensorInputType.FP8E4M3 => + unbox(recode(FP8E4M3.toFloat32(x.asUInt), S), S, Some(tIn)) + case _ => + unbox(recode(x.asUInt, tag), tag, Some(tIn)) + } + } + val in1 = io.in.bits.a.map(recodeInput) + val in2 = io.in.bits.b.map(recodeInput) val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut)) val dpu = Module(new DotProductPipe(dim, tIn, tOut)) diff --git a/src/test/scala/radiance/FP8E4M3Test.scala b/src/test/scala/radiance/FP8E4M3Test.scala index eba0a35..e67d49f 100644 --- a/src/test/scala/radiance/FP8E4M3Test.scala +++ b/src/test/scala/radiance/FP8E4M3Test.scala @@ -32,4 +32,52 @@ class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester { } } } + + it should "run an 8-wide FP8 dot product with FP32 accumulation" in { + test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c => + c.io.in.valid.poke(true.B) + c.io.stall.poke(false.B) + for (i <- 0 until 8) { + c.io.in.bits.a(i).poke(0x38.U(8.W)) + c.io.in.bits.b(i).poke(0x40.U(8.W)) + } + c.io.in.bits.c.poke(0x3f800000L.U(32.W)) + + c.io.out.valid.expect(false.B) + c.clock.step() + c.io.in.valid.poke(false.B) + c.io.out.valid.expect(false.B) + + c.clock.step() + c.clock.step() + c.clock.step() + c.clock.step() + c.io.out.valid.expect(true.B) + c.io.out.bits.data.expect(0x41880000L.U) + } + } + + it should "run an 8-wide fractional FP8 dot product with FP32 accumulation" in { + test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c => + c.io.in.valid.poke(true.B) + c.io.stall.poke(false.B) + for (i <- 0 until 8) { + c.io.in.bits.a(i).poke(0x30.U(8.W)) + c.io.in.bits.b(i).poke(0x3c.U(8.W)) + } + c.io.in.bits.c.poke(0x40000000L.U(32.W)) + + c.io.out.valid.expect(false.B) + c.clock.step() + c.io.in.valid.poke(false.B) + c.io.out.valid.expect(false.B) + + c.clock.step() + c.clock.step() + c.clock.step() + c.clock.step() + c.io.out.valid.expect(true.B) + c.io.out.bits.data.expect(0x41000000L.U) + } + } }