Skip to content

Commit

Permalink
cr-nimble refactoring related changes
Browse files Browse the repository at this point in the history
  • Loading branch information
shailesh1729 committed Aug 27, 2022
1 parent 9b96a53 commit 8b9eabe
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 11 deletions.
13 changes: 9 additions & 4 deletions examples/pursuit/cosamp_step_by_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
import cr.sparse as crs
import cr.sparse.dict as crdict
import cr.sparse.data as crdata
from cr.nimble.dsp import (
nonzero_indices,
nonzero_values,
largest_indices
)

# %%
# Problem Setup
Expand Down Expand Up @@ -147,7 +152,7 @@

# %%
# Pick the indices of 3K atoms with largest matches with the residual
I_sub = crs.largest_indices(h, K3)
I_sub = largest_indices(h, K3)
# Update the flags array
flags = flags.at[I_sub].set(True)
# Sort the ``I_sub`` array with the help of flags array
Expand All @@ -165,7 +170,7 @@
# Compute the least squares solution of ``y`` over this subdictionary
x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)
# Pick the indices of K largest entries in in ``x_sub``
Ia = crs.largest_indices(x_sub, K)
Ia = largest_indices(x_sub, K)
print(f"{Ia=}")
# %%
# We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi``
Expand Down Expand Up @@ -215,7 +220,7 @@
h = Phi.T @ r
# %%
# Pick the indices of 2K atoms with largest matches with the residual
I_2k = crs.largest_indices(h, K2 if iterations else K3)
I_2k = largest_indices(h, K2 if iterations else K3)
# We can check if these include the atoms missed out in first iteration.
print(jnp.intersect1d(omega, I_2k))
# %%
Expand All @@ -237,7 +242,7 @@
# Compute the least squares solution of ``y`` over this subdictionary
x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)
# Pick the indices of K largest entries in in ``x_sub``
Ia = crs.largest_indices(x_sub, K)
Ia = largest_indices(x_sub, K)
print(Ia)
# %%
# We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi``
Expand Down
5 changes: 4 additions & 1 deletion examples/pursuit/cs1bit_biht.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import cr.sparse.data as crdata
import cr.sparse.cs.cs1bit as cs1bit

from cr.nimble.dsp import (
build_signal_from_indices_and_values
)

# %%
# Setup
Expand Down Expand Up @@ -76,7 +79,7 @@
sol = cs1bit.biht_jit(Phi, y, K, tau)
# %%
# reconstructed signal
x_rec = crs.build_signal_from_indices_and_values(N, sol.I, sol.x_I)
x_rec = build_signal_from_indices_and_values(N, sol.I, sol.x_I)

# %%
# Verification
Expand Down
13 changes: 9 additions & 4 deletions examples/rec_l1/spikes_l1ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
import cr.sparse.data as crdata
import cr.sparse.lop as lop
import cr.sparse.cvx.l1ls as l1ls
from cr.nimble.dsp import (
hard_threshold_by,
support,
largest_indices_by
)

# %%
# Setup
Expand Down Expand Up @@ -114,7 +119,7 @@
# %%
# Thresholding for large values
# '''''''''''''''''''''''''''''''''''''
x = crs.hard_threshold_by(sol.x, 0.5)
x = hard_threshold_by(sol.x, 0.5)
plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')
plt.subplot(211)
plt.plot(xs)
Expand All @@ -124,8 +129,8 @@
# %%
# Verifying the support recovery
# '''''''''''''''''''''''''''''''''''''
support_xs = crs.support(xs)
support_x = crs.support(x)
support_xs = support(xs)
support_x = support(x)
jnp.all(jnp.equal(support_xs, support_x))


Expand All @@ -134,7 +139,7 @@
# ------------------------------------------------

# Identify the sub-matrix of columns for the support of recovered solution's large entries
support_x = crs.largest_indices_by(sol.x, 0.5)
support_x = largest_indices_by(sol.x, 0.5)
AI = A.columns(support_x)
print(AI.shape)

Expand Down
8 changes: 6 additions & 2 deletions examples/sparse_vector_normals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import jax.numpy as jnp
import cr.sparse as crs
import cr.sparse.data as crdata
from cr.nimble.dsp import (
nonzero_indices,
nonzero_values
)

# %%
# Let's define the size of model and number of sparse entries
Expand All @@ -31,11 +35,11 @@

# %%
# We can easily find the locations of non-zero entries
print(crs.nonzero_indices(x))
print(nonzero_indices(x))

# %%
# We can extract corresponding non-zero values in a compact vector
print(crs.nonzero_values(x))
print(nonzero_values(x))

# %%
# Let's plot the vector to see where the non-zero entries are
Expand Down

0 comments on commit 8b9eabe

Please sign in to comment.