Skip to content

Commit 87d8f2d

Browse files
committed
Fixed issue with data generator for block sparse
1 parent fdc1daf commit 87d8f2d

File tree

1 file changed

+120
-48
lines changed

1 file changed

+120
-48
lines changed

sam/onyx/generate_matrices.py

Lines changed: 120 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import random
66
import scipy.sparse as ss
77
import tempfile
8+
9+
import torch
10+
811
from sam.onyx.fiber_tree import *
912
import argparse
1013
import math
@@ -13,21 +16,23 @@
1316
from sam.sim.test.test import *
1417

1518

16-
class MatrixGenerator():
19+
class MatrixGenerator:
1720

1821
def __init__(self, name='B', shape=None, mode_ordering=None, block_size=1, block_naive=True, sparsity=0.6,
19-
format='CSF', dump_dir=None, tensor=None, value_cap=None) -> None:
22+
format='CSF', dump_dir=None, tensor=None, value_cap=None, transpose=False) -> None:
2023

2124
# assert dimension is not None
2225
# self.dimension = dimension
2326
self.shape = shape
2427
self.array = None
28+
self.blocked_array = None
2529
self.sparsity = sparsity
2630
self.format = format
2731
self.name = name
2832
self.mode_ordering = mode_ordering
2933
self.block_size = block_size
3034
self.block_naive = block_naive
35+
self.transpose = transpose
3136
if value_cap is None:
3237
self.value_cap = int(math.pow(2, 8)) - 1
3338
else:
@@ -67,23 +72,18 @@ def _create_matrix(self, value_cap=int(math.pow(2, 8)) - 1):
6772
'''
6873
Routine to create the actual matrix from the dimension/shape
6974
'''
70-
if self.block_size > 1:
71-
# self.shape = [*self.shape[:len(self.shape) - 2], self.shape[len(self.shape) - 2] * self.block_size,
72-
# self.shape[len(self.shape) - 2] * self.block_size]
73-
self.shape[len(self.shape) - 2] *= self.block_size
74-
self.shape[len(self.shape) - 1] *= self.block_size
75-
self.array = numpy.random.randint(low=1, high=value_cap, size=self.shape)
75+
self.array = numpy.random.randint(low=0, high=value_cap, size=self.shape)
76+
7677
for idx, x in numpy.ndenumerate(self.array):
78+
rand = random.random()
7779
if self.block_size > 1:
7880
if (idx[-2] % self.block_size == 0) and (idx[-1]) % self.block_size == 0:
79-
if random.random() < self.sparsity:
81+
if rand < self.sparsity:
8082
self.array[..., idx[-2]:idx[-2]+self.block_size, idx[-1]:idx[-1]+self.block_size] = 0
8183
else:
82-
if random.random() < self.sparsity:
84+
if rand < self.sparsity:
8385
self.array[idx] = 0
8486

85-
# print(self.array[...,:self.shape[-2]:self.block_size,:self.shape[-1]:self.block_size])
86-
8787
def _create_fiber_tree(self):
8888
# self.fiber_tree = FiberTree(tensor=self.array if self.block_naive else self.array[...,
8989
# self.shape[-2]:self.block_size,:self.shape[-1]:self.block_size])
@@ -123,16 +123,17 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
123123

124124
print(f"Using dump directory - {use_dir}")
125125

126-
all_zeros = not np.any(self.array)
126+
# all_zeros = not np.any(self.array, out=self.array)
127+
all_zeros = False
127128

128129
debug_coord_arr = []
129130
debug_seg_arr = []
130131
debug_val_arr = []
131132
# Transpose it first if necessary
132-
if tpose is True:
133-
self.array = numpy.transpose(self.array)
134-
self.shape = self.array.shape
135-
self.fiber_tree = FiberTree(tensor=self.array)
133+
# if tpose is True:
134+
# self.array = numpy.transpose(self.array)
135+
# self.shape = self.array.shape
136+
# self.fiber_tree = FiberTree(tensor=self.array)
136137

137138
if format is not None:
138139
self.format = format
@@ -171,27 +172,19 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
171172
# In CSF format, need to iteratively create seg/coord arrays
172173
tmp_lvl_list = []
173174
small_tmp_lvl_list = []
174-
# if self.block_size == 1:
175175
tmp_lvl_list.append(self.fiber_tree.get_root())
176-
# else:
177-
# tmp_lvl_list.append(self.tmp_fiber_tree.get_root())
178176
small_tmp_lvl_list.append(self.tmp_fiber_tree.get_root())
179-
# print(small_tmp_lvl_list)
180-
# print(tmp_lvl_list)
181177

182178
seg_arr, coord_arr = None, None
183179
if self.block_size > 1:
184180
seg_arr, coord_arr = self._dump_csf(small_tmp_lvl_list)
185-
# print(seg_arr, coord_arr)
186181
else:
187182
seg_arr, coord_arr = self._dump_csf(tmp_lvl_list)
188183
if glb_override:
189184
lines = [len(seg_arr), *seg_arr, len(coord_arr), *coord_arr]
190185
self.write_array(lines, name=f"tensor_{self.name}_mode_{self.mode_ordering[0]}", dump_dir=use_dir,
191186
hex=print_hex)
192187
else:
193-
print(self.mode_ordering)
194-
# print(seg_arr, coord_arr)
195188
self.write_array(seg_arr, name=f"tensor_{self.name}_mode_{self.mode_ordering[0]}_seg", dump_dir=use_dir,
196189
hex=print_hex)
197190
self.write_array(coord_arr, name=f"tensor_{self.name}_mode_{self.mode_ordering[0]}_crd",
@@ -222,24 +215,27 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
222215
# self.write_array(tmp_lvl_list, name=f"tensor_{self.name}_mode_vals" dump_dir=use_dir)
223216
self.write_array(lines, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir, hex=print_hex)
224217
else:
225-
reached_full_vals = False
226-
if self.block_size > 1:
227-
tmp_lst = tmp_lvl_list
228-
# Retrieve values from full tensor
229-
while not reached_full_vals:
230-
next_tmp_lvl_list = []
231-
for fib in tmp_lst:
232-
crd_payloads_tmp = fib.get_coord_payloads()
233-
if type(crd_payloads_tmp[0][1]) is not FiberTreeFiber:
234-
reached_full_vals = True
235-
for crd, pld in crd_payloads_tmp:
236-
next_tmp_lvl_list.append(pld)
237-
else:
238-
for crd, pld in crd_payloads_tmp:
239-
next_tmp_lvl_list.append(pld)
240-
tmp_lst = next_tmp_lvl_list
241-
self.write_array(tmp_lst, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir,
242-
hex=print_hex)
218+
# reached_full_vals = False
219+
# if self.block_size > 1:
220+
# tmp_lst = tmp_lvl_list
221+
# # Retrieve values from full tensor
222+
# while not reached_full_vals:
223+
# next_tmp_lvl_list = []
224+
# for fib in tmp_lst:
225+
# crd_payloads_tmp = fib.get_coord_payloads()
226+
# if type(crd_payloads_tmp[0][1]) is not FiberTreeFiber:
227+
# reached_full_vals = True
228+
# for crd, pld in crd_payloads_tmp:
229+
# next_tmp_lvl_list.append(pld)
230+
# else:
231+
# for crd, pld in crd_payloads_tmp:
232+
# next_tmp_lvl_list.append(pld)
233+
# tmp_lst = next_tmp_lvl_list
234+
if self.block_size == 1:
235+
self.write_array(tmp_lst, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir,
236+
hex=print_hex)
237+
else:
238+
self.write_blocked_array(self.array, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir, skip_zeros=True)
243239
else:
244240
seg_arr, coord_arr = self._dump_csf(tmp_lst)
245241
if glb_override:
@@ -260,7 +256,10 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
260256
lines = [len(flat_array), *flat_array]
261257
self.write_array(lines, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir, hex=print_hex)
262258
else:
263-
self.write_array(flat_array, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir, hex=print_hex)
259+
if self.block_size == 1:
260+
self.write_array(flat_array, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir, hex=print_hex)
261+
else:
262+
self.write_blocked_array(self.array, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir)
264263
elif self.format == "COO":
265264
crd_dict = dict()
266265
order = len(self.array.shape)
@@ -293,8 +292,14 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
293292
hex=print_hex)
294293

295294
if dump_shape:
296-
final_shape = [x // self.block_size for x in self.array.shape]
297-
self.write_array(final_shape, name=f"tensor_{self.name}_mode_shape", dump_dir=use_dir, hex=print_hex)
295+
# final_shape = self.array.shape
296+
final_shape = [x if i < len(self.array.shape) - 2 else x // self.block_size for i, x in enumerate(self.array.shape)]
297+
# final_shape = final_shape[self.mode_ordering]
298+
shape = []
299+
for elem in self.mode_ordering:
300+
shape.append(final_shape[elem])
301+
302+
self.write_array(shape, name=f"tensor_{self.name}_mode_shape", dump_dir=use_dir, hex=print_hex)
298303

299304
# Transpose it back
300305
if tpose is True:
@@ -323,7 +328,7 @@ def _dump_csf(self, level_list):
323328

324329
return seg_arr, coord_arr
325330

326-
def write_array(self, str_list, name, dump_dir=None, hex=False):
331+
def write_array(self, str_list, name, dump_dir=None, hex=False, integer=False, num_repeats=1):
327332
"""Write an array/list to a file
328333
329334
Args:
@@ -336,7 +341,35 @@ def write_array(self, str_list, name, dump_dir=None, hex=False):
336341
full_path = dump_dir + "/" + name
337342
with open(full_path, "w+") as wr_file:
338343
for item in str_list:
339-
item_int = int(item)
344+
if integer:
345+
item_int = int(item)
346+
else:
347+
item_int = item
348+
if hex:
349+
wr_file.write(f"{item_int:04X}\n")
350+
else:
351+
for _ in range(num_repeats):
352+
wr_file.write(f"{item_int}\n")
353+
354+
def write_blocked_array(self, str_list, name, dump_dir=None, integer=False, hex=False, skip_zeros=False):
355+
if dump_dir is None:
356+
dump_dir = self.dump_dir
357+
358+
tiles = tile_and_unroll_nd(str_list, [self.block_size, self.block_size])
359+
360+
if skip_zeros:
361+
print("Skipping zeros")
362+
tiles = [tile for tile in tiles if numpy.sum(tile) != 0]
363+
364+
tiles = np.concatenate(tiles)
365+
366+
full_path = dump_dir + "/" + name
367+
with open(full_path, "w+") as wr_file:
368+
for item in tiles:
369+
if integer:
370+
item_int = int(item)
371+
else:
372+
item_int = item
340373
if hex:
341374
wr_file.write(f"{item_int:04X}\n")
342375
else:
@@ -361,6 +394,45 @@ def __setitem__(self, key, val):
361394
self.array[key] = val
362395

363396

397+
def tile_and_unroll_nd(array, tile_size):
398+
"""
399+
Generalized function to tile and unroll a multidimensional array.
400+
401+
Parameters:
402+
- array: Input numpy array of any dimensionality.
403+
- tile_size: A tuple that specifies the size of the tiles for the last dimensions.
404+
405+
Returns:
406+
- A list of flattened tiles.
407+
"""
408+
array_shape = array.shape
409+
num_non_tiled_dims = len(array_shape) - len(tile_size) # Non-tiled dimensions
410+
tiled_dims = array_shape[num_non_tiled_dims:] # Dimensions that will be tiled
411+
412+
# Calculate the number of tiles along the tiling dimensions
413+
tiled_sizes = [dim // tile for dim, tile in zip(tiled_dims, tile_size)]
414+
415+
# Prepare a list to collect the tiles
416+
tiles = []
417+
418+
# Iterate over the non-tiled dimensions
419+
for index in np.ndindex(*array_shape[:num_non_tiled_dims]):
420+
# Iterate over the tiled dimensions and extract each tile
421+
for tile_index in np.ndindex(*tiled_sizes):
422+
# Slice the array to extract the current tile
423+
slices = tuple(
424+
slice(idx * size, (idx + 1) * size) for idx, size in zip(tile_index, tile_size)
425+
)
426+
full_slices = index + slices # Combine non-tiled and tiled indices
427+
tile = array[full_slices]
428+
if type(tile) == torch.Tensor:
429+
tile = tile.numpy()
430+
# Flatten the tile and append to the list
431+
tiles.append(tile.flatten())
432+
433+
return tiles
434+
435+
364436
def get_runs(v1, v2):
365437
"""Get the average run length/runs of each vector
366438

0 commit comments

Comments
 (0)