Specify cost function from command line
This commit is contained in:
@@ -19,7 +19,7 @@ import Utils._
|
||||
*/
|
||||
// TODO: eventually explore compiling a single target memory using multiple
|
||||
// different kinds of target memory.
|
||||
trait CostMetric {
|
||||
trait CostMetric extends Serializable {
|
||||
/**
|
||||
* Cost function that returns the cost of compiling a memory using a certain
|
||||
* macro.
|
||||
@@ -100,6 +100,32 @@ object NewDefaultMetric extends CostMetric {
|
||||
}
|
||||
}
|
||||
|
||||
object MacroCompilerUtil {
|
||||
import java.io._
|
||||
import java.util.Base64
|
||||
|
||||
// Adapted from https://stackoverflow.com/a/134918
|
||||
|
||||
/** Serialize an arbitrary object to String.
|
||||
* Used to pass structured values through as an annotation. */
|
||||
def objToString(o: Serializable): String = {
|
||||
val baos: ByteArrayOutputStream = new ByteArrayOutputStream
|
||||
val oos: ObjectOutputStream = new ObjectOutputStream(baos)
|
||||
oos.writeObject(o)
|
||||
oos.close()
|
||||
return Base64.getEncoder.encodeToString(baos.toByteArray)
|
||||
}
|
||||
|
||||
/** Deserialize an arbitrary object from String. */
|
||||
def objFromString(s: String): AnyRef = {
|
||||
val data = Base64.getDecoder.decode(s)
|
||||
val ois: ObjectInputStream = new ObjectInputStream(new ByteArrayInputStream(data))
|
||||
val o = ois.readObject
|
||||
ois.close()
|
||||
return o
|
||||
}
|
||||
}
|
||||
|
||||
object CostMetric {
|
||||
/** Define some default metric. */
|
||||
val default: CostMetric = NewDefaultMetric
|
||||
@@ -108,30 +134,48 @@ object CostMetric {
|
||||
def getCostMetric(m: String, params: Map[String, String]): CostMetric = m match {
|
||||
case "default" => default
|
||||
case "PalmerMetric" => PalmerMetric
|
||||
case "ExternalMetric" => new ExternalMetric(params.get("path").get)
|
||||
case "ExternalMetric" => {
|
||||
try {
|
||||
new ExternalMetric(params.get("path").get)
|
||||
} catch {
|
||||
case e: NoSuchElementException => throw new IllegalArgumentException("Missing parameter 'path'")
|
||||
}
|
||||
}
|
||||
case "NewDefaultMetric" => NewDefaultMetric
|
||||
case _ => throw new IllegalArgumentException("Invalid cost metric " + m)
|
||||
}
|
||||
}
|
||||
|
||||
object MacroCompilerAnnotation {
|
||||
def apply(c: String, mem: File, lib: Option[File], synflops: Boolean): Annotation =
|
||||
apply(c, mem.toString, lib map (_.toString), synflops)
|
||||
/**
|
||||
* Parameters associated to this MacroCompilerAnnotation.
|
||||
* @param mem Path to memory lib
|
||||
* @param lib Path to library lib or None if no libraries
|
||||
* @param costMetric Cost metric to use
|
||||
* @param synflops True to syn flops
|
||||
*/
|
||||
case class Params(mem: String, lib: Option[String], costMetric: CostMetric, synflops: Boolean)
|
||||
|
||||
/**
|
||||
* Create a MacroCompilerAnnotation.
|
||||
* @param c Name of the module(?) for this annotation.
|
||||
* @param p Parameters (see above).
|
||||
*/
|
||||
def apply(c: String, p: Params): Annotation =
|
||||
Annotation(CircuitName(c), classOf[MacroCompilerTransform], MacroCompilerUtil.objToString(p))
|
||||
|
||||
def apply(c: String, mem: String, lib: Option[String], synflops: Boolean): Annotation = {
|
||||
Annotation(CircuitName(c), classOf[MacroCompilerTransform],
|
||||
s"${mem} %s ${synflops}".format(lib getOrElse ""))
|
||||
}
|
||||
private val matcher = "([^ ]+) ([^ ]*) (true|false)".r
|
||||
def unapply(a: Annotation) = a match {
|
||||
case Annotation(CircuitName(c), t, matcher(mem, lib, synflops)) if t == classOf[MacroCompilerTransform] =>
|
||||
Some((c, Some(mem), if (lib.isEmpty) None else Some(lib), synflops.toBoolean))
|
||||
case Annotation(CircuitName(c), t, serialized) if t == classOf[MacroCompilerTransform] => {
|
||||
val p: Params = MacroCompilerUtil.objFromString(serialized).asInstanceOf[Params]
|
||||
Some(c, p)
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
libs: Option[Seq[Macro]]) extends firrtl.passes.Pass {
|
||||
libs: Option[Seq[Macro]],
|
||||
costMetric: CostMetric = CostMetric.default) extends firrtl.passes.Pass {
|
||||
def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = {
|
||||
val pairedPorts = mem.sortedPorts zip lib.sortedPorts
|
||||
|
||||
@@ -437,7 +481,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||
(best, cost)
|
||||
case ((best, cost), lib) =>
|
||||
// Run the cost function to evaluate this potential compile.
|
||||
CostMetric.default.cost(mem, lib) match {
|
||||
costMetric.cost(mem, lib) match {
|
||||
case Some(newCost) => {
|
||||
System.err.println(s"Cost of ${lib.src.name} for ${mem.src.name}: ${newCost}")
|
||||
if (newCost > cost) (best, cost)
|
||||
@@ -469,10 +513,9 @@ class MacroCompilerTransform extends Transform {
|
||||
def inputForm = MidForm
|
||||
def outputForm = MidForm
|
||||
def execute(state: CircuitState) = getMyAnnotations(state) match {
|
||||
case Seq(MacroCompilerAnnotation(state.circuit.main, memFile, libFile, synflops)) =>
|
||||
require(memFile.isDefined)
|
||||
case Seq(MacroCompilerAnnotation(state.circuit.main, MacroCompilerAnnotation.Params(memFile, libFile, costMetric, synflops))) =>
|
||||
// Read, eliminate None, get only SRAM, make firrtl macro
|
||||
val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(memFile) match {
|
||||
val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(Some(memFile)) match {
|
||||
case Some(x:Seq[mdf.macrolib.Macro]) =>
|
||||
Some(Utils.filterForSRAM(Some(x)) getOrElse(List()) map {new Macro(_)})
|
||||
case _ => None
|
||||
@@ -483,7 +526,7 @@ class MacroCompilerTransform extends Transform {
|
||||
case _ => None
|
||||
}
|
||||
val transforms = Seq(
|
||||
new MacroCompilerPass(mems, libs),
|
||||
new MacroCompilerPass(mems, libs, costMetric),
|
||||
new SynFlopsPass(synflops, libs getOrElse mems.get))
|
||||
(transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm)
|
||||
case _ => state
|
||||
@@ -529,9 +572,9 @@ object MacroCompiler extends App {
|
||||
" -cp, --cost-param: Cost function parameter. (Optional depending on the cost function.). e.g. -c ExternalMetric -cp path /path/to/my/cost/script",
|
||||
" --syn-flops: Produces synthesizable flop-based memories (for all memories and library memory macros); likely useful for simulation purposes") mkString "\n"
|
||||
|
||||
def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, Boolean) =
|
||||
def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, CostParamMap, Boolean) =
|
||||
args match {
|
||||
case Nil => (map, synflops)
|
||||
case Nil => (map, costMap, synflops)
|
||||
case ("-m" | "--macro-list") :: value :: tail =>
|
||||
parseArgs(map + (Macros -> value), costMap, synflops, tail)
|
||||
case ("-l" | "--library") :: value :: tail =>
|
||||
@@ -551,7 +594,7 @@ object MacroCompiler extends App {
|
||||
}
|
||||
|
||||
def run(args: List[String]) {
|
||||
val (params, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args)
|
||||
val (params, costParams, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args)
|
||||
try {
|
||||
val macros = Utils.filterForSRAM(mdf.macrolib.Utils.readMDFFromPath(params.get(Macros))).get map (x => (new Macro(x)).blackbox)
|
||||
|
||||
@@ -562,8 +605,16 @@ object MacroCompiler extends App {
|
||||
// Note: the last macro in the input list is (seemingly arbitrarily)
|
||||
// determined as the firrtl "top-level module".
|
||||
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
||||
val annotations = AnnotationMap(Seq(MacroCompilerAnnotation(
|
||||
circuit.main, params.get(Macros).get, params.get(Library), synflops)))
|
||||
val annotations = AnnotationMap(
|
||||
Seq(MacroCompilerAnnotation(
|
||||
circuit.main,
|
||||
MacroCompilerAnnotation.Params(
|
||||
params.get(Macros).get, params.get(Library),
|
||||
CostMetric.getCostMetric(params.getOrElse(CostFunc, "default"), costParams),
|
||||
synflops
|
||||
)
|
||||
))
|
||||
)
|
||||
val state = CircuitState(circuit, HighForm, Some(annotations))
|
||||
|
||||
// Run the compiler.
|
||||
|
||||
Reference in New Issue
Block a user