From 6a3aa549d34ffbafd4154081fb063be847452114 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 14 Oct 2024 15:02:25 -0700 Subject: [PATCH] Add skeleton for Hopper Tensor Core --- .../radiance/core/TensorCoreDecoupled.scala | 70 +++++++++++++++++++ .../radiance/TensorCoreDecoupledTest.scala | 23 ++++++ 2 files changed, 93 insertions(+) create mode 100644 src/main/scala/radiance/core/TensorCoreDecoupled.scala create mode 100644 src/test/scala/radiance/TensorCoreDecoupledTest.scala diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala new file mode 100644 index 0000000..10bfedb --- /dev/null +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -0,0 +1,70 @@ +// See LICENSE.SiFive for license details. +// See LICENSE.Berkeley for license details. + +package radiance.core + +import chisel3._ +import chisel3.util._ + +class TensorCoreDecoupled(val numWarps: Int, val numLanes: Int) extends Module { + val numWarpBits = log2Ceil(numWarps) + val wordSize = 4 // TODO FP16 + val dataWidth = numLanes * wordSize // TODO FP16 + + val io = IO(new Bundle{ + val initiate = Flipped(Decoupled(new Bundle{ + val wid = UInt(numWarpBits.W) + })) + val dataA = Flipped(Decoupled(new TensorMemResp(dataWidth))) + val dataB = Flipped(Decoupled(new TensorMemResp(dataWidth))) + val addressA = Decoupled(new TensorMemReq) + val addressB = Decoupled(new TensorMemReq) + val writeback = Decoupled(new Bundle{ + val wid = UInt(numWarpBits.W) + val last = Bool() + }) + }) + + // FSM + // + val state = RegInit(TensorState.idle) + // TODO: just transition every cycle for now + state := (state match { + case TensorState.idle => Mux(io.initiate.fire, TensorState.smemRead, state) + case TensorState.smemRead => TensorState.compute + case TensorState.compute => TensorState.writeback + case TensorState.writeback => { + // hold until writeback is cleared + Mux(io.writeback.ready, TensorState.idle, state) + } + case _ => TensorState.idle + }) + + // TODO + io.dataA.ready := true.B + io.dataB.ready := true.B + io.addressA.valid := false.B + io.addressB.valid := false.B + io.addressA.bits := DontCare + io.addressB.bits := DontCare + io.initiate.ready := true.B + io.writeback.valid := true.B + io.writeback.bits := DontCare +} + +class TensorMemReq extends Bundle { + // TODO: tag + val address = UInt(32.W) +} +class TensorMemResp(val dataWidth: Int) extends Bundle { + // TODO: tag + val data = UInt(32.W) +} + + +object TensorState extends ChiselEnum { + val idle = Value(0.U) + val smemRead = Value(1.U) + val compute = Value(2.U) + val writeback = Value(3.U) +} diff --git a/src/test/scala/radiance/TensorCoreDecoupledTest.scala b/src/test/scala/radiance/TensorCoreDecoupledTest.scala new file mode 100644 index 0000000..5dd734a --- /dev/null +++ b/src/test/scala/radiance/TensorCoreDecoupledTest.scala @@ -0,0 +1,23 @@ +package radiance.core + +import chisel3._ +import chisel3.util._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec + +class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester { + behavior of "TensorCoreDecoupled" + + it should "do the right thing" in { + test(new TensorCoreDecoupled(8, 8)) + { c => + c.io.initiate.valid.poke(true.B) + c.io.dataA.valid.poke(false.B) + c.io.dataA.bits.data.poke(0.U) + c.io.dataB.valid.poke(false.B) + c.io.dataB.bits.data.poke(0.U) + c.clock.step() + c.io.writeback.valid.expect(true.B) + } + } +}