Specify cost function from command line
This commit is contained in:
@@ -19,7 +19,7 @@ import Utils._
|
|||||||
*/
|
*/
|
||||||
// TODO: eventually explore compiling a single target memory using multiple
|
// TODO: eventually explore compiling a single target memory using multiple
|
||||||
// different kinds of target memory.
|
// different kinds of target memory.
|
||||||
trait CostMetric {
|
trait CostMetric extends Serializable {
|
||||||
/**
|
/**
|
||||||
* Cost function that returns the cost of compiling a memory using a certain
|
* Cost function that returns the cost of compiling a memory using a certain
|
||||||
* macro.
|
* macro.
|
||||||
@@ -100,6 +100,32 @@ object NewDefaultMetric extends CostMetric {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object MacroCompilerUtil {
|
||||||
|
import java.io._
|
||||||
|
import java.util.Base64
|
||||||
|
|
||||||
|
// Adapted from https://stackoverflow.com/a/134918
|
||||||
|
|
||||||
|
/** Serialize an arbitrary object to String.
|
||||||
|
* Used to pass structured values through as an annotation. */
|
||||||
|
def objToString(o: Serializable): String = {
|
||||||
|
val baos: ByteArrayOutputStream = new ByteArrayOutputStream
|
||||||
|
val oos: ObjectOutputStream = new ObjectOutputStream(baos)
|
||||||
|
oos.writeObject(o)
|
||||||
|
oos.close()
|
||||||
|
return Base64.getEncoder.encodeToString(baos.toByteArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Deserialize an arbitrary object from String. */
|
||||||
|
def objFromString(s: String): AnyRef = {
|
||||||
|
val data = Base64.getDecoder.decode(s)
|
||||||
|
val ois: ObjectInputStream = new ObjectInputStream(new ByteArrayInputStream(data))
|
||||||
|
val o = ois.readObject
|
||||||
|
ois.close()
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
object CostMetric {
|
object CostMetric {
|
||||||
/** Define some default metric. */
|
/** Define some default metric. */
|
||||||
val default: CostMetric = NewDefaultMetric
|
val default: CostMetric = NewDefaultMetric
|
||||||
@@ -108,30 +134,48 @@ object CostMetric {
|
|||||||
def getCostMetric(m: String, params: Map[String, String]): CostMetric = m match {
|
def getCostMetric(m: String, params: Map[String, String]): CostMetric = m match {
|
||||||
case "default" => default
|
case "default" => default
|
||||||
case "PalmerMetric" => PalmerMetric
|
case "PalmerMetric" => PalmerMetric
|
||||||
case "ExternalMetric" => new ExternalMetric(params.get("path").get)
|
case "ExternalMetric" => {
|
||||||
|
try {
|
||||||
|
new ExternalMetric(params.get("path").get)
|
||||||
|
} catch {
|
||||||
|
case e: NoSuchElementException => throw new IllegalArgumentException("Missing parameter 'path'")
|
||||||
|
}
|
||||||
|
}
|
||||||
case "NewDefaultMetric" => NewDefaultMetric
|
case "NewDefaultMetric" => NewDefaultMetric
|
||||||
case _ => throw new IllegalArgumentException("Invalid cost metric " + m)
|
case _ => throw new IllegalArgumentException("Invalid cost metric " + m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object MacroCompilerAnnotation {
|
object MacroCompilerAnnotation {
|
||||||
def apply(c: String, mem: File, lib: Option[File], synflops: Boolean): Annotation =
|
/**
|
||||||
apply(c, mem.toString, lib map (_.toString), synflops)
|
* Parameters associated to this MacroCompilerAnnotation.
|
||||||
|
* @param mem Path to memory lib
|
||||||
|
* @param lib Path to library lib or None if no libraries
|
||||||
|
* @param costMetric Cost metric to use
|
||||||
|
* @param synflops True to syn flops
|
||||||
|
*/
|
||||||
|
case class Params(mem: String, lib: Option[String], costMetric: CostMetric, synflops: Boolean)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a MacroCompilerAnnotation.
|
||||||
|
* @param c Name of the module(?) for this annotation.
|
||||||
|
* @param p Parameters (see above).
|
||||||
|
*/
|
||||||
|
def apply(c: String, p: Params): Annotation =
|
||||||
|
Annotation(CircuitName(c), classOf[MacroCompilerTransform], MacroCompilerUtil.objToString(p))
|
||||||
|
|
||||||
def apply(c: String, mem: String, lib: Option[String], synflops: Boolean): Annotation = {
|
|
||||||
Annotation(CircuitName(c), classOf[MacroCompilerTransform],
|
|
||||||
s"${mem} %s ${synflops}".format(lib getOrElse ""))
|
|
||||||
}
|
|
||||||
private val matcher = "([^ ]+) ([^ ]*) (true|false)".r
|
|
||||||
def unapply(a: Annotation) = a match {
|
def unapply(a: Annotation) = a match {
|
||||||
case Annotation(CircuitName(c), t, matcher(mem, lib, synflops)) if t == classOf[MacroCompilerTransform] =>
|
case Annotation(CircuitName(c), t, serialized) if t == classOf[MacroCompilerTransform] => {
|
||||||
Some((c, Some(mem), if (lib.isEmpty) None else Some(lib), synflops.toBoolean))
|
val p: Params = MacroCompilerUtil.objFromString(serialized).asInstanceOf[Params]
|
||||||
|
Some(c, p)
|
||||||
|
}
|
||||||
case _ => None
|
case _ => None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class MacroCompilerPass(mems: Option[Seq[Macro]],
|
class MacroCompilerPass(mems: Option[Seq[Macro]],
|
||||||
libs: Option[Seq[Macro]]) extends firrtl.passes.Pass {
|
libs: Option[Seq[Macro]],
|
||||||
|
costMetric: CostMetric = CostMetric.default) extends firrtl.passes.Pass {
|
||||||
def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = {
|
def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = {
|
||||||
val pairedPorts = mem.sortedPorts zip lib.sortedPorts
|
val pairedPorts = mem.sortedPorts zip lib.sortedPorts
|
||||||
|
|
||||||
@@ -437,7 +481,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]],
|
|||||||
(best, cost)
|
(best, cost)
|
||||||
case ((best, cost), lib) =>
|
case ((best, cost), lib) =>
|
||||||
// Run the cost function to evaluate this potential compile.
|
// Run the cost function to evaluate this potential compile.
|
||||||
CostMetric.default.cost(mem, lib) match {
|
costMetric.cost(mem, lib) match {
|
||||||
case Some(newCost) => {
|
case Some(newCost) => {
|
||||||
System.err.println(s"Cost of ${lib.src.name} for ${mem.src.name}: ${newCost}")
|
System.err.println(s"Cost of ${lib.src.name} for ${mem.src.name}: ${newCost}")
|
||||||
if (newCost > cost) (best, cost)
|
if (newCost > cost) (best, cost)
|
||||||
@@ -469,10 +513,9 @@ class MacroCompilerTransform extends Transform {
|
|||||||
def inputForm = MidForm
|
def inputForm = MidForm
|
||||||
def outputForm = MidForm
|
def outputForm = MidForm
|
||||||
def execute(state: CircuitState) = getMyAnnotations(state) match {
|
def execute(state: CircuitState) = getMyAnnotations(state) match {
|
||||||
case Seq(MacroCompilerAnnotation(state.circuit.main, memFile, libFile, synflops)) =>
|
case Seq(MacroCompilerAnnotation(state.circuit.main, MacroCompilerAnnotation.Params(memFile, libFile, costMetric, synflops))) =>
|
||||||
require(memFile.isDefined)
|
|
||||||
// Read, eliminate None, get only SRAM, make firrtl macro
|
// Read, eliminate None, get only SRAM, make firrtl macro
|
||||||
val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(memFile) match {
|
val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(Some(memFile)) match {
|
||||||
case Some(x:Seq[mdf.macrolib.Macro]) =>
|
case Some(x:Seq[mdf.macrolib.Macro]) =>
|
||||||
Some(Utils.filterForSRAM(Some(x)) getOrElse(List()) map {new Macro(_)})
|
Some(Utils.filterForSRAM(Some(x)) getOrElse(List()) map {new Macro(_)})
|
||||||
case _ => None
|
case _ => None
|
||||||
@@ -483,7 +526,7 @@ class MacroCompilerTransform extends Transform {
|
|||||||
case _ => None
|
case _ => None
|
||||||
}
|
}
|
||||||
val transforms = Seq(
|
val transforms = Seq(
|
||||||
new MacroCompilerPass(mems, libs),
|
new MacroCompilerPass(mems, libs, costMetric),
|
||||||
new SynFlopsPass(synflops, libs getOrElse mems.get))
|
new SynFlopsPass(synflops, libs getOrElse mems.get))
|
||||||
(transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm)
|
(transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm)
|
||||||
case _ => state
|
case _ => state
|
||||||
@@ -529,9 +572,9 @@ object MacroCompiler extends App {
|
|||||||
" -cp, --cost-param: Cost function parameter. (Optional depending on the cost function.). e.g. -c ExternalMetric -cp path /path/to/my/cost/script",
|
" -cp, --cost-param: Cost function parameter. (Optional depending on the cost function.). e.g. -c ExternalMetric -cp path /path/to/my/cost/script",
|
||||||
" --syn-flops: Produces synthesizable flop-based memories (for all memories and library memory macros); likely useful for simulation purposes") mkString "\n"
|
" --syn-flops: Produces synthesizable flop-based memories (for all memories and library memory macros); likely useful for simulation purposes") mkString "\n"
|
||||||
|
|
||||||
def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, Boolean) =
|
def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, CostParamMap, Boolean) =
|
||||||
args match {
|
args match {
|
||||||
case Nil => (map, synflops)
|
case Nil => (map, costMap, synflops)
|
||||||
case ("-m" | "--macro-list") :: value :: tail =>
|
case ("-m" | "--macro-list") :: value :: tail =>
|
||||||
parseArgs(map + (Macros -> value), costMap, synflops, tail)
|
parseArgs(map + (Macros -> value), costMap, synflops, tail)
|
||||||
case ("-l" | "--library") :: value :: tail =>
|
case ("-l" | "--library") :: value :: tail =>
|
||||||
@@ -551,7 +594,7 @@ object MacroCompiler extends App {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def run(args: List[String]) {
|
def run(args: List[String]) {
|
||||||
val (params, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args)
|
val (params, costParams, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args)
|
||||||
try {
|
try {
|
||||||
val macros = Utils.filterForSRAM(mdf.macrolib.Utils.readMDFFromPath(params.get(Macros))).get map (x => (new Macro(x)).blackbox)
|
val macros = Utils.filterForSRAM(mdf.macrolib.Utils.readMDFFromPath(params.get(Macros))).get map (x => (new Macro(x)).blackbox)
|
||||||
|
|
||||||
@@ -562,8 +605,16 @@ object MacroCompiler extends App {
|
|||||||
// Note: the last macro in the input list is (seemingly arbitrarily)
|
// Note: the last macro in the input list is (seemingly arbitrarily)
|
||||||
// determined as the firrtl "top-level module".
|
// determined as the firrtl "top-level module".
|
||||||
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
||||||
val annotations = AnnotationMap(Seq(MacroCompilerAnnotation(
|
val annotations = AnnotationMap(
|
||||||
circuit.main, params.get(Macros).get, params.get(Library), synflops)))
|
Seq(MacroCompilerAnnotation(
|
||||||
|
circuit.main,
|
||||||
|
MacroCompilerAnnotation.Params(
|
||||||
|
params.get(Macros).get, params.get(Library),
|
||||||
|
CostMetric.getCostMetric(params.getOrElse(CostFunc, "default"), costParams),
|
||||||
|
synflops
|
||||||
|
)
|
||||||
|
))
|
||||||
|
)
|
||||||
val state = CircuitState(circuit, HighForm, Some(annotations))
|
val state = CircuitState(circuit, HighForm, Some(annotations))
|
||||||
|
|
||||||
// Run the compiler.
|
// Run the compiler.
|
||||||
|
|||||||
Reference in New Issue
Block a user