Skip to content

Commit

Permalink
[rtl] fix ix type for gather.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Nov 12, 2024
1 parent ae33cd9 commit c2f2846
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 46 deletions.
45 changes: 10 additions & 35 deletions t1/src/T1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -541,30 +541,7 @@ class T1(val parameter: T1Parameter)
parameter.instructionIndexBits
)

// todo: ix type gather read
// gather read state
val gatherOverlap: Bool = Wire(Bool())
val gatherNeedRead: Bool = requestRegDequeue.valid && decodeResult(Decoder.gather) &&
!decodeResult(Decoder.vtype) && !gatherOverlap
val gatherData: UInt = RegInit(0.U(parameter.datapathWidth.W))
val gatherReadRequest: DecoupledIO[VRFReadRequest] = Wire(Decoupled(readType))
val gatherReadLaneSelect: UInt = Wire(UInt(parameter.laneNumber.W))
val gatherReadResultFire = Pipe(gatherReadRequest.fire, gatherReadLaneSelect, parameter.vrfReadLatency).valid
val gatherReadFinish: Bool =
RegEnable(
!requestRegDequeue.fire,
false.B,
(gatherReadResultFire && gatherNeedRead) || requestRegDequeue.fire
)
val gatherReadDataOffset: UInt = Wire(UInt(5.W))

// todo
gatherReadRequest.valid := DontCare
gatherReadRequest.bits := DontCare
gatherReadRequest.ready := DontCare
gatherOverlap := DontCare
gatherReadLaneSelect := DontCare
gatherReadDataOffset := DontCare
val gatherNeedRead: Bool = requestRegDequeue.valid && decodeResult(Decoder.gather)

/** state machine register for each instruction. */
val slots: Seq[InstructionControl] = Seq.tabulate(parameter.chainingSize) { index =>
Expand Down Expand Up @@ -687,7 +664,11 @@ class T1(val parameter: T1Parameter)
val slotReady: Bool = Mux(specialInstruction, slots.map(_.state.idle).last, freeOR)

val source1Select: UInt =
Mux(decodeResult(Decoder.gather), gatherData, Mux(decodeResult(Decoder.itype), immSignExtend, source1Extend))
Mux(
decodeResult(Decoder.gather),
maskUnit.gatherData.bits,
Mux(decodeResult(Decoder.itype), immSignExtend, source1Extend)
)

// data eew for extend type
val extendDataEEW: Bool = (T1Issue.vsew(requestReg.bits.issue) - decodeResult(Decoder.topUop)(2, 1))(0)
Expand Down Expand Up @@ -854,6 +835,9 @@ class T1(val parameter: T1Parameter)
maskUnit.instReq.bits.vs2 := requestRegDequeue.bits.instruction(24, 20)
maskUnit.instReq.bits.vd := requestRegDequeue.bits.instruction(11, 7)
maskUnit.instReq.bits.vl := requestReg.bits.issue.vl
// gather read
maskUnit.gatherRead := gatherNeedRead
maskUnit.gatherData.ready := requestRegDequeue.fire

maskUnit.exeReq.zip(laneVec).foreach { case (maskInput, lane) =>
maskInput.valid := lane.maskUnitRequest.valid && !lane.maskRequestToLSU
Expand All @@ -867,15 +851,6 @@ class T1(val parameter: T1Parameter)
lane.tokenIO.maskRequestRelease := token.maskRequestRelease || lsu.tokenIO.offsetGroupRelease(index)
}

val gatherResultSelect: UInt = Mux1H(
gatherReadLaneSelect,
laneVec.map(_.vrfReadDataChannel)
)
// gather read result
when(gatherReadResultFire) {
gatherData := Mux(gatherOverlap, 0.U, (gatherResultSelect >> gatherReadDataOffset).asUInt)
}

// 连lane的环
parameter.crossLaneConnectCycles.zipWithIndex.foreach { case (cycles, index) =>
cycles.zipWithIndex.foreach { case (cycle, portIndex) =>
Expand Down Expand Up @@ -940,7 +915,7 @@ class T1(val parameter: T1Parameter)
// - for slide instruction, it is unordered, and may have RAW hazard,
// we detect the hazard and decide should we issue this slide or
// issue the instruction after the slide which already in the slot.
requestRegDequeue.ready := executionReady && slotReady && (!gatherNeedRead || gatherReadFinish) &&
requestRegDequeue.ready := executionReady && slotReady && (!gatherNeedRead || maskUnit.gatherData.valid) &&
tokenManager.issueAllow && instructionIndexFree && vrfAllocate

instructionToSlotOH := Mux(requestRegDequeue.fire, slotToEnqueue, 0.U)
Expand Down
13 changes: 7 additions & 6 deletions t1/src/mask/MaskReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MaskReduce(parameter: T1Parameter) extends Module {
val out: ValidIO[ReduceOutput] = IO(Valid(new ReduceOutput(parameter)))
val newInstruction: Bool = IO(Input(Bool()))
val validInst: Bool = IO(Input(Bool()))
val pop: Bool = IO(Input(Bool()))

val maskSize: Int = parameter.laneNumber * parameter.datapathWidth / 8

Expand Down Expand Up @@ -119,9 +120,9 @@ class MaskReduce(parameter: T1Parameter) extends Module {
}
}

val enqWriteMask: UInt = Fill(2, in.bits.eew(1)) ## in.bits.eew.orR ## true.B
val enqWriteMask: UInt = Fill(2, in.bits.eew(1)) ## in.bits.eew.orR ## true.B
val updateInitMask: UInt = FillInterleaved(8, enqWriteMask)
val updateMask: UInt = FillInterleaved(8, writeMask)
val updateMask: UInt = FillInterleaved(8, writeMask)
when(newInstruction) {
// todo: update reduceInit when first in.fire
reduceInit := in.bits.readVS1 & updateInitMask
Expand Down Expand Up @@ -149,9 +150,9 @@ class MaskReduce(parameter: T1Parameter) extends Module {
cutUInt(reqReg.source2, parameter.datapathWidth)
)
val sourceValidCalculate: UInt =
reqReg.fpSourceValid.map(fv =>
Mux(floatType, fv & reqReg.sourceValid, reqReg.sourceValid)
).getOrElse(reqReg.sourceValid)
reqReg.fpSourceValid
.map(fv => Mux(floatType, fv & reqReg.sourceValid, reqReg.sourceValid))
.getOrElse(reqReg.sourceValid)
sourceValid := Mux1H(
UIntToOH(crossFoldCount),
sourceValidCalculate.asBools
Expand Down Expand Up @@ -204,5 +205,5 @@ class MaskReduce(parameter: T1Parameter) extends Module {

out.valid := outValid
out.bits.data := Mux(updateResult, reduceResult, reduceInit)
out.bits.mask := writeMask & Fill(4, validInst)
out.bits.mask := writeMask & Fill(4, validInst && !pop)
}
92 changes: 87 additions & 5 deletions t1/src/mask/MaskUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ class MaskUnit(parameter: T1Parameter) extends Module {
@public
val writeRDData: UInt = IO(Output(UInt(parameter.xLen.W)))

@public
val gatherData: DecoupledIO[UInt] = IO(Decoupled(UInt(parameter.xLen.W)))

@public
val gatherRead: Bool = IO(Input(Bool()))

/** duplicate v0 for mask */
val v0: Vec[UInt] = RegInit(
VecInit(Seq.fill(parameter.vLen / parameter.datapathWidth)(0.U(parameter.datapathWidth.W)))
Expand Down Expand Up @@ -154,19 +160,66 @@ class MaskUnit(parameter: T1Parameter) extends Module {

val maskedWrite: BitLevelMaskWrite = Module(new BitLevelMaskWrite(parameter))

def gatherIndex(elementIndex: UInt, vlmul: UInt, sew: UInt): (UInt, UInt, UInt, UInt, Bool) = {
val intLMULInput: UInt = (1.U << vlmul(1, 0)).asUInt
val positionSize = parameter.laneParam.vlMaxBits - 1
val dataPosition = (changeUIntSize(elementIndex, positionSize) << sew).asUInt(positionSize - 1, 0)
val sewOHInput = UIntToOH(sew)(2, 0)

// The offset of the data starting position in 32 bits (currently only 32).
// Since the data may cross lanes, it will be optimized during fusion.
val dataOffset: UInt = (dataPosition(1) && sewOHInput(1, 0).orR) ## (dataPosition(0) && sewOHInput(0))
val accessLane = if (parameter.laneNumber > 1) dataPosition(log2Ceil(parameter.laneNumber) + 1, 2) else 0.U(1.W)
// 32 bit / group
val dataGroup = (dataPosition >> (log2Ceil(parameter.laneNumber) + 2)).asUInt
val offsetWidth: Int = parameter.laneParam.vrfParam.vrfOffsetBits
val offset = dataGroup(offsetWidth - 1, 0)
val accessRegGrowth = (dataGroup >> offsetWidth).asUInt
val decimalProportion = offset ## accessLane
// 1/8 register
val decimal = decimalProportion(decimalProportion.getWidth - 1, 0.max(decimalProportion.getWidth - 3))

/** elementIndex needs to be compared with vlMax(vLen * lmul /sew) This calculation is too complicated We can change
* the angle. Calculate the increment of the read register and compare it with lmul to know whether the index
* exceeds vlMax. vlmul needs to distinguish between integers and floating points
*/
val overlap =
(vlmul(2) && decimal >= intLMULInput(3, 1)) ||
(!vlmul(2) && accessRegGrowth >= intLMULInput) ||
(elementIndex >> log2Ceil(parameter.vLen)).asUInt.orR
val notNeedRead = overlap
val reallyGrowth: UInt = changeUIntSize(accessRegGrowth, 3)
(dataOffset, accessLane, offset, reallyGrowth, notNeedRead)
}
val (dataOffset, accessLane, offset, reallyGrowth, notNeedRead) =
gatherIndex(instReq.bits.readFromScala, instReq.bits.vlmul, instReq.bits.sew)
val idle :: sRead :: wRead :: sResponse :: Nil = Enum(4)
val gatherReadState: UInt = RegInit(idle)
val gatherRequestFire: Bool = gatherReadState === idle && gatherRead
val gatherSRead: Bool = gatherReadState === sRead
val gatherWaiteRead: Bool = gatherReadState === wRead
val gatherResponse: Bool = gatherReadState === sResponse
val gatherDatOffset: UInt = RegEnable(dataOffset, 0.U, gatherRequestFire)
val gatherLane: UInt = RegEnable(accessLane, 0.U, gatherRequestFire)
val gatherOffset: UInt = RegEnable(offset, 0.U, gatherRequestFire)
val gatherGrowth: UInt = RegEnable(reallyGrowth, 0.U, gatherRequestFire)

val instReg: MaskUnitInstReq = RegEnable(instReq.bits, 0.U.asTypeOf(instReq.bits), instReq.valid)
val instVlValid: Bool =
RegEnable(instReq.bits.vl.orR && instReq.valid, false.B, instReq.valid || lastReport.orR)
// viota mask read vs2. Also pretending to be reading vs1
val viotaReq: Bool = instReq.bits.decodeResult(Decoder.topUop) === "b01000".U
when(instReq.valid && viotaReq) { instReg.vs1 := instReq.bits.vs2 }
when(instReq.valid && viotaReq || gatherRequestFire) {
instReg.vs1 := instReq.bits.vs2
instReg.instructionIndex := instReq.bits.instructionIndex
}
// register for read vs1
val readVS1Reg: MaskUnitReadVs1 = RegInit(0.U.asTypeOf(new MaskUnitReadVs1(parameter)))
val sew1H: UInt = UIntToOH(instReg.sew)(2, 0)
// request for read vs1
val readVS1Req: MaskUnitReadReq = WireDefault(0.U.asTypeOf(new MaskUnitReadReq(parameter)))

when(instReq.valid) {
when(instReq.valid || gatherRequestFire) {
readVS1Reg.requestSend := false.B
readVS1Reg.dataValid := false.B
readVS1Reg.sendToExecution := false.B
Expand All @@ -187,6 +240,7 @@ class MaskUnit(parameter: T1Parameter) extends Module {
val orderReduce: Bool = instReg.decodeResult(Decoder.topUop) === BitPat("b101?1")
val ffo: Bool = instReg.decodeResult(Decoder.topUop) === BitPat("b0111?")
val extendType: Bool = unitType(3) && (subType(2) || subType(1))
val pop: Bool = instReg.decodeResult(Decoder.popCount)

// Instructions for writing vd without source
val noSource: Bool = mvVd || viota
Expand Down Expand Up @@ -564,13 +618,18 @@ class MaskUnit(parameter: T1Parameter) extends Module {
val anyDataValid: Bool = exeReqReg.zipWithIndex.map { case (d, i) => d.valid }.reduce(_ || _)

// try to read vs1
val readVs1Valid: Bool = (unitType(2) || compress) && !readVS1Reg.requestSend
val readVs1Valid: Bool = (unitType(2) || compress) && !readVS1Reg.requestSend || gatherSRead
readVS1Req.vs := instReg.vs1
when(compress) {
val logLaneNumber = log2Ceil(parameter.laneNumber)
readVS1Req.vs := instReg.vs1 + (readVS1Reg.readIndex >> (parameter.laneParam.vrfOffsetBits + logLaneNumber))
readVS1Req.offset := readVS1Reg.readIndex >> logLaneNumber
readVS1Req.readLane := changeUIntSize(readVS1Reg.readIndex, logLaneNumber)
}.elsewhen(gatherSRead) {
readVS1Req.vs := instReg.vs1 + gatherGrowth
readVS1Req.offset := gatherOffset
readVS1Req.readLane := gatherLane
readVS1Req.dataOffset := gatherDatOffset
}

// select execute group
Expand Down Expand Up @@ -782,10 +841,13 @@ class MaskUnit(parameter: T1Parameter) extends Module {
val read = readData(index)
read.ready := isWaiteForThisData
if (index == 0) {
read.ready := isWaiteForThisData || unitType(2) || compress
read.ready := isWaiteForThisData || unitType(2) || compress || gatherWaiteRead
when(read.fire) {
readVS1Reg.data := read.bits
readVS1Reg.dataValid := true.B
when(gatherWaiteRead) {
gatherReadState := sResponse
}
}
}
when(read.fire) {
Expand Down Expand Up @@ -871,6 +933,7 @@ class MaskUnit(parameter: T1Parameter) extends Module {
reduceUnit.in.bits.sign := !instReg.decodeResult(Decoder.unsigned1)
reduceUnit.newInstruction := !readVS1Reg.sendToExecution && reduceUnit.in.fire
reduceUnit.validInst := instReg.vl.orR
reduceUnit.pop := pop

reduceUnit.in.bits.fpSourceValid.foreach { sink =>
sink := VecInit(exeReqReg.map(_.bits.fpReduceValid.get)).asUInt
Expand Down Expand Up @@ -1030,5 +1093,24 @@ class MaskUnit(parameter: T1Parameter) extends Module {
lastReportValid,
indexToOH(instReg.instructionIndex, parameter.chainingSize)
)
writeRDData := compressUnit.writeData
writeRDData := Mux(pop, reduceUnit.out.bits.data, compressUnit.writeData)

// gather read state
when(gatherRequestFire) {
when(notNeedRead) {
gatherReadState := sResponse
}.otherwise {
gatherReadState := sRead
}
}

when(readCrossBar.input.head.fire && gatherSRead) {
gatherReadState := wRead
}

gatherData.valid := gatherResponse
gatherData.bits := Mux(readVS1Reg.dataValid, readVS1Reg.data, 0.U)
when(gatherData.fire) {
gatherReadState := idle
}
}

0 comments on commit c2f2846

Please sign in to comment.