fp16 gemmini support
This commit is contained in:
@@ -3,14 +3,17 @@
|
|||||||
|
|
||||||
package radiance.subsystem
|
package radiance.subsystem
|
||||||
|
|
||||||
|
import chisel3._
|
||||||
import chisel3.util._
|
import chisel3.util._
|
||||||
import org.chipsalliance.cde.config._
|
import org.chipsalliance.cde.config._
|
||||||
import freechips.rocketchip.rocket._
|
import freechips.rocketchip.rocket._
|
||||||
import freechips.rocketchip.tile._
|
import freechips.rocketchip.tile._
|
||||||
import freechips.rocketchip.subsystem._
|
import freechips.rocketchip.subsystem._
|
||||||
import gemmini.{CapacityInKilobytes, GemminiFPConfigs}
|
import gemmini._
|
||||||
|
import gemmini.Arithmetic.FloatArithmetic._
|
||||||
import radiance.tile._
|
import radiance.tile._
|
||||||
import radiance.memory._
|
import radiance.memory._
|
||||||
|
import radiance.subsystem.RadianceGemminiDataType.{BF16, FP16, FP32, Int8}
|
||||||
|
|
||||||
case class RadianceSharedMemKey(address: BigInt,
|
case class RadianceSharedMemKey(address: BigInt,
|
||||||
size: Int,
|
size: Int,
|
||||||
@@ -84,9 +87,14 @@ class WithRadianceCores(
|
|||||||
), useVxCache)
|
), useVxCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
class WithRadianceGemmini(location: HierarchicalLocation,
|
object RadianceGemminiDataType extends Enumeration {
|
||||||
crossing: RocketCrossingParams,
|
type Type = Value
|
||||||
dim: Int, accSizeInKB: Int, tileSize: Int) extends Config((site, _, up) => {
|
val FP32, FP16, BF16, Int8 = Value
|
||||||
|
}
|
||||||
|
|
||||||
|
class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossingParams,
|
||||||
|
dim: Int, accSizeInKB: Int, tileSize: Int,
|
||||||
|
dataType: RadianceGemminiDataType.Type, dmaBytes: Int) extends Config((site, _, up) => {
|
||||||
case TilesLocated(`location`) => {
|
case TilesLocated(`location`) => {
|
||||||
val prev = up(TilesLocated(`location`))
|
val prev = up(TilesLocated(`location`))
|
||||||
val idOffset = up(NumTiles)
|
val idOffset = up(NumTiles)
|
||||||
@@ -100,7 +108,31 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
|||||||
}.sum
|
}.sum
|
||||||
val smKey = site(RadianceSharedMemKey).get
|
val smKey = site(RadianceSharedMemKey).get
|
||||||
val tileParams = GemminiTileParams(
|
val tileParams = GemminiTileParams(
|
||||||
gemminiConfig = GemminiFPConfigs.FP32DefaultConfig.copy(
|
gemminiConfig = {
|
||||||
|
implicit val arithmetic: Arithmetic[Float] =
|
||||||
|
Arithmetic.FloatArithmetic.asInstanceOf[Arithmetic[Float]]
|
||||||
|
dataType match {
|
||||||
|
case FP32 => GemminiFPConfigs.FP32DefaultConfig
|
||||||
|
case FP16 => GemminiFPConfigs.FP16DefaultConfig.copy(
|
||||||
|
acc_scale_args = Some(ScaleArguments(
|
||||||
|
(t: Float, u: Float) => {t},
|
||||||
|
1, Float(8, 24), -1, identity = "1.0", c_str = "((x))"
|
||||||
|
)),
|
||||||
|
mvin_scale_args = Some(ScaleArguments(
|
||||||
|
(t: Float, u: Float) => t * u,
|
||||||
|
1, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))"
|
||||||
|
)),
|
||||||
|
mvin_scale_acc_args = None,
|
||||||
|
has_training_convs = false,
|
||||||
|
// hardcode_d_to_garbage_addr = true,
|
||||||
|
acc_read_full_width = false, // set to true to output fp32
|
||||||
|
)
|
||||||
|
case BF16 => GemminiFPConfigs.BF16DefaultConfig
|
||||||
|
// TODO: Int8
|
||||||
|
}}.copy(
|
||||||
|
dataflow = Dataflow.WS,
|
||||||
|
ex_read_from_acc = false,
|
||||||
|
ex_write_to_spad = false,
|
||||||
has_training_convs = false,
|
has_training_convs = false,
|
||||||
has_max_pool = false,
|
has_max_pool = false,
|
||||||
use_tl_ext_mem = true,
|
use_tl_ext_mem = true,
|
||||||
@@ -112,8 +144,10 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
|||||||
meshRows = dim,
|
meshRows = dim,
|
||||||
meshColumns = dim,
|
meshColumns = dim,
|
||||||
tile_latency = 0,
|
tile_latency = 0,
|
||||||
|
mesh_output_delay = 1,
|
||||||
|
acc_latency = 3,
|
||||||
dma_maxbytes = site(CacheBlockBytes),
|
dma_maxbytes = site(CacheBlockBytes),
|
||||||
dma_buswidth = 256, // TODO: parameterize
|
dma_buswidth = dmaBytes,
|
||||||
tl_ext_mem_base = smKey.address,
|
tl_ext_mem_base = smKey.address,
|
||||||
sp_banks = smKey.numBanks,
|
sp_banks = smKey.numBanks,
|
||||||
sp_capacity = CapacityInKilobytes(smKey.size >> 10),
|
sp_capacity = CapacityInKilobytes(smKey.size >> 10),
|
||||||
@@ -130,7 +164,8 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
|||||||
}
|
}
|
||||||
case NumTiles => up(NumTiles) + 1
|
case NumTiles => up(NumTiles) + 1
|
||||||
}) {
|
}) {
|
||||||
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int) =
|
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int,
|
||||||
|
dataType: RadianceGemminiDataType.Type = RadianceGemminiDataType.FP32, dmaBytes: Int = 256) =
|
||||||
this(location, RocketCrossingParams(
|
this(location, RocketCrossingParams(
|
||||||
master = HierarchicalElementMasterPortParams.locationDefault(location),
|
master = HierarchicalElementMasterPortParams.locationDefault(location),
|
||||||
slave = HierarchicalElementSlavePortParams.locationDefault(location),
|
slave = HierarchicalElementSlavePortParams.locationDefault(location),
|
||||||
@@ -138,7 +173,7 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
|||||||
case InSubsystem => CBUS
|
case InSubsystem => CBUS
|
||||||
case InCluster(clusterId) => CCBUS(clusterId)
|
case InCluster(clusterId) => CCBUS(clusterId)
|
||||||
}
|
}
|
||||||
), dim, accSizeInKB, tileSize)
|
), dim, accSizeInKB, tileSize, dataType, dmaBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
class WithRadianceSharedMem(address: BigInt,
|
class WithRadianceSharedMem(address: BigInt,
|
||||||
|
|||||||
@@ -44,7 +44,10 @@ class RadianceCluster (
|
|||||||
val gemminiTiles = leafTiles.values.filter(_.isInstanceOf[GemminiTile]).toSeq.asInstanceOf[Seq[GemminiTile]]
|
val gemminiTiles = leafTiles.values.filter(_.isInstanceOf[GemminiTile]).toSeq.asInstanceOf[Seq[GemminiTile]]
|
||||||
val gemminis = gemminiTiles.map(_.gemmini)
|
val gemminis = gemminiTiles.map(_.gemmini)
|
||||||
val gemminiConfigs = gemminis.map(_.config)
|
val gemminiConfigs = gemminis.map(_.config)
|
||||||
// val gemminiConfig = thisClusterParams.gemminiConfig.get // TODO: handle None gracefully
|
|
||||||
|
if (!(gemminiConfigs.tail.map(_.inputType == gemminiConfigs.head.inputType).reduce(_ && _))) {
|
||||||
|
println("******** WARNING ********\n******** gemmini data types do not match\n******** WARNING ********")
|
||||||
|
}
|
||||||
|
|
||||||
val radianceTiles = leafTiles.values.filter(_.isInstanceOf[RadianceTile]).toSeq.asInstanceOf[Seq[RadianceTile]]
|
val radianceTiles = leafTiles.values.filter(_.isInstanceOf[RadianceTile]).toSeq.asInstanceOf[Seq[RadianceTile]]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user