Skip to content

Commit

Permalink
Extend inline ARM_SVE broadcast fma
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Oct 20, 2024
1 parent d87a5af commit c480249
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pspamm/codegen/architectures/arm/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def make_microkernel(self,
bs.append(B_regs[bki_reg, bni])

for Vmi in range(bm//v_size):
# TODO:
# TODO: refactor cell_indices into the cursors/blocks
cell_indices = {}
for bki in range(bk): # inside this k-block
for bni in range(bn): # inside this n-block
Expand Down
70 changes: 52 additions & 18 deletions pspamm/codegen/architectures/arm_sve/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,21 @@ def make_reg_blocks(self, bm: int, bn: int, bk: int, v_size: int, nnz: int, m: i

# k-broadcasting only works in 128-bit lanes
elem128 = 16 // self.get_precision().size()
vk = -(bk // -elem128)
assert ((bn + bk) * vm + bn * bk <= 32) # Needs to fit in SVE z registers
vkext = -(bk // -elem128)

# inline broadcasting is only allowed for the lower-numbered registers
self.inline_broadcast = False
if bn*vkext < 16 if self.get_precision().size() == 8 else bn*vkext < 8:
self.inline_broadcast = True
if bk == 1:
self.inline_broadcast = False

if self.inline_broadcast:
vk = vkext
else:
vk = bk

assert ((bn + bk) * vm + bn * vk <= 32) # Needs to fit in SVE z registers

prec = {
Precision.DOUBLE: "d",
Expand All @@ -88,12 +101,15 @@ def make_reg_blocks(self, bm: int, bn: int, bk: int, v_size: int, nnz: int, m: i
Precision.BFLOAT16: "h",
}[self.get_precision()]

# use max(vm, 1) in case bm < v_size, otherwise we get no A_regs/C_regs
A_regs = Matrix([[z(max(vm, 1) * c + r , prec) for c in range(bk)] for r in range(max(vm, 1))])
B_regs = Matrix([[z(max(vm, 1) * bk + bn * r + c, prec) for c in range(bn)] for r in range(bk)])
C_regs = Matrix([[z(32 - max(vm, 1) * bn + max(vm, 1) * c + r, prec) for c in range(bn)] for r in range(max(vm, 1))])
# make place for the two broadcasting registers
a_offset = 1 if bn * vk == 1 else 0
assert ((bn + bk) * vm + bn * vk + a_offset <= 32)

A_regs = Matrix([[z(vm * c + r + bn * vk + a_offset, prec) for c in range(bk)] for r in range(vm)])
B_regs = Matrix([[z(bn * r + c, prec) for c in range(bn)] for r in range(vk)])
C_regs = Matrix([[z(32 - vm * bn + vm * c + r, prec) for c in range(bn)] for r in range(vm)])

b_reg = max(vm, 1) * bk
b_reg = 0
alpha_reg = [z(b_reg, prec), z(b_reg, prec)]
beta_reg = [z(b_reg + 1, prec), z(b_reg + 1, prec)]

Expand Down Expand Up @@ -175,7 +191,7 @@ def init_registers(self,
# 'ptrue' doesnt work for initialising overhead predicate when using single precision -> see valid patterns from above
# overhead = "\"ptrue p0.{suffix}, #{overhead}{eol}\"\n\t" if bm != 0 else "" # define overhead predicate
overhead_m = "\"mov {gen_reg}{overhead_counter}, #{overhead_m}{eol}\"\n\t\"whilelo p0.{suffix}, {gen_reg}zr, {gen_reg}{overhead_counter}{eol}\"\n\t" if bmmod != 0 else ""
overhead_k = "" #"\"mov {gen_reg}{overhead_counter}, #{overhead_k}{eol}\"\n\t\"whilelo p1.{suffix}, {gen_reg}zr, {gen_reg}{overhead_counter}{eol}\"\n\t" if bkmod != 0 else ""
overhead_k = "" # "\"mov {gen_reg}{overhead_counter}, #{overhead_k}{eol}\"\n\t\"whilelo p1.{suffix}, {gen_reg}zr, {gen_reg}{overhead_counter}{eol}\"\n\t" if bkmod != 0 else ""
all_true = "\"ptrue p7.{suffix}, #31{eol}\"" # define all true predicate
init_registers = (comment + overhead_m + overhead_k + all_true).format(suffix=p_suffix,
gen_reg=gen_reg,
Expand Down Expand Up @@ -274,7 +290,7 @@ def move_register_block(self,
asm.add(ld(addr, registers[ir, ic], True, comment, pred=p_zeroing, is_B=is_B, scalar_offs=False,
add_reg=additional_regs[2]))

prev_overhead = int(p.ugly[1]) == 0 # determine if we previously used p0 (overhead predicate)
prev_overhead = p is None or int(p.ugly[1]) == 0 # determine if we previously used p0 (overhead predicate)

return asm

Expand Down Expand Up @@ -331,31 +347,43 @@ def make_microkernel(self,
# for ld1rw (single prec): immediate offset is multiple of 4 in range of 0 to 252
# for ld1rd (double prec): immediate offset is multiple of 8 in range of 0 to 504
# in both cases: instruction encodes the immediate offset within 6 bits
max_offs = (2 ** 6 - 1) * multiple
if not self.inline_broadcast:
max_offs = (2 ** 6 - 1) * multiple
divider = 1
else:
max_offs = 127
divider = 16
for Vmi in range(Vm):
# set to all v_size predicates to true, we want to replicate a B element into a whole vector
p_zeroing = self.pred_n_trues(v_size, v_size, "z")
for bki in range(bk): # inside this k-block
bki_reg = bki // elem128
for bni in range(bn): # inside this n-block
to_cell = Coords(down=bki, right=bni)
if B.has_nonzero_cell(B_ptr, to_B_block, to_cell):
B_cell_addr, B_comment = B.look(B_ptr, to_B_block, to_cell)
if B_regs[bki, bni] not in bs:
if B_regs[bki_reg, bni] not in bs:
# max_offs is the maximum allowed immediate offset when using ld1rd/ld1rw to broadcast a scalar value
if B_cell_addr.disp > max_offs:
if B_cell_addr.disp - cur11 > 0 and B_cell_addr.disp - cur11 <= max_offs:
B_cell_addr.disp = B_cell_addr.disp - cur11
if B_cell_addr.disp > max_offs or B_cell_addr.disp % divider != 0:
moved = B_cell_addr.disp - cur11
if moved > 0 and moved <= max_offs and moved % divider == 0:
B_cell_addr.disp = moved
else:
asm.add(add(B_cell_addr.disp, additional_regs[0], "", B_cell_addr.base))
cur11 = B_cell_addr.disp
B_cell_addr.disp = 0

B_cell_addr.base = additional_regs[0]

asm.add(ld(B_cell_addr, B_regs[bki, bni], True, B_comment, pred=p_zeroing, is_B=True))
bs.append(B_regs[bki, bni])
if not self.inline_broadcast:
asm.add(ld(B_cell_addr, B_regs[bki_reg, bni], True, B_comment, pred=p_zeroing, is_B=True))
else:
asm.add(ld(B_cell_addr, B_regs[bki_reg, bni], True, B_comment, pred=p_zeroing, sub128=True))
bs.append(B_regs[bki_reg, bni])

for Vmi in range(Vm):
# TODO: refactor cell_indices into the cursors/blocks
cell_indices = {}
p_merging = self.pred_n_trues(bm - Vmi * v_size, v_size, "m")
end_index = bm if Vmi + 1 == Vm else Vmi * v_size + v_size # end_index helps us print the right index ranges
for bki in range(bk): # inside this k-block
Expand All @@ -367,8 +395,14 @@ def make_microkernel(self,
end_index, bki, B_comment)

bki_reg = bki // elem128
bki_sub = bki % elem128
asm.add(fma(B_regs[bki, bni], A_regs[Vmi, bki], C_regs[Vmi, bni], comment=comment, pred=p_merging, bcast=None))
if (bki_reg, bni) not in cell_indices:
cell_indices[(bki_reg, bni)] = 0
if not self.inline_broadcast:
bcast = None
else:
bcast = cell_indices[(bki_reg, bni)]
asm.add(fma(B_regs[bki_reg, bni], A_regs[Vmi, bki], C_regs[Vmi, bni], comment=comment, pred=p_merging, bcast=bcast))
cell_indices[(bki_reg, bni)] += 1
return asm

def init_prefetching(self, prefetching):
Expand Down
16 changes: 14 additions & 2 deletions pspamm/codegen/architectures/arm_sve/inlineprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def visitFma(self, stmt: FmaStmt):
a = stmt.add_dest.ugly
p = self.p_string(stmt.pred)
if stmt.bcast is not None:
s = f"fmla {a}, {p}{m}, {b}[{stmt.bcast}]"
# NOTE: ignores predicate
s = f"fmla {a}, {m}, {b}[{stmt.bcast}]"
else:
s = f"fmla {a}, {p}{m}, {b}"

Expand Down Expand Up @@ -152,10 +153,16 @@ def visitMov(self, stmt: MovStmt):
def visitLoad(self, stmt: LoadStmt):
if isinstance(stmt.src, Label):
src_str = "#" + stmt.src.ugly
elif stmt.src.ugly_offset != "0" and stmt.scalar_offs:
elif (stmt.src.ugly_offset != "0" and stmt.scalar_offs):
self.addLine(f"mov {stmt.add_reg.ugly}, #{stmt.src.ugly_offset}", f"move immediate offset into {stmt.add_reg.ugly}")
# TODO: adapt ugly_lsl_shift to account for possible single precision instead of double precision
src_str = f"[{stmt.src.ugly_base}, {stmt.add_reg.ugly}, LSL #{stmt.dest.ugly_lsl_shift}]"
elif stmt.typ == AsmType.f64x4 or stmt.typ == AsmType.f64x2:
# (note: the 128-bit and 256-bit broadcasts need the following more rudimentary format here)
if stmt.src.ugly_offset == '0':
src_str = f"[{stmt.src.ugly_base}]"
else:
src_str = f"[{stmt.src.ugly_base}, #{stmt.src.ugly_offset}]"
else:
src_str = stmt.src.ugly if not stmt.is_B else stmt.src.ugly_no_vl_scaling

Expand All @@ -169,6 +176,10 @@ def visitLoad(self, stmt: LoadStmt):
s = f"ld1r{prec} {stmt.dest.ugly}, {p}{src_str}"
else:
s = f"ld1{prec} {stmt.dest.ugly}, {p}{src_str}"
elif stmt.typ == AsmType.f64x4 and stmt.aligned:
s = f"ld1ro{prec} {stmt.dest.ugly}, {p}{src_str}"
elif stmt.typ == AsmType.f64x2 and stmt.aligned:
s = f"ld1rq{prec} {stmt.dest.ugly}, {p}{src_str}"
else:
raise NotImplementedError()
self.addLine(s, stmt.comment)
Expand All @@ -180,6 +191,7 @@ def visitStore(self, stmt: StoreStmt):
self.addLine(f"mov {stmt.add_reg.ugly}, #{stmt.dest.ugly_offset}",
f"move immediate offset into {stmt.add_reg.ugly}")
# TODO: adapt ugly_lsl_shift to account for possible single precision instead of double precision
regsize = stmt.add_dest.size() // 16
dest_str = f"[{stmt.dest.ugly_base}, {stmt.add_reg.ugly}, LSL #{stmt.src.ugly_lsl_shift}]"
else:
dest_str = stmt.dest.ugly
Expand Down
7 changes: 5 additions & 2 deletions pspamm/codegen/sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def lea(src: Register, dest: Operand, offset: int, comment:str = None):
stmt.comment = comment
return stmt

def ld(src: Union[Operand, int], dest: Operand, vector: bool, comment:str = None, dest2: Operand = None, pred: Register = None, is_B: bool = False, scalar_offs: bool = False, add_reg: AsmType.i64 = None):
def ld(src: Union[Operand, int], dest: Operand, vector: bool, comment:str = None, dest2: Operand = None, pred: Register = None, is_B: bool = False, scalar_offs: bool = False, add_reg: AsmType.i64 = None, sub128: bool = False):
stmt = LoadStmt()
stmt.src = src if isinstance(src, Operand) else pspamm.architecture.operands.c(src)
stmt.dest = dest
Expand All @@ -94,7 +94,10 @@ def ld(src: Union[Operand, int], dest: Operand, vector: bool, comment:str = None

if vector:
stmt.aligned = True
stmt.typ = AsmType.f64x8
if sub128:
stmt.typ = AsmType.f64x2
else:
stmt.typ = AsmType.f64x8
else:
stmt.aligned = False
stmt.typ = AsmType.i64
Expand Down

0 comments on commit c480249

Please sign in to comment.