Skip to content

Commit

Permalink
Add labels for manual training snippets
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 7, 2023
1 parent 5dc7798 commit 7f41700
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions docs/source/snippets/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,43 @@ def eval(self) -> None:
for timestep in range(cfg["timesteps"]):
trainer.eval(timestep=timestep)
# [jax-end-step]

# =============================================================================

# [pytorch-start-manual-training]

# [pytorch-end-manual-training]

# [pytorch-start-manual-evaluation]
# assuming there is an environment named 'env'
# and an agent named 'agents' (or a state-preprocessor and a policy)

states, infos = env.reset()

for i in range(1000):
# state-preprocessor + policy
with torch.no_grad():
states = state_preprocessor(states)
actions = policy.act({"states": states})[0]

# step the environment
next_states, rewards, terminated, truncated, infos = env.step(actions)

# render the environment
env.render()

# check for termination/truncation
if terminated.any() or truncated.any():
states, infos = env.reset()
else:
states = next_states
# [pytorch-end-manual-evaluation]


# [jax-start-manual-training]

# [jax-end-manual-training]

# [jax-start-manual-evaluation]

# [jax-end-manual-evaluation]

0 comments on commit 7f41700

Please sign in to comment.