Skip to content

Commit 3cc7fa8

Browse files
min-xu-aiflying-x
andauthored
[feat] support optional SST and DST (#1063)
* [feat] support sst disabled and dst disabled cases * added tests Co-authored-by: Min Xu <[email protected]>
1 parent 15d4cf1 commit 3cc7fa8

File tree

2 files changed

+77
-19
lines changed

2 files changed

+77
-19
lines changed

fairscale/experimental/wgit/signal_sparsity.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def _is_sparsity_zero(
8181
"""Returns True when a given value of topk_percent or topk_element along a particular top_k_dim
8282
for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False.
8383
"""
84+
if topk_percent is None and topk_element is None:
85+
return False # 100% sparse
86+
8487
top_k_total_size = _top_k_total_size(dense, top_k_dim)
8588
k = _get_k_for_topk(topk_percent, topk_element, top_k_total_size)
8689
return k == top_k_total_size
@@ -245,11 +248,20 @@ def __init__(
245248
self._dst_top_k_percent = dst_top_k_percent
246249

247250
self._validate_conf()
248-
# TODO (Min): Type checking for the following
249251
self._transform, self._inverse_transform = (
250252
(_fft_transform, _ifft_transform) if algo is Algo.FFT else (_dct_transform, _idct_transform)
251253
)
252254

255+
@property
256+
def _sst_enabled(self) -> bool:
257+
"""True if SST is enabled."""
258+
return self._sst_top_k_element is not None or self._sst_top_k_percent is not None
259+
260+
@property
261+
def _dst_enabled(self) -> bool:
262+
"""True if DST is enabled."""
263+
return self._dst_top_k_element is not None or self._dst_top_k_percent is not None
264+
253265
def _validate_conf(self) -> None:
254266
"""Validating if the config is valid.
255267
@@ -262,16 +274,14 @@ def _validate_conf(self) -> None:
262274
If validation fails.
263275
"""
264276
# assert that both top_k_elements and top_k_percent aren't set for sst and dst
265-
def one_and_only(a: Optional[int], b: Optional[float]) -> bool:
266-
return (a is None) ^ (b is None)
277+
def both_set(a: Optional[int], b: Optional[float]) -> bool:
278+
return (a is not None) and (b is not None)
267279

268-
if not (
269-
one_and_only(self._sst_top_k_element, self._sst_top_k_percent)
270-
and one_and_only(self._dst_top_k_element, self._dst_top_k_percent)
280+
if both_set(self._sst_top_k_element, self._sst_top_k_percent) or both_set(
281+
self._dst_top_k_element, self._dst_top_k_percent
271282
):
272283
raise ValueError(
273-
"One and only one of top_k_element and top_k_percent for "
274-
"each of sst and dst must be provided as an argument.\n"
284+
"top_k_element and top_k_percent can't be both set\n"
275285
f"Input values are: sst element={self._sst_top_k_element}, sst percent={self._sst_top_k_percent}, "
276286
f"dst element={self._dst_top_k_element}, dst percent={self._dst_top_k_percent}"
277287
)
@@ -296,7 +306,7 @@ def none_or_greater_0(a: Optional[int]) -> bool:
296306
f"and dst element={self._dst_top_k_element}"
297307
)
298308

299-
def dense_to_sst(self, dense: Tensor) -> Tensor:
309+
def dense_to_sst(self, dense: Tensor) -> Optional[Tensor]:
300310
"""Get Signal Sparse Tensor (SST) from a dense tensor
301311
302312
Dense -> fft -> top-k -> results.
@@ -310,10 +320,14 @@ def dense_to_sst(self, dense: Tensor) -> Tensor:
310320
Input dense tensor (no zeros).
311321
312322
Returns:
313-
(Tensor):
323+
(Tensor, optional):
314324
Same shaped tensor as the input dense tensor, still in dense format but in frequency
315325
domain (complex valued) and has zeros.
316326
"""
327+
if not self._sst_enabled:
328+
# Special case, SST is simply None, which represents an all-zero tensor.
329+
return None
330+
317331
top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim)
318332
k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size)
319333
dense_freq = self._transform(dense, dim=self._sst_top_k_dim)
@@ -325,7 +339,7 @@ def dense_to_sst(self, dense: Tensor) -> Tensor:
325339
real_dense_freq = dense_freq.real.abs()
326340
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)
327341

328-
def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
342+
def dense_sst_to_dst(self, dense: Tensor, sst: Optional[Tensor]) -> Optional[Tensor]:
329343
"""Calculates DST from input dense and SST tensors.
330344
331345
dense - inverse_transform(sst)[using sst_dst_to_dense method] -> top-k -> dst
@@ -340,6 +354,13 @@ def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
340354
(Tensor):
341355
Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values.
342356
"""
357+
if not self._dst_enabled:
358+
# Special case, DST is simply None, which represents an all-zero tensor.
359+
return None
360+
361+
if sst is None:
362+
sst = torch.zeros_like(dense, dtype=torch.complex64)
363+
343364
if not (dense.shape == sst.shape):
344365
raise ValueError("dense and sst have different shapes!")
345366

@@ -349,7 +370,7 @@ def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
349370
del dense
350371
return _scatter_topk_to_sparse_tensor(delta.abs(), delta, k, dim=self._dst_top_k_dim)
351372

352-
def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
373+
def sst_dst_to_dense(self, sst: Optional[Tensor], dst: Optional[Tensor] = None) -> Tensor:
353374
"""From SST and DST returns a dense reconstructed tensor (RT). When argument dst=None, simply returns
354375
the inverse transform of the SST tensor.
355376
@@ -363,12 +384,19 @@ def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
363384
(Tensor):
364385
A dense tensor in real number domain from the SST.
365386
"""
387+
assert not (sst is None and dst is None), "both-None-case is not useful"
388+
389+
if sst is None:
390+
# Simply the delta is the reconstruction.
391+
return dst
392+
393+
# Now, ifft and then add the delta.
366394
dense_rt = torch.real(self._inverse_transform(sst, dim=self._sst_top_k_dim))
367395
if dst is not None:
368396
dense_rt += dst
369397
return dense_rt
370398

371-
def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
399+
def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
372400
"""From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST
373401
tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns
374402
the input dense tensor as the reconstruction.
@@ -393,6 +421,8 @@ def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
393421
# of the same size as dense.
394422
return dense, None, dense
395423
else:
424+
# depending on whether self._sst_enabled and self._dst_enabled, None SST/DST tensors can be returned
425+
# below as well.
396426
sst = self.dense_to_sst(dense)
397427
dst = self.dense_sst_to_dst(dense, sst)
398428
return self.sst_dst_to_dense(sst, dst), sst, dst

tests/experimental/wgit/test_signal_sparsity.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,15 @@ def kwargs(vals_list):
5757
return dict(zip(arg_key_list, vals_list))
5858

5959
# Validate value error is raised when, either:
60-
# 1. One and only one of sst (or dst) percent and element is not provided a value (not None).
61-
# 2. Both of sst (or dst) percent and element is set to None.
62-
# 3. top_k_percent and top_k_element are not in valid range (elem > 0) and for 0 < percent <= 100.
60+
# 1. both sst (or dst) percent and element is not provided a value (not None).
61+
# 2. top_k_percent and top_k_element are not in valid range (elem > 0) and for 0 < percent <= 100.
6362
element = 10
6463
percent = 50
6564
dim = 0
6665
args_list = [
6766
[element, percent, dim, element, None, dim], # case 1.
6867
[element, None, dim, element, percent, dim],
69-
[None, None, dim, element, None, dim], # case 2.
70-
[element, None, dim, None, None, dim],
71-
[0, None, dim, None, None, dim], # case 3.
68+
[0, None, dim, None, None, dim], # case 2.
7269
[None, 0, dim, None, None, dim],
7370
[element, None, dim, 0, None, dim],
7471
[element, None, dim, None, 0, dim],
@@ -399,3 +396,34 @@ def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent, device):
399396
objects_are_equal(lossy_dense.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
400397
objects_are_equal(sst, None, raise_exception=True, rtol=RTOL, atol=ATOL)
401398
objects_are_equal(dst.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
399+
400+
401+
def test_sst_disabled():
402+
"""Tests the case where SST is disabled."""
403+
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000])
404+
result = torch.tensor([0.0, 0.0, 0.7000, 0.8000])
405+
sparser = SignalSparsity(dst_top_k_element=2, dst_top_k_dim=0)
406+
rt, sst, dst = sparser.lossy_compress(dense)
407+
objects_are_equal(rt, result, raise_exception=True, rtol=RTOL, atol=ATOL)
408+
objects_are_equal(dst, result, raise_exception=True, rtol=RTOL, atol=ATOL)
409+
assert sst is None
410+
411+
412+
def test_dst_disabled():
413+
"""Tests the case where DST is disabled."""
414+
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000])
415+
result_rt = torch.tensor([0.6000, 0.7618, 0.7000, 0.6382, 0.8000])
416+
result_sst = torch.tensor(
417+
[
418+
3.50000000000000000000 + 0.00000000000000000000j,
419+
0.00000000000000000000 + 0.00000000000000000000j,
420+
-0.25000002980232238770 + 0.08122986555099487305j,
421+
-0.25000002980232238770 - 0.08122986555099487305j,
422+
0.00000000000000000000 + 0.00000000000000000000j,
423+
]
424+
)
425+
sparser = SignalSparsity(sst_top_k_element=3, sst_top_k_dim=0)
426+
rt, sst, dst = sparser.lossy_compress(dense)
427+
objects_are_equal(rt, result_rt, raise_exception=True, rtol=RTOL, atol=ATOL)
428+
objects_are_equal(sst, result_sst, raise_exception=True, rtol=RTOL, atol=ATOL)
429+
assert dst is None

0 commit comments

Comments
 (0)