Move barstools tapeout src to tools/tapeout

This commit is contained in:
Jerry Zhao
2024-04-19 11:06:07 -07:00
parent 4830ebf239
commit 33a1fe3f7b
19 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,205 @@
// See LICENSE for license details.
package barstools.macros
/** Trait which can calculate the cost of compiling a memory against a certain
* library memory macro using a cost function.
*/
// TODO: eventually explore compiling a single target memory using multiple
// different kinds of target memory.
trait CostMetric extends Serializable {
/** Cost function that returns the cost of compiling a memory using a certain
* macro.
*
* @param mem Memory macro to compile (target memory)
* @param lib Library memory macro to use (library memory)
* @return The cost of this compile, defined by this cost metric, or None if
* it cannot be compiled.
*/
def cost(mem: Macro, lib: Macro): Option[Double]
/** Helper function to return the map of arguments (or an empty map if there are none).
*/
def commandLineParams(): Map[String, String]
// We also want this to show up for the class itself.
def name(): String
}
// Is there a better way to do this? (static method associated to CostMetric)
trait CostMetricCompanion {
def name(): String
/** Construct this cost metric from a command line mapping. */
def construct(m: Map[String, String]): CostMetric
}
// Some default cost functions.
/** Palmer's old metric.
* TODO: figure out what is the difference between this metric and the current
* default metric and either revive or delete this metric.
*/
object OldMetric extends CostMetric with CostMetricCompanion {
override def cost(mem: Macro, lib: Macro): Option[Double] = {
/* Palmer: A quick cost function (that must be kept in sync with
* memory_cost()) that attempts to avoid compiling unnecessary
* memories. This is a lower bound on the cost of compiling a
* memory: it assumes 100% bit-cell utilization when mapping. */
// val cost = 100 * (mem.depth * mem.width) / (lib.depth * lib.width) +
// (mem.depth * mem.width)
???
}
override def commandLineParams() = Map.empty[String, String]
override def name() = "OldMetric"
override def construct(m: Map[String, String]): CostMetric = OldMetric
}
/** An external cost function.
* Calls the specified path with paths to the JSON MDF representation of the mem
* and lib macros. The external executable should print a Double.
* None will be returned if the external executable does not print a valid
* Double.
*/
class ExternalMetric(path: String) extends CostMetric {
import mdf.macrolib.Utils.writeMacroToPath
import java.io._
import scala.language.postfixOps
import sys.process._
override def cost(mem: Macro, lib: Macro): Option[Double] = {
// Create temporary files.
val memFile = File.createTempFile("_macrocompiler_mem_", ".json")
val libFile = File.createTempFile("_macrocompiler_lib_", ".json")
writeMacroToPath(Some(memFile.getAbsolutePath), mem.src)
writeMacroToPath(Some(libFile.getAbsolutePath), lib.src)
// !! executes the given command
val result: String = (s"$path ${memFile.getAbsolutePath} ${libFile.getAbsolutePath}" !!).trim
// Remove temporary files.
memFile.delete()
libFile.delete()
try {
Some(result.toDouble)
} catch {
case _: NumberFormatException => None
}
}
override def commandLineParams() = Map("path" -> path)
override def name(): String = ExternalMetric.name()
}
object ExternalMetric extends CostMetricCompanion {
override def name() = "ExternalMetric"
/** Construct this cost metric from a command line mapping. */
override def construct(m: Map[String, String]): ExternalMetric = {
val pathOption = m.get("path")
pathOption match {
case Some(path: String) => new ExternalMetric(path)
case _ => throw new IllegalArgumentException("ExternalMetric missing option 'path'")
}
}
}
/** The current default metric in barstools, re-defined by Donggyu. */
// TODO: write tests for this function to make sure it selects the right things
object DefaultMetric extends CostMetric with CostMetricCompanion {
override def cost(mem: Macro, lib: Macro): Option[Double] = {
val memMask = mem.src.ports.map(_.maskGran).find(_.isDefined).flatten
val libMask = lib.src.ports.map(_.maskGran).find(_.isDefined).flatten
val memWidth = (memMask, libMask) match {
case (None, _) => mem.src.width
case (Some(p), None) =>
(mem.src.width / p) * math.ceil(
p.toDouble / lib.src.width
) * lib.src.width //We map the mask to distinct memories
case (Some(p), Some(m)) =>
if (m <= p) (mem.src.width / p) * math.ceil(p.toDouble / m) * m //Using multiple m's to create a p (integrally)
else (mem.src.width / p) * m //Waste the extra maskbits
}
val maskPenalty = (memMask, libMask) match {
case (None, Some(_)) => 0.001
case (_, _) => 0
}
val depthCost = math.ceil(mem.src.depth.toDouble / lib.src.depth.toDouble)
val widthCost = math.ceil(memWidth / lib.src.width.toDouble)
val bitsCost = (lib.src.depth * lib.src.width).toDouble
// Fraction of wasted bits plus const per mem
val requestedBits = (mem.src.depth * mem.src.width).toDouble
val bitsWasted = depthCost * widthCost * bitsCost - requestedBits
val wastedConst = 0.05 // 0 means waste as few bits with no regard for instance count
val costPerInst = wastedConst * depthCost * widthCost
Some(1.0 * bitsWasted / requestedBits + costPerInst + maskPenalty)
}
override def commandLineParams() = Map.empty[String, String]
override def name() = "DefaultMetric"
override def construct(m: Map[String, String]): CostMetric = DefaultMetric
}
object MacroCompilerUtil {
import java.io._
import java.util.Base64
// Adapted from https://stackoverflow.com/a/134918
/** Serialize an arbitrary object to String.
* Used to pass structured values through as an annotation.
*/
def objToString(o: Serializable): String = {
val byteOutput: ByteArrayOutputStream = new ByteArrayOutputStream
val objectOutput: ObjectOutputStream = new ObjectOutputStream(byteOutput)
objectOutput.writeObject(o)
objectOutput.close()
Base64.getEncoder.encodeToString(byteOutput.toByteArray)
}
/** Deserialize an arbitrary object from String. */
def objFromString(s: String): AnyRef = {
val data = Base64.getDecoder.decode(s)
val ois: ObjectInputStream = new ObjectInputStream(new ByteArrayInputStream(data))
val o = ois.readObject
ois.close()
o
}
}
object CostMetric {
/** Define some default metric. */
val default: CostMetric = DefaultMetric
val costMetricCreators: scala.collection.mutable.Map[String, CostMetricCompanion] = scala.collection.mutable.Map()
// Register some default metrics
registerCostMetric(OldMetric)
registerCostMetric(ExternalMetric)
registerCostMetric(DefaultMetric)
/** Register a cost metric.
* @param createFuncHelper Companion object to fetch the name and construct
* the metric.
*/
def registerCostMetric(createFuncHelper: CostMetricCompanion): Unit = {
costMetricCreators.update(createFuncHelper.name(), createFuncHelper)
}
/** Select a cost metric from string. */
def getCostMetric(m: String, params: Map[String, String]): CostMetric = {
if (m == "default") {
CostMetric.default
} else if (!costMetricCreators.contains(m)) {
throw new IllegalArgumentException("Invalid cost metric " + m)
} else {
costMetricCreators(m).construct(params)
}
}
}

View File

@@ -0,0 +1,981 @@
// See LICENSE for license details.
/** Terminology note:
* mem - target memory to compile, in design (e.g. Mem() in rocket)
* lib - technology SRAM(s) to use to compile mem
*/
package barstools.macros
import barstools.macros.Utils._
import firrtl.Utils.{one, zero, BoolType}
import firrtl.annotations._
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.stage.TransformManager.TransformDependency
import firrtl.stage.{FirrtlSourceAnnotation, FirrtlStage, Forms, OutputFileAnnotation, RunFirrtlTransformAnnotation}
import firrtl.{PrimOps, _}
import mdf.macrolib.{PolarizedPort, PortPolarity, SRAMCompiler, SRAMGroup, SRAMMacro}
import java.io.{File, FileWriter}
import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
case class MacroCompilerException(msg: String) extends Exception(msg)
// TODO The parameters could be unpacked here instead of keeping it in a serialized form
case class MacroCompilerAnnotation(content: String) extends NoTargetAnnotation {
import MacroCompilerAnnotation.Params
def params: Params = MacroCompilerUtil.objFromString(content).asInstanceOf[Params]
}
/** The MacroCompilerAnnotation to trigger the macro compiler.
* Note that this annotation does NOT actually target any modules for
* compilation. It simply holds all the settings for the memory compiler. The
* actual selection of which memories to compile is set in the Params.
*
* To use, simply annotate the entire circuit itself with this annotation and
* include [[MacroCompilerTransform]].
*/
object MacroCompilerAnnotation {
/** Macro compiler mode. */
sealed trait CompilerMode
/** Strict mode - must compile all memories or error out. */
case object Strict extends CompilerMode
/** Synflops mode - compile all memories with synflops (do not map to lib at all). */
case object Synflops extends CompilerMode
/** CompileAndSynflops mode - compile all memories and create mock versions of the target libs with synflops. */
case object CompileAndSynflops extends CompilerMode
/** FallbackSynflops - compile all memories to SRAM when possible and fall back to synflops if a memory fails. * */
case object FallbackSynflops extends CompilerMode
/** CompileAvailable - compile what is possible and do nothing with uncompiled memories. * */
case object CompileAvailable extends CompilerMode
/** The default mode for the macro compiler.
* TODO: Maybe set the default to FallbackSynflops (typical for
* vlsi_mem_gen-like scripts) once it's implemented?
*/
val Default: CompilerMode = CompileAvailable
// Options as list of (CompilerMode, command-line name, description)
val options: Seq[(CompilerMode, String, String)] = Seq(
(Default, "default", "Select the default option from below."),
(Strict, "strict", "Compile all memories to library or return an error."),
(
Synflops,
"synflops",
"Produces synthesizable flop-based memories for all memories (do not map to lib at all); likely useful for simulation purposes."
),
(
CompileAndSynflops,
"compileandsynflops",
"Compile all memories and create mock versions of the target libs with synflops; likely also useful for simulation purposes."
),
(
FallbackSynflops,
"fallbacksynflops",
"Compile all memories to library when possible and fall back to synthesizable flop-based memories when library synth is not possible."
),
(
CompileAvailable,
"compileavailable",
"Compile all memories to library when possible and do nothing in case of errors. (default)"
)
)
/** Helper function to select a compiler mode. */
def stringToCompilerMode(str: String): CompilerMode = options.collectFirst {
case (mode, cmd, _) if cmd == str => mode
} match {
case Some(x) => x
case None => throw new IllegalArgumentException("No such compiler mode " + str)
}
/** Parameters associated to this MacroCompilerAnnotation.
*
* @param mem Path to memory lib
* @param memFormat Type of memory lib (Some("conf"), Some("mdf"), or None (defaults to mdf))
* @param lib Path to library lib or None if no libraries
* @param hammerIR Path to HammerIR output or None (not generated in this case)
* @param costMetric Cost metric to use
* @param mode Compiler mode (see CompilerMode)
* @param forceCompile Set of memories to force compiling to lib regardless of the mode
* @param forceSynflops Set of memories to force compiling as flops regardless of the mode
*/
case class Params(
mem: String,
memFormat: Option[String],
lib: Option[String],
hammerIR: Option[String],
costMetric: CostMetric,
mode: CompilerMode,
useCompiler: Boolean,
forceCompile: Set[String],
forceSynflops: Set[String])
extends Serializable
/** Create a MacroCompilerAnnotation.
* @param c Top-level circuit name (see class description)
* @param p Parameters (see above).
*/
def apply(c: String, p: Params): MacroCompilerAnnotation =
MacroCompilerAnnotation(MacroCompilerUtil.objToString(p))
}
class MacroCompilerPass(
mems: Option[Seq[Macro]],
libs: Option[Seq[Macro]],
compilers: Option[SRAMCompiler],
hammerIR: Option[String],
costMetric: CostMetric = CostMetric.default,
mode: MacroCompilerAnnotation.CompilerMode = MacroCompilerAnnotation.Default)
extends firrtl.passes.Pass {
// Helper function to check the legality of bitPairs.
// e.g. ((0,21), (22,43)) is legal
// ((0,21), (22,21)) is illegal and will throw an assert
private def checkBitPairs(bitPairs: Seq[(BigInt, BigInt)]): Unit = {
bitPairs.foldLeft(BigInt(-1))((lastBit, nextPair) => {
assert(lastBit + 1 == nextPair._1, s"Pair's first bit ${nextPair._1} does not follow last bit $lastBit")
assert(nextPair._2 >= nextPair._1, s"Pair $nextPair in bitPairs $bitPairs is illegal")
nextPair._2
})
}
/** Calculate bit pairs.
* This is a list of submemories by width.
* The tuples are (lsb, msb) inclusive.
* Example: (0, 7) and (8, 15) might be a split for a width=16 memory into two width=8 target memories.
* Another example: (0, 3), (4, 7), (8, 11) may be a split for a width-12 memory into 3 width-4 target memories.
*
* @param mem Memory to compile
* @param lib Lib to compile with
* @return Bit pairs or empty list if there was an error.
*/
private def calculateBitPairs(mem: Macro, lib: Macro): Seq[(BigInt, BigInt)] = {
val pairedPorts = mem.sortedPorts.zip(lib.sortedPorts)
val bitPairs = ArrayBuffer[(BigInt, BigInt)]()
var currentLSB: BigInt = 0
// Process every bit in the mem width.
for (memBit <- 0 until mem.src.width) {
val bitsInCurrentMem = memBit - currentLSB
// We'll need to find a bitPair that works for *all* the ports of the memory.
// e.g. unmasked read port and masked write port.
// For each port, store a tentative candidate for the split.
// Afterwards, figure out which one to use.
val bitPairCandidates = ArrayBuffer[(BigInt, BigInt)]()
for ((memPort, libPort) <- pairedPorts) {
// Sanity check to make sure we only split once per bit, once per port.
var alreadySplit: Boolean = false
// Helper function to check if it's time to split memories.
// @param effectiveLibWidth Split memory when we have this many bits.
def splitMemory(effectiveLibWidth: Int): Unit = {
assert(!alreadySplit)
if (bitsInCurrentMem == effectiveLibWidth) {
bitPairCandidates += ((currentLSB, memBit - 1))
alreadySplit = true
}
}
// Make sure we don't have a maskGran larger than the width of the memory.
assert(memPort.src.effectiveMaskGran <= memPort.src.width.get)
assert(libPort.src.effectiveMaskGran <= libPort.src.width.get)
val libWidth = libPort.src.width.get
// Don't consider cases of maskGran == width as "masked" since those masks
// effectively function as write-enable bits.
val memMask = if (memPort.src.effectiveMaskGran == memPort.src.width.get) None else memPort.src.maskGran
val libMask = if (libPort.src.effectiveMaskGran == libPort.src.width.get) None else libPort.src.maskGran
(memMask, libMask) match {
// Neither lib nor mem is masked.
// No problems here.
case (None, None) => splitMemory(libWidth)
// Only the lib is masked.
// Not an issue; we can just make all the bits in the lib mask enabled.
case (None, Some(_)) => splitMemory(libWidth)
// Only the mem is masked.
case (Some(p), None) =>
if (p % libPort.src.width.get == 0) {
// If the mem mask is a multiple of the lib width, then we're good.
// Just roll over every lib width as usual.
// e.g. lib width=4, mem maskGran={4, 8, 12, 16, ...}
splitMemory(libWidth)
} else if (libPort.src.width.get % p == 0) {
// Lib width is a multiple of the mem mask.
// Consider the case where mem mask = 4 but lib width = 8, unmasked.
// We can still compile, but will need to waste the extra bits.
splitMemory(memMask.get)
} else {
// No neat multiples.
// We might still be able to compile extremely inefficiently.
if (p < libPort.src.width.get) {
// Compile using mem mask as the effective width. (note that lib is not masked)
// e.g. mem mask = 3, lib width = 8
splitMemory(memMask.get)
} else {
// e.g. mem mask = 13, lib width = 8
System.err.println(
s"Unmasked target memory: unaligned mem maskGran $p with lib (${lib.src.name}) width ${libPort.src.width.get} not supported"
)
return Seq()
}
}
// Both lib and mem are masked.
case (Some(m), Some(l)) =>
if (m == l) {
// Lib maskGran == mem maskGran, no problems
splitMemory(libWidth)
} else if (m > l) {
// Mem maskGran > lib maskGran
if (m % l == 0) {
// Mem maskGran is a multiple of lib maskGran, carry on as normal.
splitMemory(libWidth)
} else {
System.err.println(s"Mem maskGran $m is not a multiple of lib maskGran $l: currently not supported")
return Seq()
}
} else { // m < l
// Lib maskGran > mem maskGran.
if (l % m == 0) {
// Lib maskGran is a multiple of mem maskGran.
// e.g. lib maskGran = 8, mem maskGran = 4.
// In this case we can only compile very wastefully (by treating
// lib as a mem maskGran width memory) :(
splitMemory(memMask.get)
// TODO: there's an optimization that could allow us to pack more
// bits in and be more efficient.
// e.g. say if mem maskGran = 4, lib maskGran = 8, libWidth = 32
// We could use 16 of bit (bits 0-3, 8-11, 16-19, 24-27) instead
// of treating it as simply a width 4 (!!!) memory.
// This would require a major refactor though.
} else {
System.err.println(s"Lib maskGran $m is not a multiple of mem maskGran $l: currently not supported")
return Seq()
}
}
}
}
// Choose an actual bit pair to add.
// We'll have to choose the smallest one (e.g. unmasked read port might be more tolerant of a bigger split than the masked write port).
if (bitPairCandidates.isEmpty) {
// No pair needed to split, just continue
} else {
val bestPair = bitPairCandidates.reduceLeft((leftPair, rightPair) => {
if (leftPair._2 - leftPair._1 + 1 > rightPair._2 - rightPair._1 + 1) leftPair else rightPair
})
bitPairs += bestPair
currentLSB = bestPair._2 + BigInt(1) // advance the LSB pointer
}
}
// Add in the last chunk if there are any leftovers
bitPairs += ((currentLSB, mem.src.width - 1))
bitPairs
}.toSeq
def compile(mem: Macro, lib: Macro): Option[(Module, Macro)] = {
assert(
mem.sortedPorts.lengthCompare(lib.sortedPorts.length) == 0,
"mem and lib should have an equal number of ports"
)
val pairedPorts = mem.sortedPorts.zip(lib.sortedPorts)
// Width mapping. See calculateBitPairs.
val bitPairs: Seq[(BigInt, BigInt)] = calculateBitPairs(mem, lib)
if (bitPairs.isEmpty) {
System.err.println("Error occurred during bitPairs calculations (bitPairs is empty).")
return None
}
// Check bit pairs.
checkBitPairs(bitPairs)
// Depth mapping
val stmts = ArrayBuffer[Statement]()
val outputs = mutable.HashMap[String, ArrayBuffer[(Expression, Expression)]]()
val selects = mutable.HashMap[String, Expression]()
val selectRegs = mutable.HashMap[String, Expression]()
/* Palmer: If we've got a parallel memory then we've got to take the
* address bits into account. */
if (mem.src.depth > lib.src.depth) {
mem.src.ports.foreach { port =>
val high = MacroCompilerMath.ceilLog2(mem.src.depth)
val low = MacroCompilerMath.ceilLog2(lib.src.depth)
val ref = WRef(port.address.name)
val nodeName = s"${ref.name}_sel"
val tpe = UIntType(IntWidth(high - low))
selects(ref.name) = WRef(nodeName, tpe)
stmts += DefNode(NoInfo, nodeName, bits(ref, high - 1, low))
// Donggyu: output selection should be piped
if (port.output.isDefined) {
val regName = s"${ref.name}_sel_reg"
val enable = (port.chipEnable, port.readEnable) match {
case (Some(ce), Some(re)) =>
and(WRef(ce.name, BoolType), WRef(re.name, BoolType))
case (Some(ce), None) => WRef(ce.name, BoolType)
case (None, Some(re)) => WRef(re.name, BoolType)
case (None, None) => one
}
selectRegs(ref.name) = WRef(regName, tpe)
stmts += DefRegister(NoInfo, regName, tpe, WRef(port.clock.get.name), zero, WRef(regName))
stmts += Connect(NoInfo, WRef(regName), Mux(enable, WRef(nodeName), WRef(regName), tpe))
}
}
}
for ((_, i) <- BigInt(0).until(mem.src.depth, lib.src.depth).zipWithIndex) {
for (j <- bitPairs.indices) {
val name = s"mem_${i}_$j"
// Create the instance.
stmts += WDefInstance(NoInfo, name, lib.src.name, lib.tpe)
// Connect extra ports of the lib.
stmts ++= lib.extraPorts.map { case (portName, portValue) =>
Connect(NoInfo, WSubField(WRef(name), portName), portValue)
}
}
for ((memPort, libPort) <- pairedPorts) {
val addrMatch = selects.get(memPort.src.address.name) match {
case None => one
case Some(addr) =>
val index = UIntLiteral(i, IntWidth(bitWidth(addr.tpe)))
DoPrim(PrimOps.Eq, Seq(addr, index), Nil, index.tpe)
}
val addrMatchReg = selectRegs.get(memPort.src.address.name) match {
case None => one
case Some(reg) =>
val index = UIntLiteral(i, IntWidth(bitWidth(reg.tpe)))
DoPrim(PrimOps.Eq, Seq(reg, index), Nil, index.tpe)
}
def andAddrMatch(e: Expression) = {
and(e, addrMatch)
}
val cats = ArrayBuffer[Expression]()
for (((low, high), j) <- bitPairs.zipWithIndex) {
val inst = WRef(s"mem_${i}_$j", lib.tpe)
def connectPorts2(mem: Expression, lib: String, polarity: Option[PortPolarity]): Statement =
Connect(NoInfo, WSubField(inst, lib), portToExpression(mem, polarity))
def connectPorts(mem: Expression, lib: String, polarity: PortPolarity): Statement =
connectPorts2(mem, lib, Some(polarity))
// Clock port mapping
/* Palmer: FIXME: I don't handle memories with read/write clocks yet. */
/* Colin not all libPorts have clocks but all memPorts do*/
libPort.src.clock.foreach { cPort =>
stmts += connectPorts(WRef(memPort.src.clock.get.name), cPort.name, cPort.polarity)
}
// Adress port mapping
/* Palmer: The address port to a memory is just the low-order bits of
* the top address. */
stmts += connectPorts(WRef(memPort.src.address.name), libPort.src.address.name, libPort.src.address.polarity)
// Output port mapping
(memPort.src.output, libPort.src.output) match {
case (Some(PolarizedPort(mem, _)), Some(PolarizedPort(lib, lib_polarity))) =>
/* Palmer: In order to produce the output of a memory we need to cat
* together a bunch of narrower memories, which can only be
* done after generating all the memories. This saves up the
* output statements for later. */
val name = s"${mem}_${i}_$j" // This name is the output from the instance (mem vs ${mem}).
val exp = portToExpression(bits(WSubField(inst, lib), high - low, 0), Some(lib_polarity))
stmts += DefNode(NoInfo, name, exp)
cats += WRef(name)
case (None, Some(_)) =>
/* Palmer: If the inner memory has an output port but the outer
* one doesn't then it's safe to just leave the outer
* port floating. */
case (None, None) =>
/* Palmer: If there's no output ports at all (ie, read-only
* port on the memory) then just don't worry about it,
* there's nothing to do. */
case (Some(PolarizedPort(mem, _)), None) =>
System.err.println("WARNING: Unable to match output ports on memory")
System.err.println(s" outer output port: $mem")
return None
}
// Input port mapping
(memPort.src.input, libPort.src.input) match {
case (Some(PolarizedPort(mem, _)), Some(PolarizedPort(lib, lib_polarity))) =>
/* Palmer: The input port to a memory just needs to happen in parallel,
* this does a part select to narrow the memory down. */
stmts += connectPorts(bits(WRef(mem), high, low), lib, lib_polarity)
case (None, Some(lib)) =>
/* Palmer: If the inner memory has an input port but the other
* one doesn't then it's safe to just leave the inner
* port floating. This should be handled by the
* default value of the write enable, so nothing should
* every make it into the memory. */
//Firrtl cares about dangling inputs now tie it off
stmts += IsInvalid(NoInfo, WSubField(inst, lib.name))
case (None, None) =>
/* Palmer: If there's no input ports at all (ie, read-only
* port on the memory) then just don't worry about it,
* there's nothing to do. */
case (Some(PolarizedPort(mem, _)), None) =>
System.err.println("WARNING: Unable to match input ports on memory")
System.err.println(s" outer input port: $mem")
return None
}
// Mask port mapping
val memMask = memPort.src.maskPort match {
case Some(PolarizedPort(mem, _)) =>
/* Palmer: The bits from the outer memory's write mask that will be
* used as the write mask for this inner memory. */
if (libPort.src.effectiveMaskGran == libPort.src.width.get) {
bits(WRef(mem), low / memPort.src.effectiveMaskGran)
} else {
require(isPowerOfTwo(libPort.src.effectiveMaskGran), "only powers of two masks supported for now")
// How much of this lib's width we are effectively using.
// If we have a mem maskGran less than the lib's maskGran, we'll have to take the smaller maskGran.
// Example: if we have a lib whose maskGran is 8 but our mem's maskGran is 4.
// The other case is if we're using a larger lib than mem.
val usingLessThanLibMaskGran = memPort.src.maskGran.get < libPort.src.effectiveMaskGran
val effectiveLibWidth =
if (usingLessThanLibMaskGran)
memPort.src.maskGran.get
else
libPort.src.width.get
cat(
(0 until libPort.src.width.get by libPort.src.effectiveMaskGran)
.map(i => {
if (usingLessThanLibMaskGran && i >= effectiveLibWidth) {
// If the memMaskGran is smaller than the lib's gran, then
// zero out the upper bits.
zero
} else {
if ((low + i) >= memPort.src.width.get) {
// If our bit is larger than the whole width of the mem, just zero out the upper bits.
zero
} else {
// Pick the appropriate bit from the mem mask.
bits(WRef(mem), (low + i) / memPort.src.effectiveMaskGran)
}
}
})
.reverse
)
}
case None =>
/* If there is a lib mask port but no mem mask port, just turn on
* all bits of the lib mask port. */
if (libPort.src.maskPort.isDefined) {
val width = libPort.src.width.get / libPort.src.effectiveMaskGran
val value = (BigInt(1) << width) - 1
UIntLiteral(value, IntWidth(width))
} else {
// No mask ports on either side.
// We treat a "mask" of a single bit to be equivalent to a write
// enable (as used below).
one
}
}
// Write enable port mapping
val memWriteEnable = memPort.src.writeEnable match {
case Some(PolarizedPort(mem, _)) =>
/* Palmer: The outer memory's write enable port, or a constant 1 if
* there isn't a write enable port. */
WRef(mem)
case None =>
/* Palmer: If there is no input port on the source memory port
* then we don't ever want to turn on this write
* enable. Otherwise, we just _always_ turn on the
* write enable port on the inner memory. */
if (memPort.src.input.isEmpty) zero else one
}
// Chip enable port mapping
val memChipEnable = memPort.src.chipEnable match {
case Some(PolarizedPort(mem, _)) => WRef(mem)
case None => one
}
// Read enable port mapping
/* Palmer: It's safe to ignore read enables, but we pass them through
* to the vendor memory if there's a port on there that
* implements the read enables. */
(memPort.src.readEnable, libPort.src.readEnable) match {
case (_, None) =>
case (Some(PolarizedPort(mem, _)), Some(PolarizedPort(lib, lib_polarity))) =>
stmts += connectPorts(andAddrMatch(WRef(mem)), lib, lib_polarity)
case (None, Some(PolarizedPort(lib, lib_polarity))) =>
stmts += connectPorts(andAddrMatch(and(not(memWriteEnable), memChipEnable)), lib, lib_polarity)
}
/* Palmer: This is actually the memory compiler: it figures out how to
* implement the outer memory's collection of ports using what
* the inner memory has availiable. */
((libPort.src.maskPort, libPort.src.writeEnable, libPort.src.chipEnable): @unchecked) match {
case (
Some(PolarizedPort(mask, mask_polarity)),
Some(PolarizedPort(we, we_polarity)),
Some(PolarizedPort(en, en_polarity))
) =>
/* Palmer: This is the simple option: every port exists. */
stmts += connectPorts(memMask, mask, mask_polarity)
stmts += connectPorts(andAddrMatch(memWriteEnable), we, we_polarity)
stmts += connectPorts(andAddrMatch(memChipEnable), en, en_polarity)
case (Some(PolarizedPort(mask, mask_polarity)), Some(PolarizedPort(we, we_polarity)), None) =>
/* Palmer: If we don't have a chip enable but do have mask ports. */
stmts += connectPorts(memMask, mask, mask_polarity)
stmts += connectPorts(andAddrMatch(and(memWriteEnable, memChipEnable)), we, we_polarity)
case (None, Some(PolarizedPort(we, we_polarity)), chipEnable) =>
if (bitWidth(memMask.tpe) == 1) {
/* Palmer: If we're expected to provide mask ports without a
* memory that actually has them then we can use the
* write enable port instead of the mask port. */
chipEnable match {
case Some(PolarizedPort(en, en_polarity)) =>
stmts += connectPorts(andAddrMatch(and(memWriteEnable, memMask)), we, we_polarity)
stmts += connectPorts(andAddrMatch(memChipEnable), en, en_polarity)
case _ =>
stmts += connectPorts(
andAddrMatch(and(and(memWriteEnable, memChipEnable), memMask)),
we,
we_polarity
)
}
} else {
System.err.println("cannot emulate multi-bit mask ports with write enable")
return None
}
case (None, None, None) =>
// No write ports to match up (this may be a read-only port).
// This isn't necessarily an error condition.
}
}
// Cat macro outputs for selection
memPort.src.output match {
case Some(PolarizedPort(mem, _)) if cats.nonEmpty =>
val name = s"${mem}_$i"
stmts += DefNode(NoInfo, name, cat(cats.toSeq.reverse))
outputs.getOrElseUpdate(mem, ArrayBuffer[(Expression, Expression)]()) +=
(addrMatchReg -> WRef(name))
case _ =>
}
}
}
// Connect mem outputs
val zeroOutputValue: Expression = UIntLiteral(0, IntWidth(mem.src.width))
mem.src.ports.foreach { port =>
port.output match {
case Some(PolarizedPort(mem, _)) =>
outputs.get(mem) match {
case Some(select) =>
val output = select.foldRight(zeroOutputValue) { case ((cond, tval), fval) =>
Mux(cond, tval, fval, fval.tpe)
}
stmts += Connect(NoInfo, WRef(mem), output)
case None =>
}
case None =>
}
}
Some((mem.module(Block(stmts.toSeq)), lib))
}
def run(c: Circuit): Circuit = {
var firstLib = true
val modules = (mems, libs) match {
case (Some(mems), Some(libs)) =>
// Try to compile each of the memories in mems.
// The 'state' is c.modules, which is a list of all the firrtl modules
// in the 'circuit'.
mems.foldLeft(c.modules) { (modules, mem) =>
val sram = mem.src
def groupMatchesMask(group: SRAMGroup, mem: SRAMMacro): Boolean = {
val memMask = mem.ports.map(_.maskGran).find(_.isDefined).flatten
val libMask = group.ports.map(_.maskGran).find(_.isDefined).flatten
(memMask, libMask) match {
case (None, _) => true
case (Some(_), None) => false
case (Some(m), Some(l)) => l <= m //Ignore memories that don't have nice mask
}
}
// Add compiler memories that might map well to libs
val compLibs = compilers match {
case Some(SRAMCompiler(_, groups)) =>
groups
.filter(g => g.family == sram.family && groupMatchesMask(g, sram))
.map(g => {
for {
w <- g.width
d <- g.depth if (sram.width % w == 0) && (sram.depth % d == 0)
} yield Seq(new Macro(buildSRAMMacro(g, d, w, g.vt.head)))
})
case None => Seq()
}
val fullLibs = libs ++ compLibs.flatten.flatten
// Try to compile mem against each lib in libs, keeping track of the
// best compiled version, external lib used, and cost.
val (best, _) = fullLibs.foldLeft(None: Option[(Module, Macro)], Double.MaxValue) {
case ((best, cost), lib) if mem.src.ports.size != lib.src.ports.size =>
/* Palmer: FIXME: This just assumes the Chisel and vendor ports are in the same
* order, but I'm starting with what actually gets generated. */
System.err.println(s"INFO: unable to compile ${mem.src.name} using ${lib.src.name} port count must match")
(best, cost)
case ((best, cost), lib) =>
// Run the cost function to evaluate this potential compile.
costMetric.cost(mem, lib) match {
case Some(newCost) =>
//System.err.println(s"Cost of ${lib.src.name} for ${mem.src.name}: ${newCost}")
// Try compiling
compile(mem, lib) match {
// If it was successful and the new cost is lower
case Some(p) if newCost < cost => (Some(p), newCost)
case _ => (best, cost)
}
case _ => (best, cost) // Cost function rejected this combination.
}
}
// If we were able to compile anything, then replace the original module
// in the modules list with a compiled version, as well as the extmodule
// stub for the lib.
best match {
case None =>
if (mode == MacroCompilerAnnotation.Strict)
throw MacroCompilerException(
s"Target memory ${mem.src.name} could not be compiled and strict mode is activated - aborting."
)
else
modules
case Some((mod, bb)) =>
hammerIR match {
case Some(f) =>
val hammerIRWriter = new FileWriter(new File(f), !firstLib)
if (firstLib) hammerIRWriter.write("[\n")
hammerIRWriter.write(bb.src.toJSON().toString())
hammerIRWriter.write("\n,\n")
hammerIRWriter.close()
firstLib = false
case None =>
}
modules.filterNot(m => m.name == mod.name || m.name == bb.blackbox.name) ++ Seq(mod, bb.blackbox)
}
}
case _ => c.modules
}
c.copy(modules = modules)
}
}
class MacroCompilerTransform extends Transform with DependencyAPIMigration {
override def prerequisites: Seq[TransformDependency] = Forms.LowForm
override def optionalPrerequisites: Seq[TransformDependency] = Forms.LowFormOptimized
override def optionalPrerequisiteOf: Seq[Dependency[Emitter]] = Forms.LowEmitters
override def invalidates(a: Transform) = false
def execute(state: CircuitState): CircuitState = state.annotations.collect { case a: MacroCompilerAnnotation =>
a
} match {
case Seq(anno: MacroCompilerAnnotation) =>
val MacroCompilerAnnotation.Params(
memFile,
memFileFormat,
libFile,
hammerIR,
costMetric,
mode,
useCompiler,
forceCompile,
forceSynflops
) = anno.params
if (mode == MacroCompilerAnnotation.FallbackSynflops) {
throw new UnsupportedOperationException("Not implemented yet")
}
// Check that we don't have any modules both forced to compile and synflops.
assert(forceCompile.intersect(forceSynflops).isEmpty, "Cannot have modules both forced to compile and synflops")
// Read, eliminate None, get only SRAM, make firrtl macro
val mems: Option[Seq[Macro]] = (memFileFormat match {
case Some("conf") => readConfFromPath(Some(memFile))
case _ => mdf.macrolib.Utils.readMDFFromPath(Some(memFile))
}) match {
case Some(x: Seq[mdf.macrolib.Macro]) =>
Some(filterForSRAM(Some(x)).getOrElse(List()).map { new Macro(_) })
case _ => None
}
val libs: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(libFile) match {
case Some(x: Seq[mdf.macrolib.Macro]) =>
Some(filterForSRAM(Some(x)).getOrElse(List()).map { new Macro(_) })
case _ => None
}
val compilers: Option[mdf.macrolib.SRAMCompiler] = mdf.macrolib.Utils.readMDFFromPath(libFile) match {
case Some(x: Seq[mdf.macrolib.Macro]) =>
if (useCompiler) {
findSRAMCompiler(Some(x))
} else None
case _ => None
}
// Helper function to turn a set of mem names into a Seq[Macro].
def setToSeqMacro(names: Set[String]): Seq[Macro] = {
names.toSeq.map(memName => mems.get.collectFirst { case m if m.src.name == memName => m }.get)
}
// Build lists of memories for compilation and synflops.
val memCompile = mems.map { actualMems =>
val memsAdjustedForMode = if (mode == MacroCompilerAnnotation.Synflops) Seq.empty else actualMems
memsAdjustedForMode.filterNot(m => forceSynflops.contains(m.src.name)) ++ setToSeqMacro(forceCompile)
}
val memSynflops: Seq[Macro] = mems.map { actualMems =>
//
val memsAdjustedForMode = if (mode == MacroCompilerAnnotation.Synflops) actualMems else Seq.empty
memsAdjustedForMode.filterNot(m => forceCompile.contains(m.src.name)) ++ setToSeqMacro(forceSynflops)
}.getOrElse(Seq.empty)
val transforms = Seq(
new MacroCompilerPass(memCompile, libs, compilers, hammerIR, costMetric, mode),
new SynFlopsPass(
true,
memSynflops ++ (if (mode == MacroCompilerAnnotation.CompileAndSynflops) {
libs.get
} else {
Seq.empty
})
)
)
transforms.foldLeft(state)((s, xform) => xform.runTransform(s))
case _ => state
}
}
class MacroCompilerOptimizations extends SeqTransform with DependencyAPIMigration {
override def prerequisites: Seq[TransformDependency] = Forms.LowForm
override def optionalPrerequisites: Seq[TransformDependency] = Forms.LowFormOptimized
override def optionalPrerequisiteOf: Seq[Dependency[Emitter]] = Forms.LowEmitters
override def invalidates(a: Transform) = false
def transforms: Seq[Transform] = Seq(
passes.RemoveValidIf,
new firrtl.transforms.ConstantPropagation,
passes.memlib.VerilogMemDelays,
new firrtl.transforms.ConstantPropagation,
passes.SplitExpressions,
passes.CommonSubexpressionElimination
)
}
object MacroCompiler extends App {
sealed trait MacroParam
case object Macros extends MacroParam
case object MacrosFormat extends MacroParam
case object Library extends MacroParam
case object Verilog extends MacroParam
case object Firrtl extends MacroParam
case object HammerIR extends MacroParam
case object CostFunc extends MacroParam
case object Mode extends MacroParam
case object UseCompiler extends MacroParam
type MacroParamMap = Map[MacroParam, String]
type CostParamMap = Map[String, String]
type ForcedMemories = (Set[String], Set[String])
val modeOptions: Seq[String] = MacroCompilerAnnotation.options.map { case (_, cmd, description) =>
s" $cmd: $description"
}
val usage: String = (Seq(
"Options:",
" -n, --macro-conf: The set of macros to compile in firrtl-generated conf format (exclusive with -m)",
" -m, --macro-mdf: The set of macros to compile in MDF JSON format (exclusive with -n)",
" -l, --library: The set of macros that have blackbox instances",
" -u, --use-compiler: Flag, whether to use the memory compiler defined in library",
" -v, --verilog: Verilog output",
" -f, --firrtl: FIRRTL output (optional)",
" -hir, --hammer-ir: Hammer-IR output currently only needed for IP compilers",
" -c, --cost-func: Cost function to use. Optional (default: \"default\")",
" -cp, --cost-param: Cost function parameter. (Optional depending on the cost function.). e.g. -c ExternalMetric -cp path /path/to/my/cost/script",
" --force-compile [mem]: Force the given memory to be compiled to target libs regardless of the mode",
" --force-synflops [mem]: Force the given memory to be compiled via synflops regardless of the mode",
" --mode:"
) ++ modeOptions).mkString("\n")
@tailrec
def parseArgs(
map: MacroParamMap,
costMap: CostParamMap,
forcedMemories: ForcedMemories,
args: List[String]
): (MacroParamMap, CostParamMap, ForcedMemories) =
args match {
case Nil => (map, costMap, forcedMemories)
case ("-n" | "--macro-conf") :: value :: tail =>
parseArgs(map + (Macros -> value) + (MacrosFormat -> "conf"), costMap, forcedMemories, tail)
case ("-m" | "--macro-mdf") :: value :: tail =>
parseArgs(map + (Macros -> value) + (MacrosFormat -> "mdf"), costMap, forcedMemories, tail)
case ("-l" | "--library") :: value :: tail =>
parseArgs(map + (Library -> value), costMap, forcedMemories, tail)
case ("-u" | "--use-compiler") :: tail =>
parseArgs(map + (UseCompiler -> ""), costMap, forcedMemories, tail)
case ("-v" | "--verilog") :: value :: tail =>
parseArgs(map + (Verilog -> value), costMap, forcedMemories, tail)
case ("-f" | "--firrtl") :: value :: tail =>
parseArgs(map + (Firrtl -> value), costMap, forcedMemories, tail)
case ("-hir" | "--hammer-ir") :: value :: tail =>
parseArgs(map + (HammerIR -> value), costMap, forcedMemories, tail)
case ("-c" | "--cost-func") :: value :: tail =>
parseArgs(map + (CostFunc -> value), costMap, forcedMemories, tail)
case ("-cp" | "--cost-param") :: value1 :: value2 :: tail =>
parseArgs(map, costMap + (value1 -> value2), forcedMemories, tail)
case "--force-compile" :: value :: tail =>
parseArgs(map, costMap, forcedMemories.copy(_1 = forcedMemories._1 + value), tail)
case "--force-synflops" :: value :: tail =>
parseArgs(map, costMap, forcedMemories.copy(_2 = forcedMemories._2 + value), tail)
case "--mode" :: value :: tail =>
parseArgs(map + (Mode -> value), costMap, forcedMemories, tail)
case arg :: _ =>
println(s"Unknown field $arg\n")
println(usage)
sys.exit(1)
}
def run(args: List[String]): Unit = {
val (params, costParams, forcedMemories) =
parseArgs(Map[MacroParam, String](), Map[String, String](), (Set.empty, Set.empty), args)
try {
val macros = params.get(MacrosFormat) match {
case Some("conf") =>
filterForSRAM(readConfFromPath(params.get(Macros))).get.map(x => new Macro(x).blackbox)
case _ =>
filterForSRAM(mdf.macrolib.Utils.readMDFFromPath(params.get(Macros))).get
.map(x => new Macro(x).blackbox)
}
if (macros.nonEmpty) {
// Note: the last macro in the input list is (seemingly arbitrarily)
// determined as the firrtl "top-level module".
val circuit = Circuit(NoInfo, macros, macros.last.name)
val annotations = AnnotationSeq(
Seq(
MacroCompilerAnnotation(
circuit.main,
MacroCompilerAnnotation.Params(
params(Macros),
params.get(MacrosFormat),
params.get(Library),
params.get(HammerIR),
CostMetric.getCostMetric(params.getOrElse(CostFunc, "default"), costParams),
MacroCompilerAnnotation.stringToCompilerMode(params.getOrElse(Mode, "default")),
params.contains(UseCompiler),
forceCompile = forcedMemories._1,
forceSynflops = forcedMemories._2
)
)
)
)
// The actual MacroCompilerTransform basically just generates an input circuit
val macroCompilerInput = CircuitState(circuit, annotations)
val macroCompiled = (new MacroCompilerTransform).execute(macroCompilerInput)
// Run FIRRTL compiler
// For each generated module, have to create a new circuit with that module
// as top, and all other modules as ExtModules. This guarantees all modules
// are elaborated
val verilog = macroCompiled.circuit.modules
.map(_.name)
.map { macroName =>
val (mainMod, otherMods) = macroCompiled.circuit.modules.partition(_.name == macroName)
val extMods = otherMods.map(m => ExtModule(NoInfo, m.name, m.ports, m.name, Nil))
val circuit = Circuit(NoInfo, mainMod ++ extMods, macroName)
(new FirrtlStage)
.execute(
Array.empty,
Seq(
OutputFileAnnotation(params.get(Verilog).get),
RunFirrtlTransformAnnotation(new VerilogEmitter),
EmitCircuitAnnotation(classOf[VerilogEmitter]),
FirrtlSourceAnnotation(circuit.serialize)
)
)
.collect { case c: EmittedVerilogCircuitAnnotation => c }
.head
.value
.value
}
.mkString("\n")
val verilogWriter = new FileWriter(new File(params.get(Verilog).get))
verilogWriter.write(verilog)
verilogWriter.close()
params.get(HammerIR) match {
case Some(hammerIRFile: String) =>
val lines = FileUtils.getLines(hammerIRFile).toList
val hammerIRWriter = new FileWriter(new File(hammerIRFile))
// JSON means we need to destroy the last comma :(
lines.dropRight(1).foreach(l => hammerIRWriter.write(l + "\n"))
hammerIRWriter.write("]\n")
hammerIRWriter.close()
case None =>
}
} else {
// Warn user
System.err.println("WARNING: Empty *.mems.conf file. No memories generated.")
// Emit empty verilog file if no macros found
params.get(Verilog) match {
case Some(verilogFile: String) =>
// Create an empty verilog file
val verilogWriter = new FileWriter(new File(verilogFile))
verilogWriter.close()
case None =>
}
params.get(HammerIR) match {
case Some(hammerIRFile: String) =>
// Create an empty HammerIR file
val hammerIRWriter = new FileWriter(new File(hammerIRFile))
hammerIRWriter.write("[]\n")
hammerIRWriter.close()
case None =>
}
}
} catch {
case e: java.util.NoSuchElementException =>
if (args.isEmpty) {
println("Command line arguments must be specified")
} else {
e.printStackTrace()
}
e.printStackTrace()
sys.exit(1)
case e: MacroCompilerException =>
println(usage)
e.printStackTrace()
sys.exit(1)
case e: Throwable =>
throw e
}
}
run(args.toList)
}

View File

@@ -0,0 +1,152 @@
// See LICENSE for license details.
package barstools.macros
import barstools.macros.Utils._
import firrtl.Utils.{one, zero}
import firrtl._
import firrtl.ir._
import firrtl.passes.MemPortUtils.memPortField
import scala.collection.mutable
class SynFlopsPass(synflops: Boolean, libs: Seq[Macro]) extends firrtl.passes.Pass {
val extraMods: mutable.ArrayBuffer[Module] = scala.collection.mutable.ArrayBuffer.empty[Module]
lazy val libMods: Map[String, Module] = libs.map { lib =>
lib.src.name -> {
val (dataType, dataWidth) = lib.src.ports.foldLeft(None: Option[BigInt])((res, port) =>
(res, port.maskPort) match {
case (_, None) =>
res
case (None, Some(_)) =>
Some(port.effectiveMaskGran)
case (Some(x), Some(_)) =>
assert(x == port.effectiveMaskGran)
res
}
) match {
case None => (UIntType(IntWidth(lib.src.width)), lib.src.width)
case Some(gran) => (UIntType(IntWidth(gran)), gran.intValue)
}
val maxDepth = firrtl.Utils.min(lib.src.depth, 1 << 26)
// Change macro to be mapped onto to look like the below mem
// by changing its depth, and width
val lib_macro = new Macro(
lib.src.copy(
name = "split_" + lib.src.name,
depth = maxDepth,
width = dataWidth,
ports = lib.src.ports.map(p =>
p.copy(
width = p.width.map(_ => dataWidth),
depth = p.depth.map(_ => maxDepth),
maskGran = p.maskGran.map(_ => dataWidth)
)
)
)
)
val mod_macro = new MacroCompilerPass(None, None, None, None).compile(lib, lib_macro)
val (real_mod, real_macro) = mod_macro.get
val mem = DefMemory(
NoInfo,
"ram",
dataType,
maxDepth,
1, // writeLatency
1, // readLatency. This is possible because of VerilogMemDelays
real_macro.readers.indices.map(i => s"R_$i"),
real_macro.writers.indices.map(i => s"W_$i"),
real_macro.readwriters.indices.map(i => s"RW_$i")
)
val readConnects = real_macro.readers.zipWithIndex.flatMap { case (r, i) =>
val clock = portToExpression(r.src.clock.get)
val address = portToExpression(r.src.address)
val enable = (r.src.chipEnable, r.src.readEnable) match {
case (Some(en_port), Some(re_port)) =>
and(portToExpression(en_port), portToExpression(re_port))
case (Some(en_port), None) => portToExpression(en_port)
case (None, Some(re_port)) => portToExpression(re_port)
case (None, None) => one
}
val data = memPortField(mem, s"R_$i", "data")
val read = data
Seq(
Connect(NoInfo, memPortField(mem, s"R_$i", "clk"), clock),
Connect(NoInfo, memPortField(mem, s"R_$i", "addr"), address),
Connect(NoInfo, memPortField(mem, s"R_$i", "en"), enable),
Connect(NoInfo, WRef(r.src.output.get.name), read)
)
}
val writeConnects = real_macro.writers.zipWithIndex.flatMap { case (w, i) =>
val clock = portToExpression(w.src.clock.get)
val address = portToExpression(w.src.address)
val enable = (w.src.chipEnable, w.src.writeEnable) match {
case (Some(en), Some(we)) =>
and(portToExpression(en), portToExpression(we))
case (Some(en), None) => portToExpression(en)
case (None, Some(we)) => portToExpression(we)
case (None, None) => zero // is it possible?
}
val mask = w.src.maskPort match {
case Some(m) => portToExpression(m)
case None => one
}
val data = memPortField(mem, s"W_$i", "data")
val write = portToExpression(w.src.input.get)
Seq(
Connect(NoInfo, memPortField(mem, s"W_$i", "clk"), clock),
Connect(NoInfo, memPortField(mem, s"W_$i", "addr"), address),
Connect(NoInfo, memPortField(mem, s"W_$i", "en"), enable),
Connect(NoInfo, memPortField(mem, s"W_$i", "mask"), mask),
Connect(NoInfo, data, write)
)
}
val readwriteConnects = real_macro.readwriters.zipWithIndex.flatMap { case (rw, i) =>
val clock = portToExpression(rw.src.clock.get)
val address = portToExpression(rw.src.address)
val wmode = rw.src.writeEnable match {
case Some(we) => portToExpression(we)
case None => zero // is it possible?
}
val wmask = rw.src.maskPort match {
case Some(wm) => portToExpression(wm)
case None => one
}
val enable = (rw.src.chipEnable, rw.src.readEnable) match {
case (Some(en), Some(re)) =>
and(portToExpression(en), or(portToExpression(re), wmode))
case (Some(en), None) => portToExpression(en)
case (None, Some(re)) => or(portToExpression(re), wmode)
case (None, None) => one
}
val wdata = memPortField(mem, s"RW_$i", "wdata")
val rdata = memPortField(mem, s"RW_$i", "rdata")
val write = portToExpression(rw.src.input.get)
val read = rdata
Seq(
Connect(NoInfo, memPortField(mem, s"RW_$i", "clk"), clock),
Connect(NoInfo, memPortField(mem, s"RW_$i", "addr"), address),
Connect(NoInfo, memPortField(mem, s"RW_$i", "en"), enable),
Connect(NoInfo, memPortField(mem, s"RW_$i", "wmode"), wmode),
Connect(NoInfo, memPortField(mem, s"RW_$i", "wmask"), wmask),
Connect(NoInfo, WRef(rw.src.output.get.name), read),
Connect(NoInfo, wdata, write)
)
}
extraMods.append(real_macro.module(Block(mem +: (readConnects ++ writeConnects ++ readwriteConnects))))
real_mod
}
}.toMap
def run(c: Circuit): Circuit = {
if (!synflops) c
else c.copy(modules = c.modules.map(m => libMods.getOrElse(m.name, m)) ++ extraMods)
}
}

View File

@@ -0,0 +1,262 @@
// See LICENSE for license details.
package barstools.macros
import firrtl.Utils.BoolType
import firrtl.ir._
import firrtl.passes.memlib._
import firrtl.{PrimOps, _}
import mdf.macrolib.{Input => _, Output => _, _}
import scala.language.implicitConversions
object MacroCompilerMath {
def ceilLog2(x: BigInt): Int = (x - 1).bitLength
}
class FirrtlMacroPort(port: MacroPort) {
val src: MacroPort = port
val isReader: Boolean = port.output.nonEmpty && port.input.isEmpty
val isWriter: Boolean = port.input.nonEmpty && port.output.isEmpty
val isReadWriter: Boolean = port.input.nonEmpty && port.output.nonEmpty
val addrType: UIntType = UIntType(IntWidth(MacroCompilerMath.ceilLog2(port.depth.get).max(1)))
val dataType: UIntType = UIntType(IntWidth(port.width.get))
val maskType: UIntType = UIntType(IntWidth(port.width.get / port.effectiveMaskGran))
// Bundle representing this macro port.
val tpe: BundleType = BundleType(
Seq(Field(port.address.name, Flip, addrType)) ++
port.clock.map(p => Field(p.name, Flip, ClockType)) ++
port.input.map(p => Field(p.name, Flip, dataType)) ++
port.output.map(p => Field(p.name, Default, dataType)) ++
port.chipEnable.map(p => Field(p.name, Flip, BoolType)) ++
port.readEnable.map(p => Field(p.name, Flip, BoolType)) ++
port.writeEnable.map(p => Field(p.name, Flip, BoolType)) ++
port.maskPort.map(p => Field(p.name, Flip, maskType))
)
val ports: Seq[Port] = tpe.fields.map(f =>
Port(
NoInfo,
f.name,
f.flip match {
case Default => Output
case Flip => Input
},
f.tpe
)
)
}
// Reads an SRAMMacro and generates firrtl blackboxes.
class Macro(srcMacro: SRAMMacro) {
val src: SRAMMacro = srcMacro
val firrtlPorts: Seq[FirrtlMacroPort] = srcMacro.ports.map { new FirrtlMacroPort(_) }
val writers: Seq[FirrtlMacroPort] = firrtlPorts.filter(p => p.isWriter)
val readers: Seq[FirrtlMacroPort] = firrtlPorts.filter(p => p.isReader)
val readwriters: Seq[FirrtlMacroPort] = firrtlPorts.filter(p => p.isReadWriter)
val sortedPorts: Seq[FirrtlMacroPort] = writers ++ readers ++ readwriters
val extraPorts: Seq[(String, UIntLiteral)] = srcMacro.extraPorts.map { p =>
assert(p.portType == Constant) // TODO: release it?
val name = p.name
val width = BigInt(p.width.toLong)
val value = BigInt(p.value.toLong)
name -> UIntLiteral(value, IntWidth(width))
}
// Bundle representing this memory blackbox
val tpe: BundleType = BundleType(firrtlPorts.flatMap(_.tpe.fields))
private val modPorts = firrtlPorts.flatMap(_.ports) ++
extraPorts.map { case (name, value) => Port(NoInfo, name, Input, value.tpe) }
val blackbox: ExtModule = ExtModule(NoInfo, srcMacro.name, modPorts, srcMacro.name, Nil)
def module(body: Statement): Module = Module(NoInfo, srcMacro.name, modPorts, body)
}
object Utils {
def filterForSRAM(s: Option[Seq[mdf.macrolib.Macro]]): Option[Seq[mdf.macrolib.SRAMMacro]] = {
s match {
case Some(l: Seq[mdf.macrolib.Macro]) =>
Some(l.filter { _.isInstanceOf[mdf.macrolib.SRAMMacro] }.map { m => m.asInstanceOf[mdf.macrolib.SRAMMacro] })
case _ => None
}
}
// This utility reads a conf in and returns MDF like mdf.macrolib.Utils.readMDFFromPath
def readConfFromPath(path: Option[String]): Option[Seq[mdf.macrolib.Macro]] = {
path.map(p => Utils.readConfFromString(FileUtils.getText(p)))
}
def readConfFromString(str: String): Seq[mdf.macrolib.Macro] = {
MemConf.fromString(str).map { m: MemConf =>
val ports = m.ports.map { case (port, num) => Seq.fill(num)(port) }.reduce(_ ++ _)
SRAMMacro(
m.name,
m.width,
m.depth,
Utils.portSpecToFamily(ports),
Utils.portSpecToMacroPort(m.width, m.depth, m.maskGranularity, ports)
)
}
}
def portSpecToFamily(ports: Seq[MemPort]): String = {
val numR = ports.count { case ReadPort => true; case _ => false }
val numW = ports.count { case WritePort | MaskedWritePort => true; case _ => false }
val numRW = ports.count { case ReadWritePort | MaskedReadWritePort => true; case _ => false }
val numRStr = if (numR > 0) s"${numR}r" else ""
val numWStr = if (numW > 0) s"${numW}w" else ""
val numRWStr = if (numRW > 0) s"${numRW}rw" else ""
numRStr + numWStr + numRWStr
}
// This translates between two represenations of ports
def portSpecToMacroPort(width: Int, depth: BigInt, maskGran: Option[Int], ports: Seq[MemPort]): Seq[MacroPort] = {
var numR = 0
var numW = 0
var numRW = 0
ports.map {
case ReadPort =>
val portName = s"R$numR"
numR += 1
MacroPort(
width = Some(width),
depth = Some(depth),
address = PolarizedPort(s"${portName}_addr", ActiveHigh),
clock = Some(PolarizedPort(s"${portName}_clk", PositiveEdge)),
readEnable = Some(PolarizedPort(s"${portName}_en", ActiveHigh)),
output = Some(PolarizedPort(s"${portName}_data", ActiveHigh))
)
case WritePort =>
val portName = s"W$numW"
numW += 1
MacroPort(
width = Some(width),
depth = Some(depth),
address = PolarizedPort(s"${portName}_addr", ActiveHigh),
clock = Some(PolarizedPort(s"${portName}_clk", PositiveEdge)),
writeEnable = Some(PolarizedPort(s"${portName}_en", ActiveHigh)),
input = Some(PolarizedPort(s"${portName}_data", ActiveHigh))
)
case MaskedWritePort =>
val portName = s"W$numW"
numW += 1
MacroPort(
width = Some(width),
depth = Some(depth),
address = PolarizedPort(s"${portName}_addr", ActiveHigh),
clock = Some(PolarizedPort(s"${portName}_clk", PositiveEdge)),
writeEnable = Some(PolarizedPort(s"${portName}_en", ActiveHigh)),
maskPort = Some(PolarizedPort(s"${portName}_mask", ActiveHigh)),
maskGran = maskGran,
input = Some(PolarizedPort(s"${portName}_data", ActiveHigh))
)
case ReadWritePort =>
val portName = s"RW$numRW"
numRW += 1
MacroPort(
width = Some(width),
depth = Some(depth),
address = PolarizedPort(s"${portName}_addr", ActiveHigh),
clock = Some(PolarizedPort(s"${portName}_clk", PositiveEdge)),
chipEnable = Some(PolarizedPort(s"${portName}_en", ActiveHigh)),
writeEnable = Some(PolarizedPort(s"${portName}_wmode", ActiveHigh)),
input = Some(PolarizedPort(s"${portName}_wdata", ActiveHigh)),
output = Some(PolarizedPort(s"${portName}_rdata", ActiveHigh))
)
case MaskedReadWritePort =>
val portName = s"RW$numRW"
numRW += 1
MacroPort(
width = Some(width),
depth = Some(depth),
address = PolarizedPort(s"${portName}_addr", ActiveHigh),
clock = Some(PolarizedPort(s"${portName}_clk", PositiveEdge)),
chipEnable = Some(PolarizedPort(s"${portName}_en", ActiveHigh)),
writeEnable = Some(PolarizedPort(s"${portName}_wmode", ActiveHigh)),
maskPort = Some(PolarizedPort(s"${portName}_wmask", ActiveHigh)),
maskGran = maskGran,
input = Some(PolarizedPort(s"${portName}_wdata", ActiveHigh)),
output = Some(PolarizedPort(s"${portName}_rdata", ActiveHigh))
)
}
}
def findSRAMCompiler(s: Option[Seq[mdf.macrolib.Macro]]): Option[mdf.macrolib.SRAMCompiler] = {
s match {
case Some(l: Seq[mdf.macrolib.Macro]) =>
l.collectFirst { case x: mdf.macrolib.SRAMCompiler =>
x
}
case _ => None
}
}
def buildSRAMMacros(s: mdf.macrolib.SRAMCompiler): Seq[mdf.macrolib.SRAMMacro] = {
for {
g <- s.groups
d <- g.depth
w <- g.width
vt <- g.vt
} yield mdf.macrolib.SRAMMacro(
makeName(g, d, w, vt),
w,
d,
g.family,
g.ports.map(_.copy(width = Some(w), depth = Some(d))),
vt,
g.mux,
g.extraPorts
)
}
def buildSRAMMacro(g: mdf.macrolib.SRAMGroup, d: Int, w: Int, vt: String): mdf.macrolib.SRAMMacro = {
mdf.macrolib.SRAMMacro(
makeName(g, d, w, vt),
w,
d,
g.family,
g.ports.map(_.copy(width = Some(w), depth = Some(d))),
vt,
g.mux,
g.extraPorts
)
}
def makeName(g: mdf.macrolib.SRAMGroup, depth: Int, width: Int, vt: String): String = {
g.name.foldLeft("") { (builder, next) =>
next match {
case "depth" | "DEPTH" => builder + depth
case "width" | "WIDTH" => builder + width
case "vt" => builder + vt.toLowerCase
case "VT" => builder + vt.toUpperCase
case "family" => builder + g.family.toLowerCase
case "FAMILY" => builder + g.family.toUpperCase
case "mux" | "MUX" => builder + g.mux
case other => builder + other
}
}
}
def and(e1: Expression, e2: Expression): DoPrim =
DoPrim(PrimOps.And, Seq(e1, e2), Nil, e1.tpe)
def or(e1: Expression, e2: Expression): DoPrim =
DoPrim(PrimOps.Or, Seq(e1, e2), Nil, e1.tpe)
def bits(e: Expression, high: BigInt, low: BigInt): Expression =
DoPrim(PrimOps.Bits, Seq(e), Seq(high, low), UIntType(IntWidth(high - low + 1)))
def bits(e: Expression, idx: BigInt): Expression = bits(e, idx, idx)
def cat(es: Seq[Expression]): Expression =
if (es.size == 1) es.head
else DoPrim(PrimOps.Cat, Seq(es.head, cat(es.tail)), Nil, UnknownType)
def not(e: Expression): DoPrim =
DoPrim(PrimOps.Not, Seq(e), Nil, e.tpe)
// Convert a port to a FIRRTL expression, handling polarity along the way.
def portToExpression(pp: PolarizedPort): Expression =
portToExpression(WRef(pp.name), Some(pp.polarity))
def portToExpression(exp: Expression, polarity: Option[PortPolarity]): Expression =
polarity match {
case Some(ActiveLow) | Some(NegativeEdge) => not(exp)
case _ => exp
}
// Check if a number is a power of two
def isPowerOfTwo(x: Int): Boolean = (x & (x - 1)) == 0
}