Specify cost function from command line

This commit is contained in:
Edward Wang
2017-08-01 13:55:32 -07:00
committed by edwardcwang
parent 923a08dfa1
commit a25c84f72c

View File

@@ -19,7 +19,7 @@ import Utils._
*/
// TODO: eventually explore compiling a single target memory using multiple
// different kinds of target memory.
trait CostMetric {
trait CostMetric extends Serializable {
/**
* Cost function that returns the cost of compiling a memory using a certain
* macro.
@@ -100,6 +100,32 @@ object NewDefaultMetric extends CostMetric {
}
}
object MacroCompilerUtil {
import java.io._
import java.util.Base64
// Adapted from https://stackoverflow.com/a/134918
/** Serialize an arbitrary object to String.
* Used to pass structured values through as an annotation. */
def objToString(o: Serializable): String = {
val baos: ByteArrayOutputStream = new ByteArrayOutputStream
val oos: ObjectOutputStream = new ObjectOutputStream(baos)
oos.writeObject(o)
oos.close()
return Base64.getEncoder.encodeToString(baos.toByteArray)
}
/** Deserialize an arbitrary object from String. */
def objFromString(s: String): AnyRef = {
val data = Base64.getDecoder.decode(s)
val ois: ObjectInputStream = new ObjectInputStream(new ByteArrayInputStream(data))
val o = ois.readObject
ois.close()
return o
}
}
object CostMetric {
/** Define some default metric. */
val default: CostMetric = NewDefaultMetric
@@ -108,30 +134,48 @@ object CostMetric {
def getCostMetric(m: String, params: Map[String, String]): CostMetric = m match {
case "default" => default
case "PalmerMetric" => PalmerMetric
case "ExternalMetric" => new ExternalMetric(params.get("path").get)
case "ExternalMetric" => {
try {
new ExternalMetric(params.get("path").get)
} catch {
case e: NoSuchElementException => throw new IllegalArgumentException("Missing parameter 'path'")
}
}
case "NewDefaultMetric" => NewDefaultMetric
case _ => throw new IllegalArgumentException("Invalid cost metric " + m)
}
}
object MacroCompilerAnnotation {
def apply(c: String, mem: File, lib: Option[File], synflops: Boolean): Annotation =
apply(c, mem.toString, lib map (_.toString), synflops)
/**
* Parameters associated to this MacroCompilerAnnotation.
* @param mem Path to memory lib
* @param lib Path to library lib or None if no libraries
* @param costMetric Cost metric to use
* @param synflops True to syn flops
*/
case class Params(mem: String, lib: Option[String], costMetric: CostMetric, synflops: Boolean)
/**
* Create a MacroCompilerAnnotation.
* @param c Name of the module(?) for this annotation.
* @param p Parameters (see above).
*/
def apply(c: String, p: Params): Annotation =
Annotation(CircuitName(c), classOf[MacroCompilerTransform], MacroCompilerUtil.objToString(p))
def apply(c: String, mem: String, lib: Option[String], synflops: Boolean): Annotation = {
Annotation(CircuitName(c), classOf[MacroCompilerTransform],
s"${mem} %s ${synflops}".format(lib getOrElse ""))
}
private val matcher = "([^ ]+) ([^ ]*) (true|false)".r
def unapply(a: Annotation) = a match {
case Annotation(CircuitName(c), t, matcher(mem, lib, synflops)) if t == classOf[MacroCompilerTransform] =>
Some((c, Some(mem), if (lib.isEmpty) None else Some(lib), synflops.toBoolean))
case Annotation(CircuitName(c), t, serialized) if t == classOf[MacroCompilerTransform] => {
val p: Params = MacroCompilerUtil.objFromString(serialized).asInstanceOf[Params]
Some(c, p)
}
case _ => None
}
}
class MacroCompilerPass(mems: Option[Seq[Macro]],
libs: Option[Seq[Macro]]) extends firrtl.passes.Pass {
libs: Option[Seq[Macro]],
costMetric: CostMetric = CostMetric.default) extends firrtl.passes.Pass {
def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = {
val pairedPorts = mem.sortedPorts zip lib.sortedPorts
@@ -437,7 +481,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
(best, cost)
case ((best, cost), lib) =>
// Run the cost function to evaluate this potential compile.
CostMetric.default.cost(mem, lib) match {
costMetric.cost(mem, lib) match {
case Some(newCost) => {
System.err.println(s"Cost of ${lib.src.name} for ${mem.src.name}: ${newCost}")
if (newCost > cost) (best, cost)
@@ -469,10 +513,9 @@ class MacroCompilerTransform extends Transform {
def inputForm = MidForm
def outputForm = MidForm
def execute(state: CircuitState) = getMyAnnotations(state) match {
case Seq(MacroCompilerAnnotation(state.circuit.main, memFile, libFile, synflops)) =>
require(memFile.isDefined)
case Seq(MacroCompilerAnnotation(state.circuit.main, MacroCompilerAnnotation.Params(memFile, libFile, costMetric, synflops))) =>
// Read, eliminate None, get only SRAM, make firrtl macro
val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(memFile) match {
val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(Some(memFile)) match {
case Some(x:Seq[mdf.macrolib.Macro]) =>
Some(Utils.filterForSRAM(Some(x)) getOrElse(List()) map {new Macro(_)})
case _ => None
@@ -483,7 +526,7 @@ class MacroCompilerTransform extends Transform {
case _ => None
}
val transforms = Seq(
new MacroCompilerPass(mems, libs),
new MacroCompilerPass(mems, libs, costMetric),
new SynFlopsPass(synflops, libs getOrElse mems.get))
(transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm)
case _ => state
@@ -529,9 +572,9 @@ object MacroCompiler extends App {
" -cp, --cost-param: Cost function parameter. (Optional depending on the cost function.). e.g. -c ExternalMetric -cp path /path/to/my/cost/script",
" --syn-flops: Produces synthesizable flop-based memories (for all memories and library memory macros); likely useful for simulation purposes") mkString "\n"
def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, Boolean) =
def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, CostParamMap, Boolean) =
args match {
case Nil => (map, synflops)
case Nil => (map, costMap, synflops)
case ("-m" | "--macro-list") :: value :: tail =>
parseArgs(map + (Macros -> value), costMap, synflops, tail)
case ("-l" | "--library") :: value :: tail =>
@@ -551,7 +594,7 @@ object MacroCompiler extends App {
}
def run(args: List[String]) {
val (params, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args)
val (params, costParams, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args)
try {
val macros = Utils.filterForSRAM(mdf.macrolib.Utils.readMDFFromPath(params.get(Macros))).get map (x => (new Macro(x)).blackbox)
@@ -562,8 +605,16 @@ object MacroCompiler extends App {
// Note: the last macro in the input list is (seemingly arbitrarily)
// determined as the firrtl "top-level module".
val circuit = Circuit(NoInfo, macros, macros.last.name)
val annotations = AnnotationMap(Seq(MacroCompilerAnnotation(
circuit.main, params.get(Macros).get, params.get(Library), synflops)))
val annotations = AnnotationMap(
Seq(MacroCompilerAnnotation(
circuit.main,
MacroCompilerAnnotation.Params(
params.get(Macros).get, params.get(Library),
CostMetric.getCostMetric(params.getOrElse(CostFunc, "default"), costParams),
synflops
)
))
)
val state = CircuitState(circuit, HighForm, Some(annotations))
// Run the compiler.