400 lines
17 KiB
Scala
400 lines
17 KiB
Scala
// See LICENSE for license details.
|
|
|
|
package barstools.tapeout.transforms.macros
|
|
|
|
import firrtl._
|
|
import firrtl.ir._
|
|
import firrtl.PrimOps
|
|
import firrtl.Utils._
|
|
import firrtl.annotations._
|
|
import firrtl.CompilerUtils.getLoweringTransforms
|
|
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
|
import java.io.{File, FileWriter}
|
|
import Utils._
|
|
|
|
object MacroCompilerAnnotation {
|
|
def apply(c: String, mem: File, lib: Option[File], synflops: Boolean) = {
|
|
Annotation(CircuitName(c), classOf[MacroCompilerTransform],
|
|
s"${mem} %s ${synflops}".format(lib map (_.toString) getOrElse ""))
|
|
}
|
|
private val matcher = "([^ ]+) ([^ ]*) (true|false)".r
|
|
def unapply(a: Annotation) = a match {
|
|
case Annotation(CircuitName(c), t, matcher(mem, lib, synflops)) if t == classOf[MacroCompilerTransform] =>
|
|
Some((c, Some(new File(mem)), if (lib.isEmpty) None else Some(new File(lib)), synflops.toBoolean))
|
|
case _ => None
|
|
}
|
|
}
|
|
|
|
class MacroCompilerPass(mems: Option[Seq[Macro]],
|
|
libs: Option[Seq[Macro]]) extends firrtl.passes.Pass {
|
|
def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = {
|
|
val pairedPorts = mem.sortedPorts zip lib.sortedPorts
|
|
|
|
// Parallel mapping
|
|
val pairs = ArrayBuffer[(BigInt, BigInt)]()
|
|
var last = 0
|
|
for (i <- 0 until mem.width.toInt) {
|
|
if (i <= last + 1) {
|
|
/* Palmer: Every memory is going to have to fit at least a single bit. */
|
|
// continue
|
|
} else if ((i - last) % lib.width.toInt == 0) {
|
|
/* Palmer: It's possible that we rolled over a memory's width here,
|
|
if so generate one. */
|
|
pairs += ((last, i-1))
|
|
last = i
|
|
} else {
|
|
/* Palmer: FIXME: This is a mess, I must just be super confused. */
|
|
for ((memPort, libPort) <- pairedPorts) {
|
|
(memPort.maskGran, libPort.maskGran) match {
|
|
case (_, Some(p)) if p == 1 => // continue
|
|
case (Some(p), _) if i % p == 0 =>
|
|
pairs += ((last, i-1))
|
|
last = i
|
|
case (_, None) => // continue
|
|
case (_, Some(p)) if p == lib.width => // continue
|
|
case _ =>
|
|
System.err println "Bit-mask (or unmasked) target memories are supported only"
|
|
return None
|
|
}
|
|
}
|
|
}
|
|
}
|
|
pairs += ((last, mem.width.toInt - 1))
|
|
|
|
// Serial mapping
|
|
val stmts = ArrayBuffer[Statement]()
|
|
val selects = HashMap[String, Expression]()
|
|
val outputs = HashMap[String, ArrayBuffer[(Expression, Expression)]]()
|
|
/* Palmer: If we've got a parallel memory then we've got to take the
|
|
* address bits into account. */
|
|
if (mem.depth > lib.depth) {
|
|
mem.ports foreach { port =>
|
|
val high = ceilLog2(mem.depth)
|
|
val low = ceilLog2(lib.depth)
|
|
val ref = WRef(port.addressName)
|
|
val name = s"${ref.name}_sel"
|
|
selects(ref.name) = WRef(name, UIntType(IntWidth(high-low)))
|
|
stmts += DefNode(NoInfo, name, bits(ref, high-1, low))
|
|
}
|
|
}
|
|
for ((off, i) <- (0 until mem.depth.toInt by lib.depth.toInt).zipWithIndex) {
|
|
for (j <- pairs.indices) {
|
|
val name = s"mem_${i}_${j}"
|
|
stmts += WDefInstance(NoInfo, name, lib.name, lib.tpe)
|
|
// connect extra ports
|
|
stmts ++= lib.extraPorts map { case (portName, portValue) =>
|
|
Connect(NoInfo, WSubField(WRef(name), portName), portValue)
|
|
}
|
|
}
|
|
for ((memPort, libPort) <- pairedPorts) {
|
|
val addrMatch = selects get memPort.addressName match {
|
|
case None => one
|
|
case Some(addr) =>
|
|
val index = UIntLiteral(i, IntWidth(bitWidth(addr.tpe)))
|
|
DoPrim(PrimOps.Eq, Seq(addr, index), Nil, index.tpe)
|
|
}
|
|
def andAddrMatch(e: Expression) = and(e, addrMatch)
|
|
val cats = ArrayBuffer[Expression]()
|
|
for (((low, high), j) <- pairs.zipWithIndex) {
|
|
val inst = WRef(s"mem_${i}_${j}", lib.tpe)
|
|
|
|
def connectPorts(mem: Expression,
|
|
lib: String,
|
|
polarity: Option[PortPolarity]): Statement =
|
|
Connect(NoInfo, WSubField(inst, lib), invert(mem, polarity))
|
|
|
|
// Clock port mapping
|
|
/* Palmer: FIXME: I don't handle memories with read/write clocks yet. */
|
|
stmts += connectPorts(WRef(memPort.clockName),
|
|
libPort.clockName,
|
|
libPort.clockPolarity)
|
|
|
|
// Adress port mapping
|
|
/* Palmer: The address port to a memory is just the low-order bits of
|
|
* the top address. */
|
|
stmts += connectPorts(WRef(memPort.addressName),
|
|
libPort.addressName,
|
|
libPort.addressPolarity)
|
|
|
|
// Output port mapping
|
|
(memPort.outputName, libPort.outputName) match {
|
|
case (Some(mem), Some(lib)) =>
|
|
/* Palmer: In order to produce the output of a memory we need to cat
|
|
* together a bunch of narrower memories, which can only be
|
|
* done after generating all the memories. This saves up the
|
|
* output statements for later. */
|
|
val name = s"${mem}_${i}_${j}"
|
|
val exp = invert(bits(WSubField(inst, lib), high-low, 0), libPort.outputPolarity)
|
|
stmts += DefNode(NoInfo, name, exp)
|
|
cats += WRef(name)
|
|
case (None, Some(lib)) =>
|
|
/* Palmer: If the inner memory has an output port but the outer
|
|
* one doesn't then it's safe to just leave the outer
|
|
* port floating. */
|
|
case (None, None) =>
|
|
/* Palmer: If there's no output ports at all (ie, read-only
|
|
* port on the memory) then just don't worry about it,
|
|
* there's nothing to do. */
|
|
case (Some(mem), None) =>
|
|
System.err println "WARNING: Unable to match output ports on memory"
|
|
System.err println s" outer output port: ${mem}"
|
|
return None
|
|
}
|
|
|
|
// Input port mapping
|
|
(memPort.inputName, libPort.inputName) match {
|
|
case (Some(mem), Some(lib)) =>
|
|
/* Palmer: The input port to a memory just needs to happen in parallel,
|
|
* this does a part select to narrow the memory down. */
|
|
stmts += connectPorts(bits(WRef(mem), high, low), lib, libPort.inputPolarity)
|
|
case (None, Some(lib)) =>
|
|
/* Palmer: If the inner memory has an input port but the other
|
|
* one doesn't then it's safe to just leave the inner
|
|
* port floating. This should be handled by the
|
|
* default value of the write enable, so nothing should
|
|
* every make it into the memory. */
|
|
case (None, None) =>
|
|
/* Palmer: If there's no input ports at all (ie, read-only
|
|
* port on the memory) then just don't worry about it,
|
|
* there's nothing to do. */
|
|
case (Some(mem), None) =>
|
|
System.err println "WARNING: Unable to match input ports on memory"
|
|
System.err println s" outer input port: ${mem}"
|
|
return None
|
|
}
|
|
|
|
// Mask port mapping
|
|
val memMask = memPort.maskName match {
|
|
case Some(mem) =>
|
|
/* Palmer: The bits from the outer memory's write mask that will be
|
|
* used as the write mask for this inner memory. */
|
|
if (libPort.effectiveMaskGran == libPort.width) {
|
|
bits(WRef(mem), low / memPort.effectiveMaskGran)
|
|
} else {
|
|
if (libPort.effectiveMaskGran != 1) {
|
|
// TODO
|
|
System.err println "only single-bit mask supported"
|
|
return None
|
|
}
|
|
cat(((low to high) map (i => bits(WRef(mem), i / memPort.effectiveMaskGran))).reverse)
|
|
}
|
|
case None =>
|
|
/* Palmer: If there is no input port on the source memory port
|
|
* then we don't ever want to turn on this write
|
|
* enable. Otherwise, we just _always_ turn on the
|
|
* write enable port on the inner memory. */
|
|
if (!libPort.maskName.isDefined) one
|
|
else {
|
|
val width = libPort.width / libPort.effectiveMaskGran
|
|
val value = (BigInt(1) << width.toInt) - 1
|
|
UIntLiteral(value, IntWidth(width))
|
|
}
|
|
}
|
|
|
|
// Write enable port mapping
|
|
val memWriteEnable = memPort.writeEnableName match {
|
|
case Some(mem) =>
|
|
/* Palmer: The outer memory's write enable port, or a constant 1 if
|
|
* there isn't a write enable port. */
|
|
WRef(mem)
|
|
case None =>
|
|
/* Palmer: If there is no input port on the source memory port
|
|
* then we don't ever want to turn on this write
|
|
* enable. Otherwise, we just _always_ turn on the
|
|
* write enable port on the inner memory. */
|
|
if (!memPort.inputName.isDefined) zero else one
|
|
}
|
|
|
|
// Chip enable port mapping
|
|
val memChipEnable = memPort.chipEnableName match {
|
|
case Some(mem) => WRef(mem)
|
|
case None => one
|
|
}
|
|
|
|
// Read enable port mapping
|
|
/* Palmer: It's safe to ignore read enables, but we pass them through
|
|
* to the vendor memory if there's a port on there that
|
|
* implements the read enables. */
|
|
(memPort.readEnableName, libPort.readEnableName) match {
|
|
case (_, None) =>
|
|
case (Some(mem), Some(lib)) =>
|
|
stmts += connectPorts(andAddrMatch(WRef(mem)), lib, libPort.readEnablePolarity)
|
|
case (None, Some(lib)) =>
|
|
stmts += connectPorts(andAddrMatch(not(memWriteEnable)), lib, libPort.readEnablePolarity)
|
|
}
|
|
|
|
/* Palmer: This is actually the memory compiler: it figures out how to
|
|
* implement the outer memory's collection of ports using what
|
|
* the inner memory has availiable. */
|
|
((libPort.maskName, libPort.writeEnableName, libPort.chipEnableName): @unchecked) match {
|
|
case (Some(mask), Some(we), Some(en)) =>
|
|
/* Palmer: This is the simple option: every port exists. */
|
|
stmts += connectPorts(memMask, mask, libPort.maskPolarity)
|
|
stmts += connectPorts(andAddrMatch(memWriteEnable), we, libPort.writeEnablePolarity)
|
|
stmts += connectPorts(andAddrMatch(memChipEnable), en, libPort.chipEnablePolarity)
|
|
case (Some(mask), Some(we), None) =>
|
|
/* Palmer: If we don't have a chip enable but do have */
|
|
stmts += connectPorts(memMask, mask, libPort.maskPolarity)
|
|
stmts += connectPorts(andAddrMatch(and(memWriteEnable, memChipEnable)),
|
|
we, libPort.writeEnablePolarity)
|
|
case (None, Some(we), Some(en)) if bitWidth(memMask.tpe) == 1 =>
|
|
/* Palmer: If we're expected to provide mask ports without a
|
|
* memory that actually has them then we can use the
|
|
* write enable port instead of the mask port. */
|
|
stmts += connectPorts(andAddrMatch(and(memWriteEnable, memMask)),
|
|
we, libPort.writeEnablePolarity)
|
|
stmts += connectPorts(andAddrMatch(memChipEnable), en, libPort.chipEnablePolarity)
|
|
case (None, Some(we), Some(en)) =>
|
|
// TODO
|
|
System.err println "cannot emulate multi-bit mask ports with write enable"
|
|
return None
|
|
case (None, None, None) =>
|
|
/* Palmer: There's nothing to do here since there aren't any
|
|
* ports to match up. */
|
|
}
|
|
}
|
|
// Cat macro outputs for selection
|
|
memPort.outputName match {
|
|
case Some(mem) if cats.nonEmpty =>
|
|
val name = s"${mem}_${i}"
|
|
stmts += DefNode(NoInfo, name, cat(cats.toSeq.reverse))
|
|
(outputs getOrElseUpdate (mem, ArrayBuffer[(Expression, Expression)]())) +=
|
|
(addrMatch -> WRef(name))
|
|
case _ =>
|
|
}
|
|
}
|
|
}
|
|
// Connect mem outputs
|
|
mem.ports foreach { port =>
|
|
port.outputName match {
|
|
case Some(mem) => outputs get mem match {
|
|
case Some(select) =>
|
|
val output = (select foldRight (zero: Expression)) {
|
|
case ((cond, tval), fval) => Mux(cond, tval, fval, fval.tpe) }
|
|
stmts += Connect(NoInfo, WRef(mem), output)
|
|
case None =>
|
|
}
|
|
case None =>
|
|
}
|
|
}
|
|
|
|
Some((mem.module(Block(stmts.toSeq)), lib.blackbox))
|
|
}
|
|
|
|
def run(c: Circuit): Circuit = {
|
|
val modules = (mems, libs) match {
|
|
case (Some(mems), Some(libs)) => (mems foldLeft c.modules){ (modules, mem) =>
|
|
val (best, cost) = (libs foldLeft (None: Option[(Module, ExtModule)], BigInt(Long.MaxValue))){
|
|
case ((best, area), lib) if mem.ports.size != lib.ports.size =>
|
|
/* Palmer: FIXME: This just assumes the Chisel and vendor ports are in the same
|
|
* order, but I'm starting with what actually gets generated. */
|
|
System.err println s"INFO: unable to compile ${mem.name} using ${lib.name} port count must match"
|
|
(best, area)
|
|
case ((best, area), lib) =>
|
|
/* Palmer: A quick cost function (that must be kept in sync with
|
|
* memory_cost()) that attempts to avoid compiling unncessary
|
|
* memories. This is a lower bound on the cost of compiling a
|
|
* memory: it assumes 100% bit-cell utilization when mapping. */
|
|
// val cost = 100 * (mem.depth * mem.width) / (lib.depth * lib.width) +
|
|
// (mem.depth * mem.width)
|
|
// Donggyu: I re-define cost
|
|
val cost = (((mem.depth - 1) / lib.depth) + 1) *
|
|
(((mem.width - 1) / lib.width) + 1) *
|
|
(lib.depth * lib.width + 1) // weights on # cells
|
|
System.err println s"Cost of ${lib.name} for ${mem.name}: ${cost}"
|
|
if (cost > area) (best, area)
|
|
else compile(mem, lib) match {
|
|
case None => (best, area)
|
|
case Some(p) => (Some(p), cost)
|
|
}
|
|
}
|
|
best match {
|
|
case None => modules
|
|
case Some((mod, bb)) =>
|
|
(modules filterNot (m => m.name == mod.name || m.name == bb.name)) ++ Seq(mod, bb)
|
|
}
|
|
}
|
|
case _ => c.modules
|
|
}
|
|
c.copy(modules = modules)
|
|
}
|
|
}
|
|
|
|
class MacroCompilerTransform extends Transform {
|
|
def inputForm = HighForm
|
|
def outputForm = HighForm
|
|
def execute(state: CircuitState) = getMyAnnotations(state) match {
|
|
case Seq(MacroCompilerAnnotation(state.circuit.main, memFile, libFile, synflops)) =>
|
|
require(memFile.isDefined)
|
|
val mems: Option[Seq[Macro]] = readJSON(memFile) map (_ map (x => new Macro(x)))
|
|
val libs: Option[Seq[Macro]] = readJSON(libFile) map (_ map (x => new Macro(x)))
|
|
val transforms = Seq(
|
|
new MacroCompilerPass(mems, libs),
|
|
new SynFlopsPass(synflops, libs getOrElse mems.get),
|
|
firrtl.passes.SplitExpressions
|
|
)
|
|
((transforms foldLeft state)((s, xform) => xform runTransform s))
|
|
}
|
|
}
|
|
|
|
class MacroCompiler extends Compiler {
|
|
def emitter = new VerilogEmitter
|
|
def transforms =
|
|
Seq(new MacroCompilerTransform) ++
|
|
getLoweringTransforms(firrtl.HighForm, firrtl.LowForm) // ++
|
|
// Seq(new LowFirrtlOptimization) // Todo: This is dangerous
|
|
}
|
|
|
|
object MacroCompiler extends App {
|
|
sealed trait MacroParam
|
|
case object Macros extends MacroParam
|
|
case object Library extends MacroParam
|
|
case object Verilog extends MacroParam
|
|
type MacroParamMap = Map[MacroParam, File]
|
|
val usage = Seq(
|
|
"Options:",
|
|
" -m, --macro-list: The set of macros to compile",
|
|
" -l, --library: The set of macros that have blackbox instances",
|
|
" -v, --verilog: Verilog output",
|
|
" --syn-flop: Produces synthesizable flop-based memories") mkString "\n"
|
|
|
|
def parseArgs(map: MacroParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, Boolean) =
|
|
args match {
|
|
case Nil => (map, synflops)
|
|
case ("-m" | "--macro-list") :: value :: tail =>
|
|
parseArgs(map + (Macros -> new File(value)), synflops, tail)
|
|
case ("-l" | "--library") :: value :: tail =>
|
|
parseArgs(map + (Library -> new File(value)), synflops, tail)
|
|
case ("-v" | "--verilog") :: value :: tail =>
|
|
parseArgs(map + (Verilog -> new File(value)), synflops, tail)
|
|
case "--syn-flops" :: tail =>
|
|
parseArgs(map, true, tail)
|
|
case arg :: tail =>
|
|
println(s"Unknown field $arg\n")
|
|
throw new Exception(usage)
|
|
}
|
|
|
|
def run(args: List[String]) {
|
|
val (params, synflops) = parseArgs(Map[MacroParam, File](), false, args)
|
|
try {
|
|
val macros = readJSON(params get Macros).get map (x => (new Macro(x)).blackbox)
|
|
val verilog = new FileWriter(params(Verilog))
|
|
if (macros.nonEmpty) {
|
|
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
|
val annotations = AnnotationMap(Seq(MacroCompilerAnnotation(
|
|
circuit.main, params(Macros), params get Library, synflops)))
|
|
val state = CircuitState(circuit, HighForm, Some(annotations))
|
|
val result = new MacroCompiler compile (state, verilog)
|
|
}
|
|
verilog.close
|
|
} catch {
|
|
case e: java.util.NoSuchElementException =>
|
|
throw new Exception(usage)
|
|
case e: Throwable =>
|
|
throw e
|
|
}
|
|
}
|
|
|
|
run(args.toList)
|
|
}
|