Skip to content

Commit

Permalink
Fix vm overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Oct 30, 2024
1 parent 13bd622 commit 9c430fc
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions pspamm/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,17 @@ def __init__(self,
bpattern = Matrix.load(mtx_filename)
if self.masks:
self.generator.set_sparse()
else:
assert self.k <= ldb

if lda == 0:
apattern = Matrix.load(mtx_filename)
if self.masks:
self.generator.set_sparse()
else:
assert self.m <= lda

assert self.m <= ldc

self.nnz = 0
self.flop = 0
Expand Down Expand Up @@ -286,7 +292,7 @@ def kernelK(asm, Bki, A_ptr, B_ptr):

if keep:
asm.add(self.generator.make_microkernel(self.A, self.B, A_ptr, B_ptr, self.A_regs, self.B_regs, regs, self.v_size, self.additional_regs, to_A, to_B))

if unroll:
for Bki in range(Bk):
kernelK(asm, Bki, A_ptr, B_ptr)
Expand Down Expand Up @@ -376,13 +382,14 @@ def make(self):
loop(self.loop_regs[0], 0, Bm, 1).body(*loopBody)
)

vm_overhead = (self.m % self.bm) // self.v_size
m_overhead = self.m % self.bm
vm_overhead = -(m_overhead // -self.v_size)

if vm_overhead > 0:
self.m = self.m % self.bm
self.bm = self.m % self.bm
self.A_regs = self.A_regs[0:self.bm // self.v_size, 0:self.bk]
self.C_regs = self.C_regs[0:self.bm // self.v_size, 0:self.bn]
self.A_regs = self.A_regs[0:vm_overhead, 0:self.bk]
self.C_regs = self.C_regs[0:vm_overhead, 0:self.bn]
self.A.r = self.m
asm.add(self.make_nk_unroll(self.unroll))

Expand Down

0 comments on commit 9c430fc

Please sign in to comment.