Fix corner case in compiling a small mem using a large lib (#32)
* Refactor bit pairs calculation into a separate function * Minor clarifications * Clarify MacroCompilerSpec helpers * Add SmallTagArrayTest test * Fix corner case in compiling a small mem using a large lib
This commit is contained in:
@@ -102,17 +102,20 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
})
|
||||
}
|
||||
|
||||
def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = {
|
||||
/**
|
||||
* Calculate bit pairs.
|
||||
* This is a list of submemories by width.
|
||||
* The tuples are (lsb, msb) inclusive.
|
||||
* Example: (0, 7) and (8, 15) might be a split for a width=16 memory into two width=8 target memories.
|
||||
* Another example: (0, 3), (4, 7), (8, 11) may be a split for a width-12 memory into 3 width-4 target memories.
|
||||
*
|
||||
* @param mem Memory to compile
|
||||
* @param lib Lib to compile with
|
||||
* @return Bit pairs or empty list if there was an error.
|
||||
*/
|
||||
private def calculateBitPairs(mem: Macro, lib: Macro): Seq[(BigInt, BigInt)] = {
|
||||
val pairedPorts = mem.sortedPorts zip lib.sortedPorts
|
||||
|
||||
// Width mapping
|
||||
|
||||
/**
|
||||
* This is a list of submemories by width.
|
||||
* The tuples are (lsb, msb) inclusive.
|
||||
* e.g. (0, 7) and (8, 15) might be a split for a width=16 memory into two
|
||||
* width=8 memories.
|
||||
*/
|
||||
val bitPairs = ArrayBuffer[(BigInt, BigInt)]()
|
||||
var currentLSB: BigInt = 0
|
||||
|
||||
@@ -133,7 +136,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
// Helper function to check if it's time to split memories.
|
||||
// @param effectiveLibWidth Split memory when we have this many bits.
|
||||
def splitMemory(effectiveLibWidth: Int): Unit = {
|
||||
assert (!alreadySplit)
|
||||
assert(!alreadySplit)
|
||||
|
||||
if (bitsInCurrentMem == effectiveLibWidth) {
|
||||
bitPairCandidates += ((currentLSB, memBit - 1))
|
||||
@@ -142,8 +145,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
}
|
||||
|
||||
// Make sure we don't have a maskGran larger than the width of the memory.
|
||||
assert (memPort.src.effectiveMaskGran <= memPort.src.width.get)
|
||||
assert (libPort.src.effectiveMaskGran <= libPort.src.width.get)
|
||||
assert(memPort.src.effectiveMaskGran <= memPort.src.width.get)
|
||||
assert(libPort.src.effectiveMaskGran <= libPort.src.width.get)
|
||||
|
||||
val libWidth = libPort.src.width.get
|
||||
|
||||
@@ -182,8 +185,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
splitMemory(memMask.get)
|
||||
} else {
|
||||
// e.g. mem mask = 13, lib width = 8
|
||||
System.err.println(s"Unmasked target memory: unaligned mem maskGran ${p} with lib (${lib.src.name}) width ${libPort.src.width.get} not supported")
|
||||
return None
|
||||
System.err.println(s"Unmasked target memory: unaligned mem maskGran $p with lib (${lib.src.name}) width ${libPort.src.width.get} not supported")
|
||||
return Seq()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -199,8 +202,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
// Mem maskGran is a multiple of lib maskGran, carry on as normal.
|
||||
splitMemory(libWidth)
|
||||
} else {
|
||||
System.err.println(s"Mem maskGran ${m} is not a multiple of lib maskGran ${l}: currently not supported")
|
||||
return None
|
||||
System.err.println(s"Mem maskGran $m is not a multiple of lib maskGran $l: currently not supported")
|
||||
return Seq()
|
||||
}
|
||||
} else { // m < l
|
||||
// Lib maskGran > mem maskGran.
|
||||
@@ -218,8 +221,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
// of treating it as simply a width 4 (!!!) memory.
|
||||
// This would require a major refactor though.
|
||||
} else {
|
||||
System.err.println(s"Lib maskGran ${m} is not a multiple of mem maskGran ${l}: currently not supported")
|
||||
return None
|
||||
System.err.println(s"Lib maskGran $m is not a multiple of mem maskGran $l: currently not supported")
|
||||
return Seq()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,7 +231,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
|
||||
// Choose an actual bit pair to add.
|
||||
// We'll have to choose the smallest one (e.g. unmasked read port might be more tolerant of a bigger split than the masked write port).
|
||||
if (bitPairCandidates.length == 0) {
|
||||
if (bitPairCandidates.isEmpty) {
|
||||
// No pair needed to split, just continue
|
||||
} else {
|
||||
val bestPair = bitPairCandidates.reduceLeft((leftPair, rightPair) => {
|
||||
@@ -240,7 +243,22 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
}
|
||||
// Add in the last chunk if there are any leftovers
|
||||
bitPairs += ((currentLSB, mem.src.width.toInt - 1))
|
||||
// Check bit pairs
|
||||
|
||||
bitPairs.toSeq
|
||||
}
|
||||
|
||||
def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = {
|
||||
assert(mem.sortedPorts.lengthCompare(lib.sortedPorts.length) == 0,
|
||||
"mem and lib should have an equal number of ports")
|
||||
val pairedPorts = mem.sortedPorts zip lib.sortedPorts
|
||||
|
||||
// Width mapping. See calculateBitPairs.
|
||||
val bitPairs: Seq[(BigInt, BigInt)] = calculateBitPairs(mem, lib)
|
||||
if (bitPairs.isEmpty) {
|
||||
System.err.println("Error occurred during bitPairs calculations (bitPairs is empty).")
|
||||
return None
|
||||
}
|
||||
// Check bit pairs.
|
||||
checkBitPairs(bitPairs)
|
||||
|
||||
// Depth mapping
|
||||
@@ -278,8 +296,9 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
for ((off, i) <- (0 until mem.src.depth by lib.src.depth).zipWithIndex) {
|
||||
for (j <- bitPairs.indices) {
|
||||
val name = s"mem_${i}_${j}"
|
||||
// Create the instance.
|
||||
stmts += WDefInstance(NoInfo, name, lib.src.name, lib.tpe)
|
||||
// connect extra ports
|
||||
// Connect extra ports of the lib.
|
||||
stmts ++= lib.extraPorts map { case (portName, portValue) =>
|
||||
Connect(NoInfo, WSubField(WRef(name), portName), portValue)
|
||||
}
|
||||
@@ -383,14 +402,29 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
} else {
|
||||
require(isPowerOfTwo(libPort.src.effectiveMaskGran), "only powers of two masks supported for now")
|
||||
|
||||
val effectiveLibWidth = if (memPort.src.maskGran.get < libPort.src.effectiveMaskGran) memPort.src.maskGran.get else libPort.src.width.get
|
||||
// How much of this lib's width we are effectively using.
|
||||
// If we have a mem maskGran less than the lib's maskGran, we'll have to take the smaller maskGran.
|
||||
// Example: if we have a lib whose maskGran is 8 but our mem's maskGran is 4.
|
||||
// The other case is if we're using a larger lib than mem.
|
||||
val usingLessThanLibMaskGran = (memPort.src.maskGran.get < libPort.src.effectiveMaskGran)
|
||||
val effectiveLibWidth = if (usingLessThanLibMaskGran)
|
||||
memPort.src.maskGran.get
|
||||
else
|
||||
libPort.src.width.get
|
||||
|
||||
cat(((0 until libPort.src.width.get by libPort.src.effectiveMaskGran) map (i => {
|
||||
if (memPort.src.maskGran.get < libPort.src.effectiveMaskGran && i >= effectiveLibWidth) {
|
||||
if (usingLessThanLibMaskGran && i >= effectiveLibWidth) {
|
||||
// If the memMaskGran is smaller than the lib's gran, then
|
||||
// zero out the upper bits.
|
||||
zero
|
||||
} else {
|
||||
bits(WRef(mem), (low + i) / memPort.src.effectiveMaskGran)
|
||||
if (i >= memPort.src.width.get) {
|
||||
// If our bit is larger than the whole width of the mem, just zero out the upper bits.
|
||||
zero
|
||||
} else {
|
||||
// Pick the appropriate bit from the mem mask.
|
||||
bits(WRef(mem), (low + i) / memPort.src.effectiveMaskGran)
|
||||
}
|
||||
}
|
||||
})).reverse)
|
||||
}
|
||||
@@ -589,9 +623,11 @@ class MacroCompilerTransform extends Transform {
|
||||
|
||||
// FIXME: Use firrtl.LowerFirrtlOptimizations
|
||||
class MacroCompilerOptimizations extends SeqTransform {
|
||||
def inputForm = LowForm
|
||||
def outputForm = LowForm
|
||||
def transforms = Seq(
|
||||
def inputForm: CircuitForm = LowForm
|
||||
|
||||
def outputForm: CircuitForm = LowForm
|
||||
|
||||
def transforms: Seq[Transform] = Seq(
|
||||
passes.RemoveValidIf,
|
||||
new firrtl.transforms.ConstantPropagation,
|
||||
passes.memlib.VerilogMemDelays,
|
||||
@@ -602,11 +638,12 @@ class MacroCompilerOptimizations extends SeqTransform {
|
||||
}
|
||||
|
||||
class MacroCompiler extends Compiler {
|
||||
def emitter = new VerilogEmitter
|
||||
def transforms =
|
||||
def emitter: Emitter = new VerilogEmitter
|
||||
|
||||
def transforms: Seq[Transform] =
|
||||
Seq(new MacroCompilerTransform) ++
|
||||
getLoweringTransforms(firrtl.HighForm, firrtl.LowForm) ++
|
||||
Seq(new MacroCompilerOptimizations)
|
||||
getLoweringTransforms(firrtl.HighForm, firrtl.LowForm) ++
|
||||
Seq(new MacroCompilerOptimizations)
|
||||
}
|
||||
|
||||
object MacroCompiler extends App {
|
||||
|
||||
Reference in New Issue
Block a user