diff --git a/macros/src/main/scala/MacroCompiler.scala b/macros/src/main/scala/MacroCompiler.scala index ec70949e..78999cf8 100644 --- a/macros/src/main/scala/MacroCompiler.scala +++ b/macros/src/main/scala/MacroCompiler.scala @@ -67,18 +67,34 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], // Serial mapping val stmts = ArrayBuffer[Statement]() - val selects = HashMap[String, Expression]() val outputs = HashMap[String, ArrayBuffer[(Expression, Expression)]]() + val selects = HashMap[String, Expression]() + val selectRegs = HashMap[String, Expression]() /* Palmer: If we've got a parallel memory then we've got to take the - * address bits into account. */ + * address bits into account. */ if (mem.src.depth > lib.src.depth) { mem.src.ports foreach { port => val high = ceilLog2(mem.src.depth) val low = ceilLog2(lib.src.depth) val ref = WRef(port.address.name) - val name = s"${ref.name}_sel" - selects(ref.name) = WRef(name, UIntType(IntWidth(high-low))) - stmts += DefNode(NoInfo, name, bits(ref, high-1, low)) + val nodeName = s"${ref.name}_sel" + val tpe = UIntType(IntWidth(high-low)) + selects(ref.name) = WRef(nodeName, tpe) + stmts += DefNode(NoInfo, nodeName, bits(ref, high-1, low)) + // Donggyu: output selection should be piped + if (port.output.isDefined) { + val regName = s"${ref.name}_sel_reg" + val enable = (port.chipEnable, port.readEnable) match { + case (Some(ce), Some(re)) => + and(WRef(ce.name, BoolType), WRef(re.name, BoolType)) + case (Some(ce), None) => WRef(ce.name, BoolType) + case (None, Some(re)) => WRef(re.name, BoolType) + case (None, None) => one + } + selectRegs(ref.name) = WRef(regName, tpe) + stmts += DefRegister(NoInfo, regName, tpe, WRef(port.clock.name), zero, WRef(regName)) + stmts += Connect(NoInfo, WRef(regName), Mux(enable, WRef(nodeName), WRef(regName), tpe)) + } } } for ((off, i) <- (0 until mem.src.depth by lib.src.depth).zipWithIndex) { @@ -97,7 +113,15 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], val index = UIntLiteral(i, IntWidth(bitWidth(addr.tpe))) DoPrim(PrimOps.Eq, Seq(addr, index), Nil, index.tpe) } - def andAddrMatch(e: Expression) = and(e, addrMatch) + val addrMatchReg = selectRegs get memPort.src.address.name match { + case None => one + case Some(reg) => + val index = UIntLiteral(i, IntWidth(bitWidth(reg.tpe))) + DoPrim(PrimOps.Eq, Seq(reg, index), Nil, index.tpe) + } + def andAddrMatch(e: Expression) = { + and(e, addrMatch) + } val cats = ArrayBuffer[Expression]() for (((low, high), j) <- pairs.zipWithIndex) { val inst = WRef(s"mem_${i}_${j}", lib.tpe) @@ -272,7 +296,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], val name = s"${mem}_${i}" stmts += DefNode(NoInfo, name, cat(cats.toSeq.reverse)) (outputs getOrElseUpdate (mem, ArrayBuffer[(Expression, Expression)]())) += - (addrMatch -> WRef(name)) + (addrMatchReg -> WRef(name)) case _ => } } @@ -353,7 +377,7 @@ class MacroCompilerTransform extends Transform { val transforms = Seq( new MacroCompilerPass(mems, libs), new SynFlopsPass(synflops, libs getOrElse mems.get)) - (transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm) + (transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm) case _ => state } }