diff --git a/macros/src/main/scala/MacroCompiler.scala b/macros/src/main/scala/MacroCompiler.scala index d26f5912..642b85ef 100644 --- a/macros/src/main/scala/MacroCompiler.scala +++ b/macros/src/main/scala/MacroCompiler.scala @@ -32,7 +32,9 @@ trait CostMetric { def cost(mem: Macro, lib: Macro): Option[BigInt] } -/** Some default cost functions. */ +// Some default cost functions. + +/** Palmer's old metric. */ object PalmerMetric extends CostMetric { override def cost(mem: Macro, lib: Macro): Option[BigInt] = { /* Palmer: A quick cost function (that must be kept in sync with @@ -45,7 +47,43 @@ object PalmerMetric extends CostMetric { } } -// The current default metric in barstools, re-defined by Donggyu. +/** + * An external cost function. + * Calls the specified path with paths to the JSON MDF representation of the mem + * and lib macros. The external executable should return a BigInt. + * None will be returned if the external executable does not return a valid + * BigInt. + */ +class ExternalMetric(path: String) extends CostMetric { + import mdf.macrolib.Utils.writeMacroToPath + import java.io._ + import scala.language.postfixOps // for !! postfix op + import sys.process._ + + override def cost(mem: Macro, lib: Macro): Option[BigInt] = { + // Create temporary files. + val memFile = File.createTempFile("_macrocompiler_mem_", ".json") + val libFile = File.createTempFile("_macrocompiler_lib_", ".json") + + writeMacroToPath(Some(memFile.getAbsolutePath), mem.src) + writeMacroToPath(Some(libFile.getAbsolutePath), lib.src) + + // !! executes the given command + val result: String = (s"${path} ${memFile.getAbsolutePath} ${libFile.getAbsolutePath}" !!).trim + + // Remove temporary files. + memFile.delete() + libFile.delete() + + try { + Some(BigInt(result)) + } catch { + case e: NumberFormatException => None + } + } +} + +/** The current default metric in barstools, re-defined by Donggyu. */ object NewDefaultMetric extends CostMetric { override def cost(mem: Macro, lib: Macro): Option[BigInt] = { val memMask = mem.src.ports map (_.maskGran) find (_.isDefined) map (_.get)