Add cost function selection test
This commit is contained in:
@@ -90,6 +90,7 @@ class ExternalMetric(path: String) extends CostMetric {
|
||||
}
|
||||
|
||||
/** The current default metric in barstools, re-defined by Donggyu. */
|
||||
// TODO: write tests for this function to make sure it selects the right things
|
||||
object NewDefaultMetric extends CostMetric {
|
||||
override def cost(mem: Macro, lib: Macro): Option[BigInt] = {
|
||||
val memMask = mem.src.ports map (_.maskGran) find (_.isDefined) map (_.get)
|
||||
|
||||
111
macros/src/test/scala/CostFunction.scala
Normal file
111
macros/src/test/scala/CostFunction.scala
Normal file
@@ -0,0 +1,111 @@
|
||||
package barstools.macros
|
||||
|
||||
import mdf.macrolib._
|
||||
|
||||
/** Tests to check that the cost function mechanism is working properly. */
|
||||
|
||||
/**
|
||||
* A test metric that simply favours memories with smaller widths, to test that
|
||||
* the metric is chosen properly.
|
||||
*/
|
||||
object TestMinWidthMetric extends CostMetric {
|
||||
// Smaller width = lower cost = favoured
|
||||
override def cost(mem: Macro, lib: Macro): Option[BigInt] = Some(lib.src.width)
|
||||
}
|
||||
|
||||
/** Test that cost metric selection is working. */
|
||||
class SelectCostMetric extends MacroCompilerSpec with HasSRAMGenerator {
|
||||
val mem = s"mem-SelectCostMetric.json"
|
||||
val lib = s"lib-SelectCostMetric.json"
|
||||
val v = s"SelectCostMetric.v"
|
||||
|
||||
override val costMetric = TestMinWidthMetric
|
||||
|
||||
val libSRAMs = Seq(
|
||||
SRAMMacro(
|
||||
macroType=SRAM,
|
||||
name="SRAM_WIDTH_128",
|
||||
depth=1024,
|
||||
width=128,
|
||||
family="1rw",
|
||||
ports=Seq(
|
||||
generateReadWritePort("", 128, 1024)
|
||||
)
|
||||
),
|
||||
SRAMMacro(
|
||||
macroType=SRAM,
|
||||
name="SRAM_WIDTH_64",
|
||||
depth=1024,
|
||||
width=64,
|
||||
family="1rw",
|
||||
ports=Seq(
|
||||
generateReadWritePort("", 64, 1024)
|
||||
)
|
||||
),
|
||||
SRAMMacro(
|
||||
macroType=SRAM,
|
||||
name="SRAM_WIDTH_32",
|
||||
depth=1024,
|
||||
width=32,
|
||||
family="1rw",
|
||||
ports=Seq(
|
||||
generateReadWritePort("", 32, 1024)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
val memSRAMs = Seq(generateSRAM("target_memory", "", 128, 1024))
|
||||
|
||||
writeToLib(lib, libSRAMs)
|
||||
writeToMem(mem, memSRAMs)
|
||||
|
||||
// Check that the min width SRAM was chosen, even though it is less efficient.
|
||||
val output =
|
||||
"""
|
||||
circuit target_memory :
|
||||
module target_memory :
|
||||
input clk : Clock
|
||||
input addr : UInt<10>
|
||||
input din : UInt<128>
|
||||
output dout : UInt<128>
|
||||
input write_en : UInt<1>
|
||||
|
||||
inst mem_0_0 of SRAM_WIDTH_32
|
||||
inst mem_0_1 of SRAM_WIDTH_32
|
||||
inst mem_0_2 of SRAM_WIDTH_32
|
||||
inst mem_0_3 of SRAM_WIDTH_32
|
||||
mem_0_0.clk <= clk
|
||||
mem_0_0.addr <= addr
|
||||
node dout_0_0 = bits(mem_0_0.dout, 31, 0)
|
||||
mem_0_0.din <= bits(din, 31, 0)
|
||||
mem_0_0.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||
mem_0_1.clk <= clk
|
||||
mem_0_1.addr <= addr
|
||||
node dout_0_1 = bits(mem_0_1.dout, 31, 0)
|
||||
mem_0_1.din <= bits(din, 63, 32)
|
||||
mem_0_1.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||
mem_0_2.clk <= clk
|
||||
mem_0_2.addr <= addr
|
||||
node dout_0_2 = bits(mem_0_2.dout, 31, 0)
|
||||
mem_0_2.din <= bits(din, 95, 64)
|
||||
mem_0_2.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||
mem_0_3.clk <= clk
|
||||
mem_0_3.addr <= addr
|
||||
node dout_0_3 = bits(mem_0_3.dout, 31, 0)
|
||||
mem_0_3.din <= bits(din, 127, 96)
|
||||
mem_0_3.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||
node dout_0 = cat(dout_0_3, cat(dout_0_2, cat(dout_0_1, dout_0_0)))
|
||||
dout <= mux(UInt<1>("h1"), dout_0, UInt<1>("h0"))
|
||||
|
||||
extmodule SRAM_WIDTH_32 :
|
||||
input clk : Clock
|
||||
input addr : UInt<10>
|
||||
input din : UInt<32>
|
||||
output dout : UInt<32>
|
||||
input write_en : UInt<1>
|
||||
|
||||
defname = SRAM_WIDTH_32
|
||||
"""
|
||||
|
||||
compileExecuteAndTest(mem, lib, v, output)
|
||||
}
|
||||
@@ -15,6 +15,9 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate
|
||||
val libPrefix: String = testDir
|
||||
val vPrefix: String = testDir
|
||||
|
||||
// Override this to use a different cost metric.
|
||||
val costMetric: CostMetric = CostMetric.default
|
||||
|
||||
private def args(mem: String, lib: Option[String], v: String, synflops: Boolean) =
|
||||
List("-m", mem.toString, "-v", v) ++
|
||||
(lib match { case None => Nil case Some(l) => List("-l", l.toString) }) ++
|
||||
@@ -80,7 +83,7 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate
|
||||
val macros = mems map (_.blackbox)
|
||||
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
||||
val passes = Seq(
|
||||
new MacroCompilerPass(Some(mems), libs),
|
||||
new MacroCompilerPass(Some(mems), libs, costMetric),
|
||||
new SynFlopsPass(synflops, libs getOrElse mems),
|
||||
RemoveEmpty)
|
||||
val result: Circuit = (passes foldLeft circuit)((c, pass) => pass run c)
|
||||
|
||||
Reference in New Issue
Block a user