@@ -22,21 +22,31 @@ package accel
2222import chisel3 ._
2323import chisel3 .util ._
2424import vta .dpi ._
25+ import vta .core ._
26+ import vta .util .config ._
27+ import vta .shell ._
2528
29+ class TestConfig extends Config (new CoreConfig ++ new PynqConfig )
2630/** Compute
2731 *
2832 * Bit Slice GEMM:
2933 *
3034 * 1. Wait for launch to be asserted
31- * 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
35+ * 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
3236 * 3. Wait for the value
3337 * 4. Increment read-address for next value
34- * 5. Wait for sliced accumulator
35- * 6. Check if counter (cnt) is equal to length process,
36- otherwise goto step 2
37- * 7. Check if reset slice accumulator
38- * 8. Wait for overall accumulator
39- * 8. Issue a write request for 8-byte value at out_baddr address
38+ * 5. Repeat until all inp1 data have been read
39+
40+ * 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
41+ * 7. Wait for the value
42+ * 8. Increment read-address for next value
43+ * 9. Repeat until all inp2 data have been read
44+
45+ * 10. Wait for output to be calculated
46+ * 11. Issue a write request for 8-byte value at out_baddr address
47+ * 12. Increment write-address for next value to write
48+ * 13. Check if counter (cntout) is equal to length to asser finish,
49+ otherwise go to step 11
4050 */
4151class Compute (implicit config : AccelConfig ) extends Module {
4252 val io = IO (new Bundle {
@@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module {
4757 val ptrs = Input (Vec (config.nPtrs, UInt (config.ptrBits.W )))
4858 val mem = new VTAMemDPIMaster
4959 })
50- val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum (7 )
60+ implicit val p : Parameters = new TestConfig
61+ val sIdle :: sReadAReq :: sReadAData :: sReadADone :: sReadBReq :: sReadBData :: sReadBDone :: sInpDone :: sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum (12 )
5162 val state = RegInit (sIdle)
5263 val shift = io.vals(0 )
5364 val length = io.vals(1 )
5465 val rstAccum = io.vals(2 )
5566 val startDot = io.vals(3 )
5667 val cycles = RegInit (0 .U (config.regBits.W ))
57- val reg1 = Reg (chiselTypeOf(io.mem.rd.bits))
58- val reg2 = Reg (chiselTypeOf(io.mem.rd.bits))
59- val cnt = Reg (UInt (config.regBits.W ))
68+ val mvc = Module (new MatrixVectorMultiplication )
69+ val reg1 = Reg (chiselTypeOf(mvc.io.wgt.data.bits))
70+ val reg2 = Reg (chiselTypeOf(mvc.io.inp.data.bits))
71+ val cntwgt = Reg (UInt (config.regBits.W ))
72+ val cntinp = Reg (UInt (config.regBits.W ))
73+ val cntout = Reg (UInt (config.regBits.W ))
6074 val raddr1 = Reg (UInt (config.ptrBits.W ))
6175 val raddr2 = Reg (UInt (config.ptrBits.W ))
6276 val waddr = Reg (UInt (config.ptrBits.W ))
77+ val accum = Module (new Accmulator (size = p(CoreKey ).blockOut, accBits = p(CoreKey ).accBits))
6378
6479 switch (state) {
6580 is (sIdle) {
@@ -73,14 +88,38 @@ class Compute(implicit config: AccelConfig) extends Module {
7388 }
7489 is (sReadAData) {
7590 when (io.mem.rd.valid) {
91+ state := sReadADone
92+ }
93+ }
94+ is (sReadADone) {
95+ when (cntwgt === (length * length) - 1 .U ) {
7696 state := sReadBReq
97+ } .otherwise {
98+ state := sReadAReq
7799 }
78100 }
79101 is (sReadBReq) {
80102 state := sReadBData
81103 }
82104 is (sReadBData) {
83105 when (io.mem.rd.valid) {
106+ state := sReadBDone
107+ }
108+ }
109+ is (sReadBDone) {
110+ when (cntinp === length- 1 .U ) {
111+ state := sInpDone
112+ } .otherwise {
113+ state := sReadBReq
114+ }
115+ }
116+ // Both input is processed
117+ is (sInpDone) {
118+ state := sWait
119+ }
120+ // Wait for computation
121+ is (sWait) {
122+ when (accum.io.ready) {
84123 state := sWriteReq
85124 }
86125 }
@@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module {
89128 state := sWriteData
90129 }
91130 is (sWriteData) {
92- when (cnt === (length - 1 .U )) {
131+ state := sWriteDone
132+ }
133+ is (sWriteDone) {
134+ when (cntout === (length - 1 .U )) {
93135 state := sIdle
94136 } .otherwise {
95- state := sReadAReq
137+ state := sWriteReq
96138 }
97139 }
98140 }
99141
100- val last = state === sWriteData && cnt === (length - 1 .U )
142+ val last = state === sWriteDone && cntout === (length - 1 .U )
101143
102144 // cycle counter
103145 when (state === sIdle) {
@@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module {
114156 raddr1 := io.ptrs(0 )
115157 raddr2 := io.ptrs(1 )
116158 waddr := io.ptrs(2 )
117- } .elsewhen (state === sWriteData ) { // increment input array by 1-byte
159+ } .elsewhen (state === sReadADone ) { // increment input array by 1-byte
118160 raddr1 := raddr1 + 1 .U
161+ } .elsewhen (state === sReadBDone) { // increment input array by 1-byte
119162 raddr2 := raddr2 + 1 .U
120- waddr := waddr
163+ } .elsewhen (state === sWriteDone) {
164+ waddr := waddr + 4 .U // writing 4 bytes
121165 }
122166
123167 // create request
@@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module {
128172
129173 // read
130174 when (state === sReadAData && io.mem.rd.valid) {
131- reg1 := io.mem.rd.bits(7 , 0 )
175+ reg1(cntwgt / length)(cntwgt % length) := io.mem.rd.bits(7 , 0 )
132176 }
133177
134178 when (state === sReadBData && io.mem.rd.valid) {
135- reg2 := io.mem.rd.bits(7 , 0 )
179+ reg2( 0 )(cntinp) := io.mem.rd.bits(7 , 0 )
136180 }
137181
138182 io.mem.rd.ready := state === sReadAData | state === sReadBData
183+ mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
184+ mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed
185+
186+ mvc.io.wgt.data.bits <> reg1
187+ mvc.io.inp.data.bits <> reg2
188+ // Modify when shift operation is supported
189+ mvc.io.reset := false .B
190+ mvc.io.acc_i.data.valid := true .B
191+ for (i <- 0 until p(CoreKey ).blockOut) {
192+ mvc.io.acc_i.data.bits(0 )(i) := 0 .U
193+ }
139194
140-
141- val sliceAccum = Module (new Accumulator (63 ))
142- val overallAccum = Module (new Accumulator (64 ))
143-
144- sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
145- sliceAccum.io.in := reg1 * reg2
146- sliceAccum.io.clear := startDot
147- overallAccum.io.clear := rstAccum
148- overallAccum.io.valid := last // last element has been processed
149- overallAccum.io.in := sliceAccum.io.sum << shift(7 ,0 ) // limit to 8 bits
195+ accum.io.in := mvc.io.acc_o.data.bits
196+ accum.io.shift := shift
197+ accum.io.clear := rstAccum
198+ accum.io.valid := mvc.io.acc_o.data.valid
150199
151200 // write
152- io.mem.wr.valid := overallAccum.io.ready
153- io.mem.wr.bits := overallAccum.io.sum
154-
201+ io.mem.wr.valid := state === sWriteData
202+ io.mem.wr.bits := accum.io.sum(cntout)
155203
156204 // count read/write
157205 when (state === sIdle) {
158- cnt := 0 .U
159- } .elsewhen (state === sWriteData) {
160- cnt := cnt + 1 .U
206+ cntwgt := 0 .U
207+ cntinp := 0 .U
208+ cntout := 0 .U
209+ } .elsewhen (state === sReadADone) {
210+ cntwgt := cntwgt + 1 .U
211+ } .elsewhen (state === sReadBDone) {
212+ cntinp := cntinp + 1 .U
213+ } .elsewhen (state === sWriteDone) {
214+ cntout := cntout + 1 .U
161215 }
162216
163- io.finish := overallAccum.io.ready // data has been added
217+ io.finish := last // data has been added
164218}
165-
166-
167- class Accumulator (dataBits : Int = 8 ) extends Module {
219+ // Shift operation until supported in MVM
220+ class Accmulator (size : Int = 16 , accBits : Int = 32 ) extends Module {
168221 val io = IO (new Bundle {
169222 val clear = Input (Bool ())
170223 val valid = Input (Bool ())
171224 val ready = Output (Bool ())
172- val in = Input (UInt (dataBits.W ))
173- val sum = Output (UInt ((dataBits).W ))
225+ val in = Input (Vec (1 , Vec (size, (UInt (accBits.W )))))
226+ val shift = Input (UInt (8 .W ))
227+ val sum = Output (Vec (size, (UInt (accBits.W ))))
174228 })
229+ val reg = RegInit (VecInit (Seq .fill(size)(0 .U (accBits.W ))))
175230
176- val reg = RegInit ( 0 . U ((dataBits). W ))
177- val ready = RegNext (io.valid)
178- when (io.clear) {
179- reg := 0 . U
180- } .elsewhen ( io.valid) {
181- reg := reg + io.in
182- }
183- io.ready := ready
184- io.sum := reg
231+ for (i <- 0 until size) {
232+ when (io.clear) {
233+ reg(i) := 0 . U
234+ } .elsewhen(io.valid) {
235+ reg(i) := reg(i) + ( io.in( 0 )(i) << io.shift)
236+ }
237+ }
238+ io.ready := RegNext (io.valid)
239+ io.sum := reg
185240}
186241
0 commit comments