tensor: Parameterize dimension in TensorDotProductUnit

This commit is contained in:
Hansung Kim
2024-10-25 21:44:36 -07:00
parent 51dfebb6a7
commit 46a57fdf9b
4 changed files with 65 additions and 43 deletions

View File

@@ -543,7 +543,7 @@ class TensorCoreDecoupled(
require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
Module(new TensorDotProductUnit(half = false))
Module(new TensorDotProductUnit(dim = 4, half = false))
))
// reshape operands for easier routing to DPU

View File

@@ -9,7 +9,10 @@ import freechips.rocketchip.tile
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
// `half`: if True, generate fp16 MACs; if False fp32.
class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUParameters {
class TensorDotProductUnit(
val dim: Int = 4,
val half: Boolean
) extends Module with tile.HasFPUParameters {
val tIn = if (half) tile.FType.H else tile.FType.S
// output datatype fixed to single-precision
val tOut = tile.FType.S
@@ -19,12 +22,11 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
val fLen = outFLen // needed for HasFPUParameters
val minFLen = 16 // fp16
def xLen = 32
val dotProductDim = 4
val io = IO(new Bundle {
val in = Flipped(Valid(new Bundle {
val a = Vec(dotProductDim, Bits((inFLen).W))
val b = Vec(dotProductDim, Bits((inFLen).W))
val a = Vec(dim, Bits((inFLen).W))
val b = Vec(dim, Bits((inFLen).W))
val c = Bits((outFLen).W) // note C has the out length for accumulation
}))
// 'stall' is effectively out.ready, combinationally coupled to in.ready
@@ -43,7 +45,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
val dpu = Module(new DotProductPipe(dotProductDim, tIn, tOut))
val dpu = Module(new DotProductPipe(dim, tIn, tOut))
dpu.io.in.valid := io.in.valid
dpu.io.in.bits.a := in1
dpu.io.in.bits.b := in2
@@ -101,7 +103,6 @@ object StallingPipe {
// Computes d = a(0)*b(0) + ... + a(`dim`-1)*b(`dim`-1) + c.
// Fully pipelined with a fixed latency determined by `dim`.
class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) extends Module {
require(dim == 4, "DPU currently only supports dimension 4")
val expWidth = inputType.exp
val sigWidth = inputType.sig
val outExpWidth = outputType.exp
@@ -111,8 +112,8 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
val recOutFLen = outExpWidth + outSigWidth + 1
val io = IO(new Bundle {
val in = Flipped(Valid(new Bundle {
val a = Vec(4, Bits((recInFLen).W))
val b = Vec(4, Bits((recInFLen).W))
val a = Vec(dim, Bits((recInFLen).W))
val b = Vec(dim, Bits((recInFLen).W))
val c = Bits((recOutFLen).W)
// val roundingMode = UInt(3.W)
// val detectTininess = UInt(1.W)
@@ -141,6 +142,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// assert(m.io.invalidExc === false.B)
// round fp16*fp16 raw result back to fp32 recoded format
// @perf: possibly pipeline here for better timing
val mulExpWidth = m.io.rawOut.expWidth
val mulSigWidth = m.io.rawOut.sigWidth
val roundRawFNToRecFN =
@@ -160,45 +162,65 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// mul stage end -------------------------------------------------------------
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)))
val add1Outs = add1.zipWithIndex.map { case (a, i) =>
a.io.subOp := 0.U // FIXME dont know what this is
a.io.a := mulStageOut.bits(2 * i + 0)
a.io.b := mulStageOut.bits(2 * i + 1)
a.io.roundingMode := hardfloat.consts.round_near_even
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(a.io.exceptionFlags === 0.U)
a.io.out
// reduce-add `dim` mul results down to one in a tree reduction
//
val log2Dim = log2Ceil(dim)
require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!")
// instantiate wires for input values to each reduction pipeline stage
val interim = (log2Dim to 0 by -1).map { i =>
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
}
// instantiate wires for pipe registers for C
val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) )
// connect the first stage inputs
interim(0) := mulStageOut
interimC(0) := mulStageC
val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1Outs))
val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits)
// now we get fancy
val (addStageOut, addStageC) = (interim zip interimC).reduce {
(inputsAndC, outputsAndC) => {
val (inputs, inC) = inputsAndC
val (outputs, outC) = outputsAndC
// add1 stage end ------------------------------------------------------------
require(inputs.bits.length == 2 * outputs.bits.length)
val thisDim = inputs.bits.length
val adders = Seq.fill(thisDim / 2)(
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
)
val addOuts = adders.zipWithIndex.map { case (a, i) =>
a.io.subOp := 0.U // FIXME dont know what this is
a.io.a := inputs.bits(2 * i + 0)
a.io.b := inputs.bits(2 * i + 1)
a.io.roundingMode := hardfloat.consts.round_near_even
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(a.io.exceptionFlags === 0.U)
a.io.out
}
val add2 = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
add2.io.subOp := 0.U // FIXME
add2.io.a := add1StageOut.bits(0)
add2.io.b := add1StageOut.bits(1)
add2.io.roundingMode := hardfloat.consts.round_near_even
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(add2.io.exceptionFlags === 0.U)
// pipeline and connect outputs to the next stage
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
assert(inputs.valid === inC.valid,
"adder inputs valid and C pipe valid went out-of-sync")
val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out)
val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits)
(outputs, outC)
}
}
require(addStageOut.bits.length == 1)
// add2 stage end ------------------------------------------------------------
// add stages end ------------------------------------------------------------
// add final A and B dot-product result to accumulator C
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
acc.io.subOp := 0.U // FIXME
acc.io.a := add2StageOut.bits
// acc.io.b := add2StageCRec
acc.io.b := add2StageC.bits
acc.io.a := addStageOut.bits(0)
acc.io.b := addStageC.bits
acc.io.roundingMode := hardfloat.consts.round_near_even
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(acc.io.exceptionFlags === 0.U)
val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out)
val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out)
// acc stage end -------------------------------------------------------------

View File

@@ -9,7 +9,7 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreDecoupled"
it should "do the right thing" in {
test(new TensorCoreDecoupled(8, 8, tilingParams = TensorTilingParams()))
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, tilingParams = TensorTilingParams()))
{ c =>
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.wid.poke(0.U)

View File

@@ -46,8 +46,8 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
implicit val p: Parameters = Parameters.empty
it should "pass fp16" in {
test(new TensorDotProductUnit(half = true))
it should "pass 4-dim fp16" in {
test(new TensorDotProductUnit(4, half = true))
// .withAnnotations(Seq(VerilatorBackendAnnotation))
// .withAnnotations(Seq(WriteVcdAnnotation))
{ c =>
@@ -93,9 +93,9 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
}
}
it should "pass fp16 2" in {
test(new TensorDotProductUnit(half = true))
.withAnnotations(Seq(VerilatorBackendAnnotation))
it should "pass 4-dim fp16 2" in {
test(new TensorDotProductUnit(4, half = true))
// .withAnnotations(Seq(VerilatorBackendAnnotation))
// .withAnnotations(Seq(WriteVcdAnnotation))
{ c =>
c.io.in.valid.poke(true.B)
@@ -129,8 +129,8 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
}
}
it should "pass fp32" in {
test(new TensorDotProductUnit(half = false))
it should "pass 4-dim fp32" in {
test(new TensorDotProductUnit(4, half = false))
// .withAnnotations(Seq(VerilatorBackendAnnotation))
// .withAnnotations(Seq(WriteVcdAnnotation))
{ c =>