Skip to content

Commit

Permalink
create demo undocumented numpy function
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Nov 5, 2023
1 parent 3ee5346 commit 59d3993
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions test_scripts/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np


def one_hot(
indices,
depth,
on_value = None,
off_value = None,
axis = None,
dtype = None,
):
on_none = on_value is None
off_none = off_value is None

if dtype is None:
if on_none and off_none:
dtype = np.float32
else:
if not on_none:
dtype = np.array(on_value).dtype
elif not off_none:
dtype = np.array(off_value).dtype

res = np.eye(depth, dtype=dtype)[np.array(indices, dtype="int64").reshape(-1)]
res = res.reshape(list(indices.shape) + [depth])

if not on_none and not off_none:
res = np.where(res == 1, on_value, off_value)

if axis is not None:
res = np.moveaxis(res, -1, axis)

return res

0 comments on commit 59d3993

Please sign in to comment.