Fix fp exception by rounding right after MulRawFN

This commit is contained in:
Hansung Kim
2024-08-06 18:17:42 -07:00
parent b7a342fcf6
commit 32c7aed263
3 changed files with 1826 additions and 1140 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
val in = Flipped(Valid(new Bundle {
val a = Vec(dotProductDim, Bits((inFLen).W))
val b = Vec(dotProductDim, Bits((inFLen).W))
val c = Bits((inFLen).W)
val c = Bits((outFLen).W) // note C has the out length for accumulation
}))
val stall = Input(Bool())
val out = Valid(new Bundle {
@@ -39,7 +39,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
val tag = if (half) H else S
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn)))
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
val in3 = unbox(recode(io.in.bits.c, tag), tag, Some(tIn))
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
val dpu = Module(new DotProductPipe(dotProductDim, tIn, tOut))
dpu.io.in.valid := io.in.valid
@@ -111,7 +111,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
val in = Flipped(Valid(new Bundle {
val a = Vec(4, Bits((recInFLen).W))
val b = Vec(4, Bits((recInFLen).W))
val c = Bits((recInFLen).W)
val c = Bits((recOutFLen).W)
// val roundingMode = UInt(3.W)
// val detectTininess = UInt(1.W)
}))
@@ -121,55 +121,54 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
})
})
val rawZero = hardfloat.rawFloatFromRecFN(expWidth, sigWidth, 0.U(recInFLen.W))
val mul = Seq.fill(dim)(Module(new hardfloat.MulFullRawFN(expWidth, sigWidth)))
val mulOuts = mul.zipWithIndex.map { case (m, i) =>
// FIXME: these settings are arbitrary
// m.io.roundingMode := hardfloat.consts.round_near_even
// m.io.detectTininess := hardfloat.consts.tininess_afterRounding
// m.io.a := io.in.bits.a(i)
// m.io.b := io.in.bits.b(i)
val rawInA = hardfloat.rawFloatFromRecFN(expWidth, sigWidth, io.in.bits.a(i))
val rawInB = hardfloat.rawFloatFromRecFN(expWidth, sigWidth, io.in.bits.b(i))
m.io.a := rawInA
m.io.b := rawInB
// m.io.invalidExc output ignored
// assert(rawInA.isNaN === false.B)
// assert(rawInA.isInf === false.B)
// assert(rawInB.isNaN === false.B)
// assert(rawInB.isInf === false.B)
// tie down to zero when invalid
val rawInAOr0 = Mux(io.in.valid, rawInA, rawZero)
val rawInBOr0 = Mux(io.in.valid, rawInB, rawZero)
m.io.a := rawInAOr0
m.io.b := rawInBOr0
// assert(m.io.invalidExc === false.B)
// round fp16*fp16 raw result back to fp32 recoded format
val mulExpWidth = m.io.rawOut.expWidth
val mulSigWidth = m.io.rawOut.sigWidth
val roundRawFNToRecFN =
Module(new hardfloat.RoundAnyRawFNToRecFN(
mulExpWidth, mulSigWidth, outExpWidth, outSigWidth, 0))
roundRawFNToRecFN.io.invalidExc := m.io.invalidExc
roundRawFNToRecFN.io.infiniteExc := false.B
roundRawFNToRecFN.io.in := m.io.rawOut
roundRawFNToRecFN.io.roundingMode := hardfloat.consts.round_near_even
roundRawFNToRecFN.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(roundRawFNToRecFN.io.exceptionFlags === 0.U)
roundRawFNToRecFN.io.out
}
val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mul.map(_.io.rawOut)))
val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mulOuts))
val mulStageC = StallingPipe(io.stall, io.in.valid, io.in.bits.c)
val mulExpWidth = mulStageOut.bits.head.expWidth
val mulSigWidth = mulStageOut.bits.head.sigWidth
// mul stage end -------------------------------------------------------------
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRawFN(mulExpWidth, mulSigWidth)))
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)))
val add1Outs = add1.zipWithIndex.map { case (a, i) =>
a.io.subOp := 0.U // FIXME dont know what this is
a.io.a := mulStageOut.bits(2 * i + 0)
a.io.b := mulStageOut.bits(2 * i + 1)
a.io.roundingMode := hardfloat.consts.round_near_even
// a.io.detectTininess := hardfloat.consts.tininess_afterRounding
// a.io.invalidExc output ignored
// assert(a.io.invalidExc === false.B)
// round back to fp32 recoded format
// FIXME: awkward to do this in the middle; do right after mul?
val addExpWidth = a.io.rawOut.expWidth
val addSigWidth = a.io.rawOut.sigWidth
val roundRawFNToRecFN =
Module(new hardfloat.RoundAnyRawFNToRecFN(addExpWidth, addSigWidth, outExpWidth, outSigWidth, 0))
roundRawFNToRecFN.io.invalidExc := a.io.invalidExc
roundRawFNToRecFN.io.infiniteExc := false.B
roundRawFNToRecFN.io.in := a.io.rawOut
roundRawFNToRecFN.io.roundingMode := hardfloat.consts.round_near_even
roundRawFNToRecFN.io.detectTininess := hardfloat.consts.tininess_afterRounding
roundRawFNToRecFN.io.out
// roundRawFNToRecFN.io.exceptionFlags ignored
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(a.io.exceptionFlags === 0.U)
a.io.out
}
// val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1.map(_.io.out)))
val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1Outs))
val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits)
@@ -177,35 +176,25 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
val add2 = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
add2.io.subOp := 0.U // FIXME
assert(add1StageOut.bits.length == 2)
add2.io.a := add1StageOut.bits(0)
add2.io.b := add1StageOut.bits(1)
add2.io.roundingMode := hardfloat.consts.round_near_even
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
assert(add2.io.exceptionFlags === 0.U)
// assert(add2.io.exceptionFlags === 0.U)
val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out)
val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits)
// add2 stage end ------------------------------------------------------------
// convert to recoded format for addition to C
// TODO: raw+raw addition might be cheaper?
val recToRec = Module(
new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth))
recToRec.io.in := add2StageC.bits
recToRec.io.roundingMode := hardfloat.consts.round_near_even
recToRec.io.detectTininess := hardfloat.consts.tininess_afterRounding
assert(recToRec.io.exceptionFlags === 0.U)
val add2StageCRec = recToRec.io.out
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
acc.io.subOp := 0.U // FIXME
acc.io.a := add2StageOut.bits
acc.io.b := add2StageCRec
// acc.io.b := add2StageCRec
acc.io.b := add2StageC.bits
acc.io.roundingMode := hardfloat.consts.round_near_even
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
assert(acc.io.exceptionFlags === 0.U)
// assert(acc.io.exceptionFlags === 0.U)
val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out)

View File

@@ -62,7 +62,7 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.in.bits.b(1).poke(0x4400.U(16.W))
c.io.in.bits.b(2).poke(0x4600.U(16.W))
c.io.in.bits.b(3).poke(0x4800.U(16.W))
c.io.in.bits.c .poke(0x4880.U(16.W))
c.io.in.bits.c .poke(0x41100000L.U(32.W))
c.io.out.valid.expect(false.B)
@@ -93,6 +93,42 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
}
}
it should "pass fp16 2" in {
test(new TensorDotProductUnit(half = true))
.withAnnotations(Seq(VerilatorBackendAnnotation))
// .withAnnotations(Seq(WriteVcdAnnotation))
{ c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
c.io.in.bits.a(0).poke(0x0000.U(16.W))
c.io.in.bits.a(1).poke(0x3c00.U(16.W))
c.io.in.bits.a(2).poke(0x4000.U(16.W))
c.io.in.bits.a(3).poke(0x4200.U(16.W))
c.io.in.bits.b(0).poke(0x0000.U(16.W))
c.io.in.bits.b(1).poke(0x4800.U(16.W))
c.io.in.bits.b(2).poke(0x4c00.U(16.W))
c.io.in.bits.b(3).poke(0x4e00.U(16.W))
c.io.in.bits.c .poke(0x00000000.U(32.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
// 4-cycle latency
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x42e00000L.U)
c.clock.step()
c.io.out.valid.expect(false.B)
}
}
it should "pass fp32" in {
test(new TensorDotProductUnit(half = false))
// .withAnnotations(Seq(VerilatorBackendAnnotation))