@@ -81,6 +81,9 @@ def _is_sparsity_zero(
81
81
"""Returns True when a given value of topk_percent or topk_element along a particular top_k_dim
82
82
for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False.
83
83
"""
84
+ if topk_percent is None and topk_element is None :
85
+ return False # 100% sparse
86
+
84
87
top_k_total_size = _top_k_total_size (dense , top_k_dim )
85
88
k = _get_k_for_topk (topk_percent , topk_element , top_k_total_size )
86
89
return k == top_k_total_size
@@ -245,11 +248,20 @@ def __init__(
245
248
self ._dst_top_k_percent = dst_top_k_percent
246
249
247
250
self ._validate_conf ()
248
- # TODO (Min): Type checking for the following
249
251
self ._transform , self ._inverse_transform = (
250
252
(_fft_transform , _ifft_transform ) if algo is Algo .FFT else (_dct_transform , _idct_transform )
251
253
)
252
254
255
+ @property
256
+ def _sst_enabled (self ) -> bool :
257
+ """True if SST is enabled."""
258
+ return self ._sst_top_k_element is not None or self ._sst_top_k_percent is not None
259
+
260
+ @property
261
+ def _dst_enabled (self ) -> bool :
262
+ """True if DST is enabled."""
263
+ return self ._dst_top_k_element is not None or self ._dst_top_k_percent is not None
264
+
253
265
def _validate_conf (self ) -> None :
254
266
"""Validating if the config is valid.
255
267
@@ -262,16 +274,14 @@ def _validate_conf(self) -> None:
262
274
If validation fails.
263
275
"""
264
276
# assert that both top_k_elements and top_k_percent aren't set for sst and dst
265
- def one_and_only (a : Optional [int ], b : Optional [float ]) -> bool :
266
- return (a is None ) ^ (b is None )
277
+ def both_set (a : Optional [int ], b : Optional [float ]) -> bool :
278
+ return (a is not None ) and (b is not None )
267
279
268
- if not (
269
- one_and_only (self ._sst_top_k_element , self ._sst_top_k_percent )
270
- and one_and_only (self ._dst_top_k_element , self ._dst_top_k_percent )
280
+ if both_set (self ._sst_top_k_element , self ._sst_top_k_percent ) or both_set (
281
+ self ._dst_top_k_element , self ._dst_top_k_percent
271
282
):
272
283
raise ValueError (
273
- "One and only one of top_k_element and top_k_percent for "
274
- "each of sst and dst must be provided as an argument.\n "
284
+ "top_k_element and top_k_percent can't be both set\n "
275
285
f"Input values are: sst element={ self ._sst_top_k_element } , sst percent={ self ._sst_top_k_percent } , "
276
286
f"dst element={ self ._dst_top_k_element } , dst percent={ self ._dst_top_k_percent } "
277
287
)
@@ -296,7 +306,7 @@ def none_or_greater_0(a: Optional[int]) -> bool:
296
306
f"and dst element={ self ._dst_top_k_element } "
297
307
)
298
308
299
- def dense_to_sst (self , dense : Tensor ) -> Tensor :
309
+ def dense_to_sst (self , dense : Tensor ) -> Optional [ Tensor ] :
300
310
"""Get Signal Sparse Tensor (SST) from a dense tensor
301
311
302
312
Dense -> fft -> top-k -> results.
@@ -310,10 +320,14 @@ def dense_to_sst(self, dense: Tensor) -> Tensor:
310
320
Input dense tensor (no zeros).
311
321
312
322
Returns:
313
- (Tensor):
323
+ (Tensor, optional ):
314
324
Same shaped tensor as the input dense tensor, still in dense format but in frequency
315
325
domain (complex valued) and has zeros.
316
326
"""
327
+ if not self ._sst_enabled :
328
+ # Special case, SST is simply None, which represents an all-zero tensor.
329
+ return None
330
+
317
331
top_k_total_size = _top_k_total_size (dense , self ._sst_top_k_dim )
318
332
k = _get_k_for_topk (self ._sst_top_k_percent , self ._sst_top_k_element , top_k_total_size )
319
333
dense_freq = self ._transform (dense , dim = self ._sst_top_k_dim )
@@ -325,7 +339,7 @@ def dense_to_sst(self, dense: Tensor) -> Tensor:
325
339
real_dense_freq = dense_freq .real .abs ()
326
340
return _scatter_topk_to_sparse_tensor (real_dense_freq , dense_freq , k , dim = self ._sst_top_k_dim )
327
341
328
- def dense_sst_to_dst (self , dense : Tensor , sst : Tensor ) -> Tensor :
342
+ def dense_sst_to_dst (self , dense : Tensor , sst : Optional [ Tensor ] ) -> Optional [ Tensor ] :
329
343
"""Calculates DST from input dense and SST tensors.
330
344
331
345
dense - inverse_transform(sst)[using sst_dst_to_dense method] -> top-k -> dst
@@ -340,6 +354,13 @@ def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
340
354
(Tensor):
341
355
Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values.
342
356
"""
357
+ if not self ._dst_enabled :
358
+ # Special case, DST is simply None, which represents an all-zero tensor.
359
+ return None
360
+
361
+ if sst is None :
362
+ sst = torch .zeros_like (dense , dtype = torch .complex64 )
363
+
343
364
if not (dense .shape == sst .shape ):
344
365
raise ValueError ("dense and sst have different shapes!" )
345
366
@@ -349,7 +370,7 @@ def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
349
370
del dense
350
371
return _scatter_topk_to_sparse_tensor (delta .abs (), delta , k , dim = self ._dst_top_k_dim )
351
372
352
- def sst_dst_to_dense (self , sst : Tensor , dst : Optional [Tensor ] = None ) -> Tensor :
373
+ def sst_dst_to_dense (self , sst : Optional [ Tensor ] , dst : Optional [Tensor ] = None ) -> Tensor :
353
374
"""From SST and DST returns a dense reconstructed tensor (RT). When argument dst=None, simply returns
354
375
the inverse transform of the SST tensor.
355
376
@@ -363,12 +384,19 @@ def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
363
384
(Tensor):
364
385
A dense tensor in real number domain from the SST.
365
386
"""
387
+ assert not (sst is None and dst is None ), "both-None-case is not useful"
388
+
389
+ if sst is None :
390
+ # Simply the delta is the reconstruction.
391
+ return dst
392
+
393
+ # Now, ifft and then add the delta.
366
394
dense_rt = torch .real (self ._inverse_transform (sst , dim = self ._sst_top_k_dim ))
367
395
if dst is not None :
368
396
dense_rt += dst
369
397
return dense_rt
370
398
371
- def lossy_compress (self , dense : Tensor ) -> Tuple [Tensor , Tensor , Tensor ]:
399
+ def lossy_compress (self , dense : Tensor ) -> Tuple [Tensor , Optional [ Tensor ], Optional [ Tensor ] ]:
372
400
"""From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST
373
401
tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns
374
402
the input dense tensor as the reconstruction.
@@ -393,6 +421,8 @@ def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
393
421
# of the same size as dense.
394
422
return dense , None , dense
395
423
else :
424
+ # depending on whether self._sst_enabled and self._dst_enabled, None SST/DST tensors can be returned
425
+ # below as well.
396
426
sst = self .dense_to_sst (dense )
397
427
dst = self .dense_sst_to_dst (dense , sst )
398
428
return self .sst_dst_to_dense (sst , dst ), sst , dst
0 commit comments