diff --git a/macros/src/main/scala/MacroCompiler.scala b/macros/src/main/scala/MacroCompiler.scala index 18b5ad0a..08f4ee34 100644 --- a/macros/src/main/scala/MacroCompiler.scala +++ b/macros/src/main/scala/MacroCompiler.scala @@ -36,12 +36,28 @@ trait CostMetric extends Serializable { * it cannot be compiled. */ def cost(mem: Macro, lib: Macro): Option[BigInt] + + /** + * Helper function to return the map of argments (or an empty map if there are none). + */ + def commandLineParams(): Map[String, String] + + // We also want this to show up for the class itself. + def name(): String +} + +// Is there a better way to do this? (static method associated to CostMetric) +trait CostMetricCompanion { + def name(): String + + /** Construct this cost metric from a command line mapping. */ + def construct(m: Map[String, String]): CostMetric } // Some default cost functions. /** Palmer's old metric. */ -object PalmerMetric extends CostMetric { +object PalmerMetric extends CostMetric with CostMetricCompanion { override def cost(mem: Macro, lib: Macro): Option[BigInt] = { /* Palmer: A quick cost function (that must be kept in sync with * memory_cost()) that attempts to avoid compiling unncessary @@ -51,6 +67,10 @@ object PalmerMetric extends CostMetric { // (mem.depth * mem.width) ??? } + + override def commandLineParams = Map() + override def name = "PalmerMetric" + override def construct(m: Map[String, String]) = PalmerMetric } /** @@ -87,11 +107,27 @@ class ExternalMetric(path: String) extends CostMetric { case e: NumberFormatException => None } } + + override def commandLineParams = Map("path" -> path) + override def name = ExternalMetric.name +} + +object ExternalMetric extends CostMetricCompanion { + override def name = "ExternalMetric" + + /** Construct this cost metric from a command line mapping. */ + override def construct(m: Map[String, String]) = { + val pathOption = m.get("path") + pathOption match { + case Some(path:String) => new ExternalMetric(path) + case _ => throw new IllegalArgumentException("ExternalMetric missing option 'path'") + } + } } /** The current default metric in barstools, re-defined by Donggyu. */ // TODO: write tests for this function to make sure it selects the right things -object NewDefaultMetric extends CostMetric { +object NewDefaultMetric extends CostMetric with CostMetricCompanion { override def cost(mem: Macro, lib: Macro): Option[BigInt] = { val memMask = mem.src.ports map (_.maskGran) find (_.isDefined) map (_.get) val libMask = lib.src.ports map (_.maskGran) find (_.isDefined) map (_.get) @@ -105,6 +141,10 @@ object NewDefaultMetric extends CostMetric { (lib.src.depth * lib.src.width + 1) // weights on # cells ) } + + override def commandLineParams = Map() + override def name = "NewDefaultMetric" + override def construct(m: Map[String, String]) = NewDefaultMetric } object MacroCompilerUtil { @@ -137,19 +177,31 @@ object CostMetric { /** Define some default metric. */ val default: CostMetric = NewDefaultMetric - /** Select a cost function from string. */ - def getCostMetric(m: String, params: Map[String, String]): CostMetric = m match { - case "default" => default - case "PalmerMetric" => PalmerMetric - case "ExternalMetric" => { - try { - new ExternalMetric(params.get("path").get) - } catch { - case e: NoSuchElementException => throw new IllegalArgumentException("Missing parameter 'path'") - } + val costMetricCreators: scala.collection.mutable.Map[String, CostMetricCompanion] = scala.collection.mutable.Map() + + // Register some default metrics + registerCostMetric(PalmerMetric) + registerCostMetric(ExternalMetric) + registerCostMetric(NewDefaultMetric) + + /** + * Register a cost metric. + * @param createFuncHelper Companion object to fetch the name and construct + * the metric. + */ + def registerCostMetric(createFuncHelper: CostMetricCompanion): Unit = { + costMetricCreators.update(createFuncHelper.name, createFuncHelper) + } + + /** Select a cost metric from string. */ + def getCostMetric(m: String, params: Map[String, String]): CostMetric = { + if (m == "default") { + CostMetric.default + } else if (!costMetricCreators.contains(m)) { + throw new IllegalArgumentException("Invalid cost metric " + m) + } else { + costMetricCreators.get(m).get.construct(params) } - case "NewDefaultMetric" => NewDefaultMetric - case _ => throw new IllegalArgumentException("Invalid cost metric " + m) } } diff --git a/macros/src/test/scala/CostFunction.scala b/macros/src/test/scala/CostFunction.scala index 71d1270d..44d25b68 100644 --- a/macros/src/test/scala/CostFunction.scala +++ b/macros/src/test/scala/CostFunction.scala @@ -8,9 +8,13 @@ import mdf.macrolib._ * A test metric that simply favours memories with smaller widths, to test that * the metric is chosen properly. */ -object TestMinWidthMetric extends CostMetric { +object TestMinWidthMetric extends CostMetric with CostMetricCompanion { // Smaller width = lower cost = favoured override def cost(mem: Macro, lib: Macro): Option[BigInt] = Some(lib.src.width) + + override def commandLineParams = Map() + override def name = "TestMinWidthMetric" + override def construct(m: Map[String, String]) = TestMinWidthMetric } /** Test that cost metric selection is working. */ @@ -19,7 +23,10 @@ class SelectCostMetric extends MacroCompilerSpec with HasSRAMGenerator { val lib = s"lib-SelectCostMetric.json" val v = s"SelectCostMetric.v" - override val costMetric = TestMinWidthMetric + // Cost metrics must be registered for them to work with the command line. + CostMetric.registerCostMetric(TestMinWidthMetric) + + override val costMetric = Some(TestMinWidthMetric) val libSRAMs = Seq( SRAMMacro( diff --git a/macros/src/test/scala/MacroCompilerSpec.scala b/macros/src/test/scala/MacroCompilerSpec.scala index 6dcc0efb..6fd95eb5 100644 --- a/macros/src/test/scala/MacroCompilerSpec.scala +++ b/macros/src/test/scala/MacroCompilerSpec.scala @@ -16,11 +16,26 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate val vPrefix: String = testDir // Override this to use a different cost metric. - val costMetric: CostMetric = CostMetric.default + // If this is None, the compile() call will not have any -c/-cp arguments, and + // execute() will use CostMetric.default. + val costMetric: Option[CostMetric] = None + private def getCostMetric: CostMetric = costMetric.getOrElse(CostMetric.default) + + private def costMetricCmdLine = { + costMetric match { + case None => Nil + case Some(m) => { + val name = m.name + val params = m.commandLineParams + List("-c", name) ++ params.flatMap{ case (key, value) => List("-cp", key, value) } + } + } + } private def args(mem: String, lib: Option[String], v: String, synflops: Boolean) = List("-m", mem.toString, "-v", v) ++ (lib match { case None => Nil case Some(l) => List("-l", l.toString) }) ++ + costMetricCmdLine ++ (if (synflops) List("--syn-flops") else Nil) // Run the full compiler as if from the command line interface. @@ -83,7 +98,7 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate val macros = mems map (_.blackbox) val circuit = Circuit(NoInfo, macros, macros.last.name) val passes = Seq( - new MacroCompilerPass(Some(mems), libs, costMetric), + new MacroCompilerPass(Some(mems), libs, getCostMetric), new SynFlopsPass(synflops, libs getOrElse mems), RemoveEmpty) val result: Circuit = (passes foldLeft circuit)((c, pass) => pass run c)