Add cost function selection test
This commit is contained in:
@@ -90,6 +90,7 @@ class ExternalMetric(path: String) extends CostMetric {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** The current default metric in barstools, re-defined by Donggyu. */
|
/** The current default metric in barstools, re-defined by Donggyu. */
|
||||||
|
// TODO: write tests for this function to make sure it selects the right things
|
||||||
object NewDefaultMetric extends CostMetric {
|
object NewDefaultMetric extends CostMetric {
|
||||||
override def cost(mem: Macro, lib: Macro): Option[BigInt] = {
|
override def cost(mem: Macro, lib: Macro): Option[BigInt] = {
|
||||||
val memMask = mem.src.ports map (_.maskGran) find (_.isDefined) map (_.get)
|
val memMask = mem.src.ports map (_.maskGran) find (_.isDefined) map (_.get)
|
||||||
|
|||||||
111
macros/src/test/scala/CostFunction.scala
Normal file
111
macros/src/test/scala/CostFunction.scala
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package barstools.macros
|
||||||
|
|
||||||
|
import mdf.macrolib._
|
||||||
|
|
||||||
|
/** Tests to check that the cost function mechanism is working properly. */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A test metric that simply favours memories with smaller widths, to test that
|
||||||
|
* the metric is chosen properly.
|
||||||
|
*/
|
||||||
|
object TestMinWidthMetric extends CostMetric {
|
||||||
|
// Smaller width = lower cost = favoured
|
||||||
|
override def cost(mem: Macro, lib: Macro): Option[BigInt] = Some(lib.src.width)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Test that cost metric selection is working. */
|
||||||
|
class SelectCostMetric extends MacroCompilerSpec with HasSRAMGenerator {
|
||||||
|
val mem = s"mem-SelectCostMetric.json"
|
||||||
|
val lib = s"lib-SelectCostMetric.json"
|
||||||
|
val v = s"SelectCostMetric.v"
|
||||||
|
|
||||||
|
override val costMetric = TestMinWidthMetric
|
||||||
|
|
||||||
|
val libSRAMs = Seq(
|
||||||
|
SRAMMacro(
|
||||||
|
macroType=SRAM,
|
||||||
|
name="SRAM_WIDTH_128",
|
||||||
|
depth=1024,
|
||||||
|
width=128,
|
||||||
|
family="1rw",
|
||||||
|
ports=Seq(
|
||||||
|
generateReadWritePort("", 128, 1024)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
SRAMMacro(
|
||||||
|
macroType=SRAM,
|
||||||
|
name="SRAM_WIDTH_64",
|
||||||
|
depth=1024,
|
||||||
|
width=64,
|
||||||
|
family="1rw",
|
||||||
|
ports=Seq(
|
||||||
|
generateReadWritePort("", 64, 1024)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
SRAMMacro(
|
||||||
|
macroType=SRAM,
|
||||||
|
name="SRAM_WIDTH_32",
|
||||||
|
depth=1024,
|
||||||
|
width=32,
|
||||||
|
family="1rw",
|
||||||
|
ports=Seq(
|
||||||
|
generateReadWritePort("", 32, 1024)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
val memSRAMs = Seq(generateSRAM("target_memory", "", 128, 1024))
|
||||||
|
|
||||||
|
writeToLib(lib, libSRAMs)
|
||||||
|
writeToMem(mem, memSRAMs)
|
||||||
|
|
||||||
|
// Check that the min width SRAM was chosen, even though it is less efficient.
|
||||||
|
val output =
|
||||||
|
"""
|
||||||
|
circuit target_memory :
|
||||||
|
module target_memory :
|
||||||
|
input clk : Clock
|
||||||
|
input addr : UInt<10>
|
||||||
|
input din : UInt<128>
|
||||||
|
output dout : UInt<128>
|
||||||
|
input write_en : UInt<1>
|
||||||
|
|
||||||
|
inst mem_0_0 of SRAM_WIDTH_32
|
||||||
|
inst mem_0_1 of SRAM_WIDTH_32
|
||||||
|
inst mem_0_2 of SRAM_WIDTH_32
|
||||||
|
inst mem_0_3 of SRAM_WIDTH_32
|
||||||
|
mem_0_0.clk <= clk
|
||||||
|
mem_0_0.addr <= addr
|
||||||
|
node dout_0_0 = bits(mem_0_0.dout, 31, 0)
|
||||||
|
mem_0_0.din <= bits(din, 31, 0)
|
||||||
|
mem_0_0.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||||
|
mem_0_1.clk <= clk
|
||||||
|
mem_0_1.addr <= addr
|
||||||
|
node dout_0_1 = bits(mem_0_1.dout, 31, 0)
|
||||||
|
mem_0_1.din <= bits(din, 63, 32)
|
||||||
|
mem_0_1.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||||
|
mem_0_2.clk <= clk
|
||||||
|
mem_0_2.addr <= addr
|
||||||
|
node dout_0_2 = bits(mem_0_2.dout, 31, 0)
|
||||||
|
mem_0_2.din <= bits(din, 95, 64)
|
||||||
|
mem_0_2.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||||
|
mem_0_3.clk <= clk
|
||||||
|
mem_0_3.addr <= addr
|
||||||
|
node dout_0_3 = bits(mem_0_3.dout, 31, 0)
|
||||||
|
mem_0_3.din <= bits(din, 127, 96)
|
||||||
|
mem_0_3.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1"))
|
||||||
|
node dout_0 = cat(dout_0_3, cat(dout_0_2, cat(dout_0_1, dout_0_0)))
|
||||||
|
dout <= mux(UInt<1>("h1"), dout_0, UInt<1>("h0"))
|
||||||
|
|
||||||
|
extmodule SRAM_WIDTH_32 :
|
||||||
|
input clk : Clock
|
||||||
|
input addr : UInt<10>
|
||||||
|
input din : UInt<32>
|
||||||
|
output dout : UInt<32>
|
||||||
|
input write_en : UInt<1>
|
||||||
|
|
||||||
|
defname = SRAM_WIDTH_32
|
||||||
|
"""
|
||||||
|
|
||||||
|
compileExecuteAndTest(mem, lib, v, output)
|
||||||
|
}
|
||||||
@@ -15,6 +15,9 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate
|
|||||||
val libPrefix: String = testDir
|
val libPrefix: String = testDir
|
||||||
val vPrefix: String = testDir
|
val vPrefix: String = testDir
|
||||||
|
|
||||||
|
// Override this to use a different cost metric.
|
||||||
|
val costMetric: CostMetric = CostMetric.default
|
||||||
|
|
||||||
private def args(mem: String, lib: Option[String], v: String, synflops: Boolean) =
|
private def args(mem: String, lib: Option[String], v: String, synflops: Boolean) =
|
||||||
List("-m", mem.toString, "-v", v) ++
|
List("-m", mem.toString, "-v", v) ++
|
||||||
(lib match { case None => Nil case Some(l) => List("-l", l.toString) }) ++
|
(lib match { case None => Nil case Some(l) => List("-l", l.toString) }) ++
|
||||||
@@ -80,7 +83,7 @@ abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalate
|
|||||||
val macros = mems map (_.blackbox)
|
val macros = mems map (_.blackbox)
|
||||||
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
val circuit = Circuit(NoInfo, macros, macros.last.name)
|
||||||
val passes = Seq(
|
val passes = Seq(
|
||||||
new MacroCompilerPass(Some(mems), libs),
|
new MacroCompilerPass(Some(mems), libs, costMetric),
|
||||||
new SynFlopsPass(synflops, libs getOrElse mems),
|
new SynFlopsPass(synflops, libs getOrElse mems),
|
||||||
RemoveEmpty)
|
RemoveEmpty)
|
||||||
val result: Circuit = (passes foldLeft circuit)((c, pass) => pass run c)
|
val result: Circuit = (passes foldLeft circuit)((c, pass) => pass run c)
|
||||||
|
|||||||
Reference in New Issue
Block a user