correct loop count to start after receiving command

This commit is contained in:
Richard Yan
2025-01-28 17:41:00 -08:00
parent d38f69fc5e
commit 52eeed277b

View File

@@ -177,10 +177,15 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
accCommandQueue.io.deq.ready := !ciscValid accCommandQueue.io.deq.ready := !ciscValid
assert(!accSlave.cmd.valid || accCommandQueue.io.enq.ready, "cisc command queue full") assert(!accSlave.cmd.valid || accCommandQueue.io.enq.ready, "cisc command queue full")
when (accCommandQueue.io.enq.fire) {
val enqId = accSlave.cmd.bits(6, 0)
startsLoop := VecInit(Seq(0, 1, 2, 9, 10, 12).map { x => enqId === x.U }).asUInt.orR
}
when (accCommandQueue.io.deq.fire) { when (accCommandQueue.io.deq.fire) {
ciscValid := true.B ciscValid := true.B
ciscId := accSlave.cmd.bits(7, 0) ciscId := accCommandQueue.io.deq.bits(7, 0)
ciscArgs := accSlave.cmd.bits(31, 8) ciscArgs := accCommandQueue.io.deq.bits(31, 8)
instCounter.reset() instCounter.reset()
} }
@@ -236,19 +241,15 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
is (0.U) { // compute on given hexadeciles is (0.U) { // compute on given hexadeciles
val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8)) val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8))
val accSkipInst = genAccSkipInst(ciscArgs(16), 0x2b8.U) val accSkipInst = genAccSkipInst(ciscArgs(16), 0x2b8.U)
startsLoop := true.B
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst)) ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
} // replaces opcode 0: (a, b, accum) = (0, 2, 0), op 1 = (0, 2, 1), op 2 = (1, 3, 1), op 3 = (1, 3, 0) } // replaces opcode 0: (a, b, accum) = (0, 2, 0), op 1 = (0, 2, 1), op 2 = (1, 3, 1), op 3 = (1, 3, 0)
is (1.U) { // compute on given hexadeciles and mvout to spad is (1.U) { // compute on given hexadeciles and mvout to spad
val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8)) val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8))
// note that accumulation is disabled // note that accumulation is disabled
val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U) val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U)
startsLoop := true.B
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst)) ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
} }
is (2.U) { // no actual invocation, fake job placeholder is (2.U) {} // no actual invocation, fake job placeholder
startsLoop := true.B
}
is (8.U) { // set a, b stride is (8.U) { // set a, b stride
val inst = Wire(ciscInstT) val inst = Wire(ciscInstT)
inst.inst := 0x1820b07b.U inst.inst := 0x1820b07b.U
@@ -258,13 +259,11 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
} }
is (9.U) { // move out to scratchpad is (9.U) { // move out to scratchpad
val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(7, 0) * spadHexadecile.U) << 32).asUInt | 0x278.U) val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(7, 0) * spadHexadecile.U) << 32).asUInt | 0x278.U)
startsLoop := true.B
ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst)) ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst))
} }
is (10.U) { // load to scratchpad hexadeciles is (10.U) { // load to scratchpad hexadeciles
val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8)) val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8))
val accSkipInst = genAccSkipInst(1.U, 0x2e0.U) val accSkipInst = genAccSkipInst(1.U, 0x2e0.U)
startsLoop := true.B
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst)) ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
} // replaces opcode 10: (a, b) = (0, 2), opcode 11 = (1, 3), opcode 12 = (0, 0), opcode 13 = (2, 2) } // replaces opcode 10: (a, b) = (0, 2), opcode 11 = (1, 3), opcode 12 = (0, 0), opcode 13 = (2, 2)
is (11.U) { // set d, c stride is (11.U) { // set d, c stride
@@ -276,7 +275,6 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
} }
is (12.U) { // store to gmem is (12.U) { // store to gmem
val accSkipInst = genAccSkipInst(0.U, 0x78.U) val accSkipInst = genAccSkipInst(0.U, 0x78.U)
startsLoop := true.B
ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst)) ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst))
} }
@@ -291,7 +289,7 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
} }
val completionCount = PopCount(outer.gemmini.module.completion_io.completed) val completionCount = PopCount(outer.gemmini.module.completion_io.completed)
val loopStarted = Mux(ciscValid && instCounter.value === 0.U && startsLoop, 1.U, 0.U) val loopStarted = Mux(startsLoop, 1.U, 0.U)
runningLoops := runningLoops + loopStarted - completionCount runningLoops := runningLoops + loopStarted - completionCount
assert(runningLoops + loopStarted >= completionCount) assert(runningLoops + loopStarted >= completionCount)