Skip to content

Commit 0b53071

Browse files
BenjaminTujroesch
authored andcommitted
[VTA][Chisel] TSIM VTA Source Refactor (apache#4163)
* app init push * fix on readme * change name, add bit serial explanantion * rm serialLoadMM, change doc * syntax change for readme * add parallel test functionality * fix readme * add python doc * syntax * init commit * fix empty line * fix typo
1 parent 5b1350f commit 0b53071

File tree

4 files changed

+202
-130
lines changed

4 files changed

+202
-130
lines changed

apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala

Lines changed: 104 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,31 @@ package accel
2222
import chisel3._
2323
import chisel3.util._
2424
import vta.dpi._
25+
import vta.core._
26+
import vta.util.config._
27+
import vta.shell._
2528

29+
class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
2630
/** Compute
2731
*
2832
* Bit Slice GEMM:
2933
*
3034
* 1. Wait for launch to be asserted
31-
* 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
35+
* 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
3236
* 3. Wait for the value
3337
* 4. Increment read-address for next value
34-
* 5. Wait for sliced accumulator
35-
* 6. Check if counter (cnt) is equal to length process,
36-
otherwise goto step 2
37-
* 7. Check if reset slice accumulator
38-
* 8. Wait for overall accumulator
39-
* 8. Issue a write request for 8-byte value at out_baddr address
38+
* 5. Repeat until all inp1 data have been read
39+
40+
* 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
41+
* 7. Wait for the value
42+
* 8. Increment read-address for next value
43+
* 9. Repeat until all inp2 data have been read
44+
45+
* 10. Wait for output to be calculated
46+
* 11. Issue a write request for 8-byte value at out_baddr address
47+
* 12. Increment write-address for next value to write
48+
* 13. Check if counter (cntout) is equal to length to asser finish,
49+
otherwise go to step 11
4050
*/
4151
class Compute(implicit config: AccelConfig) extends Module {
4252
val io = IO(new Bundle {
@@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module {
4757
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
4858
val mem = new VTAMemDPIMaster
4959
})
50-
val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
60+
implicit val p: Parameters = new TestConfig
61+
val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
5162
val state = RegInit(sIdle)
5263
val shift = io.vals(0)
5364
val length = io.vals(1)
5465
val rstAccum = io.vals(2)
5566
val startDot = io.vals(3)
5667
val cycles = RegInit(0.U(config.regBits.W))
57-
val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
58-
val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
59-
val cnt = Reg(UInt(config.regBits.W))
68+
val mvc = Module(new MatrixVectorMultiplication)
69+
val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
70+
val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
71+
val cntwgt = Reg(UInt(config.regBits.W))
72+
val cntinp = Reg(UInt(config.regBits.W))
73+
val cntout = Reg(UInt(config.regBits.W))
6074
val raddr1 = Reg(UInt(config.ptrBits.W))
6175
val raddr2 = Reg(UInt(config.ptrBits.W))
6276
val waddr = Reg(UInt(config.ptrBits.W))
77+
val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))
6378

6479
switch (state) {
6580
is (sIdle) {
@@ -73,14 +88,38 @@ class Compute(implicit config: AccelConfig) extends Module {
7388
}
7489
is (sReadAData) {
7590
when (io.mem.rd.valid) {
91+
state := sReadADone
92+
}
93+
}
94+
is (sReadADone) {
95+
when (cntwgt === (length * length) - 1.U) {
7696
state := sReadBReq
97+
} .otherwise {
98+
state := sReadAReq
7799
}
78100
}
79101
is (sReadBReq) {
80102
state := sReadBData
81103
}
82104
is (sReadBData) {
83105
when (io.mem.rd.valid) {
106+
state := sReadBDone
107+
}
108+
}
109+
is (sReadBDone) {
110+
when (cntinp === length-1.U) {
111+
state := sInpDone
112+
} .otherwise {
113+
state := sReadBReq
114+
}
115+
}
116+
// Both input is processed
117+
is (sInpDone) {
118+
state := sWait
119+
}
120+
// Wait for computation
121+
is (sWait) {
122+
when (accum.io.ready) {
84123
state := sWriteReq
85124
}
86125
}
@@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module {
89128
state := sWriteData
90129
}
91130
is (sWriteData) {
92-
when (cnt === (length - 1.U)) {
131+
state := sWriteDone
132+
}
133+
is (sWriteDone) {
134+
when (cntout === (length - 1.U)) {
93135
state := sIdle
94136
} .otherwise {
95-
state := sReadAReq
137+
state := sWriteReq
96138
}
97139
}
98140
}
99141

100-
val last = state === sWriteData && cnt === (length - 1.U)
142+
val last = state === sWriteDone && cntout === (length - 1.U)
101143

102144
// cycle counter
103145
when (state === sIdle) {
@@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module {
114156
raddr1 := io.ptrs(0)
115157
raddr2 := io.ptrs(1)
116158
waddr := io.ptrs(2)
117-
} .elsewhen (state === sWriteData) { // increment input array by 1-byte
159+
} .elsewhen (state === sReadADone) { // increment input array by 1-byte
118160
raddr1 := raddr1 + 1.U
161+
} .elsewhen (state === sReadBDone) { // increment input array by 1-byte
119162
raddr2 := raddr2 + 1.U
120-
waddr := waddr
163+
} .elsewhen (state === sWriteDone) {
164+
waddr := waddr + 4.U // writing 4 bytes
121165
}
122166

123167
// create request
@@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module {
128172

129173
// read
130174
when (state === sReadAData && io.mem.rd.valid) {
131-
reg1 := io.mem.rd.bits(7, 0)
175+
reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
132176
}
133177

134178
when (state === sReadBData && io.mem.rd.valid) {
135-
reg2 := io.mem.rd.bits(7, 0)
179+
reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
136180
}
137181

138182
io.mem.rd.ready := state === sReadAData | state === sReadBData
183+
mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
184+
mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed
185+
186+
mvc.io.wgt.data.bits <> reg1
187+
mvc.io.inp.data.bits <> reg2
188+
// Modify when shift operation is supported
189+
mvc.io.reset := false.B
190+
mvc.io.acc_i.data.valid := true.B
191+
for (i <- 0 until p(CoreKey).blockOut) {
192+
mvc.io.acc_i.data.bits(0)(i) := 0.U
193+
}
139194

140-
141-
val sliceAccum = Module(new Accumulator(63))
142-
val overallAccum = Module(new Accumulator(64))
143-
144-
sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
145-
sliceAccum.io.in := reg1 * reg2
146-
sliceAccum.io.clear := startDot
147-
overallAccum.io.clear := rstAccum
148-
overallAccum.io.valid := last // last element has been processed
149-
overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
195+
accum.io.in := mvc.io.acc_o.data.bits
196+
accum.io.shift := shift
197+
accum.io.clear := rstAccum
198+
accum.io.valid := mvc.io.acc_o.data.valid
150199

151200
// write
152-
io.mem.wr.valid := overallAccum.io.ready
153-
io.mem.wr.bits := overallAccum.io.sum
154-
201+
io.mem.wr.valid := state === sWriteData
202+
io.mem.wr.bits := accum.io.sum(cntout)
155203

156204
// count read/write
157205
when (state === sIdle) {
158-
cnt := 0.U
159-
} .elsewhen (state === sWriteData) {
160-
cnt := cnt + 1.U
206+
cntwgt := 0.U
207+
cntinp := 0.U
208+
cntout := 0.U
209+
} .elsewhen (state === sReadADone) {
210+
cntwgt := cntwgt + 1.U
211+
} .elsewhen (state === sReadBDone) {
212+
cntinp := cntinp + 1.U
213+
} .elsewhen (state === sWriteDone) {
214+
cntout := cntout + 1.U
161215
}
162216

163-
io.finish := overallAccum.io.ready // data has been added
217+
io.finish := last // data has been added
164218
}
165-
166-
167-
class Accumulator(dataBits: Int = 8) extends Module {
219+
// Shift operation until supported in MVM
220+
class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
168221
val io = IO(new Bundle {
169222
val clear = Input(Bool())
170223
val valid = Input(Bool())
171224
val ready = Output(Bool())
172-
val in = Input(UInt(dataBits.W))
173-
val sum = Output(UInt((dataBits).W))
225+
val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
226+
val shift = Input(UInt(8.W))
227+
val sum = Output(Vec(size, (UInt(accBits.W))))
174228
})
229+
val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))
175230

176-
val reg = RegInit(0.U((dataBits).W))
177-
val ready = RegNext(io.valid)
178-
when (io.clear) {
179-
reg := 0.U
180-
} .elsewhen (io.valid) {
181-
reg := reg + io.in
182-
}
183-
io.ready := ready
184-
io.sum := reg
231+
for (i <- 0 until size) {
232+
when (io.clear) {
233+
reg(i) := 0.U
234+
} .elsewhen(io.valid) {
235+
reg(i) := reg(i) + (io.in(0)(i) << io.shift)
236+
}
237+
}
238+
io.ready := RegNext(io.valid)
239+
io.sum := reg
185240
}
186241

apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,9 @@ import vta.dpi._
3535
* Shift value | 0x08
3636
* Vector length | 0x0c
3737
* Reset Accumulator | 0x10
38-
* Reset Dot Module | 0x14
39-
* Input1 pointer lsb | 0x18
40-
* Input1 pointer msb | 0x1c
41-
* Input2 pointer lsb | 0x20
42-
* Input2 pointer msb | 0x24
43-
* Output pointer lsb | 0x28
44-
* Output pointer msb | 0x2c
38+
* Input1 pointer | 0x18
39+
* Input2 pointer | 0x20
40+
* Output pointer | 0x28
4541
* -------------------------------
4642
4743
* ------------------------------

apps/gemm/src/driver.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ class Device {
6666

6767
uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
6868
uint32_t cycles;
69-
uint32_t length = inp1->shape[0];
70-
size_t size1 = (inp1->dtype.bits >> 3) * length;
69+
uint32_t length = inp2->shape[0];
70+
// 1 matrix 1 vector input
71+
size_t size1 = (inp1->dtype.bits >> 3) * length * length;
7172
size_t size2 = (inp2->dtype.bits >> 3) * length;
72-
size_t size3 = (64 >> 3);
73+
// 1 vector output
74+
size_t size3 = (32 >> 3) * length;
7375
inp1_ = this->MemAlloc(size1);
7476
inp2_ = this->MemAlloc(size2);
7577
out_ = this->MemAlloc(size3);
@@ -115,19 +117,17 @@ class Device {
115117

116118
void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
117119
dpi_->WriteReg(0x08, shiftVal);
118-
dpi_->WriteReg(0x0c, length); // vector length
120+
dpi_->WriteReg(0x0c, length); // tensor size
119121
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
120122
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
121123
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
122124
dpi_->WriteReg(0x00, 0x1); // launch
123-
dpi_->WriteReg(0x00, 0x0); // launch
125+
dpi_->WriteReg(0x00, 0x0);
124126

125127
if (reset == 1) {
126-
dpi_->WriteReg(0x10, 0x1); // reset accum
127-
dpi_->WriteReg(0x10, 0x0); // stop reset accum
128+
dpi_->WriteReg(0x10, 0x1); // reset accumulator
129+
dpi_->WriteReg(0x10, 0x0);
128130
}
129-
dpi_->WriteReg(0x14, 0x1); // reset dot
130-
dpi_->WriteReg(0x14, 0x0); // stop reset dot
131131
}
132132

133133
uint32_t WaitForCompletion() {

0 commit comments

Comments
 (0)