diff --git a/macros/src/main/scala/MacroCompiler.scala b/macros/src/main/scala/MacroCompiler.scala index ea800956..ec70949e 100644 --- a/macros/src/main/scala/MacroCompiler.scala +++ b/macros/src/main/scala/MacroCompiler.scala @@ -14,9 +14,12 @@ import java.io.{File, FileWriter} import Utils._ object MacroCompilerAnnotation { - def apply(c: String, mem: String, lib: Option[String], synflops: Boolean) = { + def apply(c: String, mem: File, lib: Option[File], synflops: Boolean): Annotation = + apply(c, mem.toString, lib map (_.toString), synflops) + + def apply(c: String, mem: String, lib: Option[String], synflops: Boolean): Annotation = { Annotation(CircuitName(c), classOf[MacroCompilerTransform], - s"${mem} %s ${synflops}".format(lib map (_.toString) getOrElse "")) + s"${mem} %s ${synflops}".format(lib getOrElse "")) } private val matcher = "([^ ]+) ([^ ]*) (true|false)".r def unapply(a: Annotation) = a match { @@ -331,8 +334,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], } class MacroCompilerTransform extends Transform { - def inputForm = HighForm - def outputForm = HighForm + def inputForm = MidForm + def outputForm = MidForm def execute(state: CircuitState) = getMyAnnotations(state) match { case Seq(MacroCompilerAnnotation(state.circuit.main, memFile, libFile, synflops)) => require(memFile.isDefined) @@ -349,19 +352,32 @@ class MacroCompilerTransform extends Transform { } val transforms = Seq( new MacroCompilerPass(mems, libs), - new SynFlopsPass(synflops, libs getOrElse mems.get), - firrtl.passes.SplitExpressions - ) - ((transforms foldLeft state)((s, xform) => xform runTransform s)) + new SynFlopsPass(synflops, libs getOrElse mems.get)) + (transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm) + case _ => state } } +// FIXME: Use firrtl.LowerFirrtlOptimizations +class MacroCompilerOptimizations extends SeqTransform { + def inputForm = LowForm + def outputForm = LowForm + def transforms = Seq( + passes.RemoveValidIf, + new firrtl.transforms.ConstantPropagation, + passes.memlib.VerilogMemDelays, + new firrtl.transforms.ConstantPropagation, + passes.Legalize, + passes.SplitExpressions, + passes.CommonSubexpressionElimination) +} + class MacroCompiler extends Compiler { def emitter = new VerilogEmitter def transforms = Seq(new MacroCompilerTransform) ++ - getLoweringTransforms(firrtl.HighForm, firrtl.LowForm) // ++ - // Seq(new LowFirrtlOptimization) // Todo: This is dangerous + getLoweringTransforms(firrtl.HighForm, firrtl.LowForm) ++ + Seq(new MacroCompilerOptimizations) } object MacroCompiler extends App { diff --git a/macros/src/main/scala/Utils.scala b/macros/src/main/scala/Utils.scala index 9a2a11b8..5c41ef8b 100644 --- a/macros/src/main/scala/Utils.scala +++ b/macros/src/main/scala/Utils.scala @@ -14,9 +14,9 @@ import scala.language.implicitConversions class FirrtlMacroPort(port: MacroPort) { val src = port - val isReader = !port.readEnable.isEmpty && port.writeEnable.isEmpty - val isWriter = !port.writeEnable.isEmpty && port.readEnable.isEmpty - val isReadWriter = !port.writeEnable.isEmpty && !port.readEnable.isEmpty + val isReader = port.output.nonEmpty && port.input.isEmpty + val isWriter = port.input.nonEmpty && port.output.isEmpty + val isReadWriter = port.input.nonEmpty && port.output.nonEmpty val addrType = UIntType(IntWidth(ceilLog2(port.depth) max 1)) val dataType = UIntType(IntWidth(port.width))