Implement command line cost metric selection
This commit is contained in:
@@ -36,12 +36,28 @@ trait CostMetric extends Serializable {
|
|||||||
* it cannot be compiled.
|
* it cannot be compiled.
|
||||||
*/
|
*/
|
||||||
def cost(mem: Macro, lib: Macro): Option[BigInt]
|
def cost(mem: Macro, lib: Macro): Option[BigInt]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function to return the map of argments (or an empty map if there are none).
|
||||||
|
*/
|
||||||
|
def commandLineParams(): Map[String, String]
|
||||||
|
|
||||||
|
// We also want this to show up for the class itself.
|
||||||
|
def name(): String
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is there a better way to do this? (static method associated to CostMetric)
|
||||||
|
trait CostMetricCompanion {
|
||||||
|
def name(): String
|
||||||
|
|
||||||
|
/** Construct this cost metric from a command line mapping. */
|
||||||
|
def construct(m: Map[String, String]): CostMetric
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some default cost functions.
|
// Some default cost functions.
|
||||||
|
|
||||||
/** Palmer's old metric. */
|
/** Palmer's old metric. */
|
||||||
object PalmerMetric extends CostMetric {
|
object PalmerMetric extends CostMetric with CostMetricCompanion {
|
||||||
override def cost(mem: Macro, lib: Macro): Option[BigInt] = {
|
override def cost(mem: Macro, lib: Macro): Option[BigInt] = {
|
||||||
/* Palmer: A quick cost function (that must be kept in sync with
|
/* Palmer: A quick cost function (that must be kept in sync with
|
||||||
* memory_cost()) that attempts to avoid compiling unncessary
|
* memory_cost()) that attempts to avoid compiling unncessary
|
||||||
@@ -51,6 +67,10 @@ object PalmerMetric extends CostMetric {
|
|||||||
// (mem.depth * mem.width)
|
// (mem.depth * mem.width)
|
||||||
???
|
???
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def commandLineParams = Map()
|
||||||
|
override def name = "PalmerMetric"
|
||||||
|
override def construct(m: Map[String, String]) = PalmerMetric
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -87,11 +107,27 @@ class ExternalMetric(path: String) extends CostMetric {
|
|||||||
case e: NumberFormatException => None
|
case e: NumberFormatException => None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def commandLineParams = Map("path" -> path)
|
||||||
|
override def name = ExternalMetric.name
|
||||||
|
}
|
||||||
|
|
||||||
|
object ExternalMetric extends CostMetricCompanion {
|
||||||
|
override def name = "ExternalMetric"
|
||||||
|
|
||||||
|
/** Construct this cost metric from a command line mapping. */
|
||||||
|
override def construct(m: Map[String, String]) = {
|
||||||
|
val pathOption = m.get("path")
|
||||||
|
pathOption match {
|
||||||
|
case Some(path:String) => new ExternalMetric(path)
|
||||||
|
case _ => throw new IllegalArgumentException("ExternalMetric missing option 'path'")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The current default metric in barstools, re-defined by Donggyu. */
|
/** The current default metric in barstools, re-defined by Donggyu. */
|
||||||
// TODO: write tests for this function to make sure it selects the right things
|
// TODO: write tests for this function to make sure it selects the right things
|
||||||
object NewDefaultMetric extends CostMetric {
|
object NewDefaultMetric extends CostMetric with CostMetricCompanion {
|
||||||
override def cost(mem: Macro, lib: Macro): Option[BigInt] = {
|
override def cost(mem: Macro, lib: Macro): Option[BigInt] = {
|
||||||
val memMask = mem.src.ports map (_.maskGran) find (_.isDefined) map (_.get)
|
val memMask = mem.src.ports map (_.maskGran) find (_.isDefined) map (_.get)
|
||||||
val libMask = lib.src.ports map (_.maskGran) find (_.isDefined) map (_.get)
|
val libMask = lib.src.ports map (_.maskGran) find (_.isDefined) map (_.get)
|
||||||
@@ -105,6 +141,10 @@ object NewDefaultMetric extends CostMetric {
|
|||||||
(lib.src.depth * lib.src.width + 1) // weights on # cells
|
(lib.src.depth * lib.src.width + 1) // weights on # cells
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def commandLineParams = Map()
|
||||||
|
override def name = "NewDefaultMetric"
|
||||||
|
override def construct(m: Map[String, String]) = NewDefaultMetric
|
||||||
}
|
}
|
||||||
|
|
||||||
object MacroCompilerUtil {
|
object MacroCompilerUtil {
|
||||||
@@ -137,19 +177,31 @@ object CostMetric {
|
|||||||
/** Define some default metric. */
|
/** Define some default metric. */
|
||||||
val default: CostMetric = NewDefaultMetric
|
val default: CostMetric = NewDefaultMetric
|
||||||
|
|
||||||
/** Select a cost function from string. */
|
val costMetricCreators: scala.collection.mutable.Map[String, CostMetricCompanion] = scala.collection.mutable.Map()
|
||||||
def getCostMetric(m: String, params: Map[String, String]): CostMetric = m match {
|
|
||||||
case "default" => default
|
// Register some default metrics
|
||||||
case "PalmerMetric" => PalmerMetric
|
registerCostMetric(PalmerMetric)
|
||||||
case "ExternalMetric" => {
|
registerCostMetric(ExternalMetric)
|
||||||
try {
|
registerCostMetric(NewDefaultMetric)
|
||||||
new ExternalMetric(params.get("path").get)
|
|
||||||
} catch {
|
/**
|
||||||
case e: NoSuchElementException => throw new IllegalArgumentException("Missing parameter 'path'")
|
* Register a cost metric.
|
||||||
}
|
* @param createFuncHelper Companion object to fetch the name and construct
|
||||||
|
* the metric.
|
||||||
|
*/
|
||||||
|
def registerCostMetric(createFuncHelper: CostMetricCompanion): Unit = {
|
||||||
|
costMetricCreators.update(createFuncHelper.name, createFuncHelper)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Select a cost metric from string. */
|
||||||
|
def getCostMetric(m: String, params: Map[String, String]): CostMetric = {
|
||||||
|
if (m == "default") {
|
||||||
|
CostMetric.default
|
||||||
|
} else if (!costMetricCreators.contains(m)) {
|
||||||
|
throw new IllegalArgumentException("Invalid cost metric " + m)
|
||||||
|
} else {
|
||||||
|
costMetricCreators.get(m).get.construct(params)
|
||||||
}
|
}
|
||||||
case "NewDefaultMetric" => NewDefaultMetric
|
|
||||||
case _ => throw new IllegalArgumentException("Invalid cost metric " + m)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,13 @@ import mdf.macrolib._
|
|||||||
* A test metric that simply favours memories with smaller widths, to test that
|
* A test metric that simply favours memories with smaller widths, to test that
|
||||||
* the metric is chosen properly.
|
* the metric is chosen properly.
|
||||||
*/
|
*/
|
||||||
object TestMinWidthMetric extends CostMetric {
|
object TestMinWidthMetric extends CostMetric with CostMetricCompanion {
|
||||||
// Smaller width = lower cost = favoured
|
// Smaller width = lower cost = favoured
|
||||||
override def cost(mem: Macro, lib: Macro): Option[BigInt] = Some(lib.src.width)
|
override def cost(mem: Macro, lib: Macro): Option[BigInt] = Some(lib.src.width)
|
||||||
|
|
||||||
|
override def commandLineParams = Map()
|
||||||
|
override def name = "TestMinWidthMetric"
|
||||||
|
override def construct(m: Map[String, String]) = TestMinWidthMetric
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Test that cost metric selection is working. */
|
/** Test that cost metric selection is working. */
|
||||||
@@ -19,7 +23,10 @@ class SelectCostMetric extends MacroCompilerSpec with HasSRAMGenerator {
|
|||||||
val lib = s"lib-SelectCostMetric.json"
|
val lib = s"lib-SelectCostMetric.json"
|
||||||
val v = s"SelectCostMetric.v"
|
val v = s"SelectCostMetric.v"
|
||||||
|
|
||||||
override val costMetric = TestMinWidthMetric
|
// Cost metrics must be registered for them to work with the command line.
|
||||||
|
CostMetric.registerCostMetric(TestMinWidthMetric)
|
||||||
|
|
||||||
|
override val costMetric = Some(TestMinWidthMetric)
|
||||||
|
|
||||||
val libSRAMs = Seq(
|
val libSRAMs = Seq(
|
||||||
SRAMMacro(
|
SRAMMacro(
|
||||||
|
|||||||
@@ -16,11 +16,26 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate
|
|||||||
val vPrefix: String = testDir
|
val vPrefix: String = testDir
|
||||||
|
|
||||||
// Override this to use a different cost metric.
|
// Override this to use a different cost metric.
|
||||||
val costMetric: CostMetric = CostMetric.default
|
// If this is None, the compile() call will not have any -c/-cp arguments, and
|
||||||
|
// execute() will use CostMetric.default.
|
||||||
|
val costMetric: Option[CostMetric] = None
|
||||||
|
private def getCostMetric: CostMetric = costMetric.getOrElse(CostMetric.default)
|
||||||
|
|
||||||
|
private def costMetricCmdLine = {
|
||||||
|
costMetric match {
|
||||||
|
case None => Nil
|
||||||
|
case Some(m) => {
|
||||||
|
val name = m.name
|
||||||
|
val params = m.commandLineParams
|
||||||
|
List("-c", name) ++ params.flatMap{ case (key, value) => List("-cp", key, value) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private def args(mem: String, lib: Option[String], v: String, synflops: Boolean) =
|
private def args(mem: String, lib: Option[String], v: String, synflops: Boolean) =
|
||||||
List("-m", mem.toString, "-v", v) ++
|
List("-m", mem.toString, "-v", v) ++
|
||||||
(lib match { case None => Nil case Some(l) => List("-l", l.toString) }) ++
|
(lib match { case None => Nil case Some(l) => List("-l", l.toString) }) ++
|
||||||
|
costMetricCmdLine ++
|
||||||
(if (synflops) List("--syn-flops") else Nil)
|
(if (synflops) List("--syn-flops") else Nil)
|
||||||
|
|
||||||
// Run the full compiler as if from the command line interface.
|
// Run the full compiler as if from the command line interface.
|
||||||
@@ -83,7 +98,7 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate
|
|||||||
val macros = mems map (_.blackbox)
|
val macros = mems map (_.blackbox)
|
||||||
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
||||||
val passes = Seq(
|
val passes = Seq(
|
||||||
new MacroCompilerPass(Some(mems), libs, costMetric),
|
new MacroCompilerPass(Some(mems), libs, getCostMetric),
|
||||||
new SynFlopsPass(synflops, libs getOrElse mems),
|
new SynFlopsPass(synflops, libs getOrElse mems),
|
||||||
RemoveEmpty)
|
RemoveEmpty)
|
||||||
val result: Circuit = (passes foldLeft circuit)((c, pass) => pass run c)
|
val result: Circuit = (passes foldLeft circuit)((c, pass) => pass run c)
|
||||||
|
|||||||
Reference in New Issue
Block a user