From acd974716a96a36b809323311a9132dfc0926295 Mon Sep 17 00:00:00 2001 From: Sam Holt Date: Thu, 16 Jan 2025 01:45:56 +0000 Subject: [PATCH 1/6] fix: Fixed Isort and Pyink linting errors that break pre-commits. --- mujoco_playground/_src/mjx_env.py | 1 - mujoco_playground/_src/wrapper_torch.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mujoco_playground/_src/mjx_env.py b/mujoco_playground/_src/mjx_env.py index 329f437..4b03dd1 100644 --- a/mujoco_playground/_src/mjx_env.py +++ b/mujoco_playground/_src/mjx_env.py @@ -28,7 +28,6 @@ import numpy as np import tqdm - # Root path is used for loading XML strings directly using etils.epath. ROOT_PATH = epath.Path(__file__).parent # Base directory for external dependencies. diff --git a/mujoco_playground/_src/wrapper_torch.py b/mujoco_playground/_src/wrapper_torch.py index 4ddab24..da54f4f 100644 --- a/mujoco_playground/_src/wrapper_torch.py +++ b/mujoco_playground/_src/wrapper_torch.py @@ -228,5 +228,7 @@ def get_number_of_agents(self): def get_env_info(self): info = {} info['action_space'] = self.action_space # pytype: disable=attribute-error - info['observation_space'] = self.observation_space # pytype: disable=attribute-error + info['observation_space'] = ( + self.observation_space + ) # pytype: disable=attribute-error return info From afe80c63b28024ed8df78c01fbcf51a873165591 Mon Sep 17 00:00:00 2001 From: Sam Holt Date: Thu, 16 Jan 2025 02:37:08 +0000 Subject: [PATCH 2/6] fix: Fix all pylint errors down to zero. --- learning/train_jax_ppo.py | 10 +-- learning/train_rsl_rl.py | 8 +- .../_src/locomotion/g1/randomize.py | 1 + .../_src/locomotion/locomotion_test.py | 1 + .../franka_emika_panda/open_cabinet.py | 2 +- .../franka_emika_panda/pick_cartesian.py | 2 +- .../_src/manipulation/leap_hand/rotate_z.py | 3 +- .../_src/manipulation/manipulation_test.py | 1 + mujoco_playground/_src/random.py | 3 +- mujoco_playground/_src/wrapper_torch.py | 87 ++++++++++--------- mujoco_playground/config/locomotion_params.py | 6 +- .../config/manipulation_params.py | 8 +- .../madrona_benchmarking/benchmark.py | 26 +++--- .../madrona_benchmarking/make_plots.py | 33 ++++--- .../madrona_benchmarking/print_tables.py | 21 +++-- .../experimental/sim2sim/gamepad_reader.py | 9 +- .../experimental/sim2sim/play_bh_joystick.py | 1 + .../experimental/sim2sim/play_g1_joystick.py | 1 + .../experimental/sim2sim/play_go1_joystick.py | 1 + .../sim2sim/play_leap_reorient.py | 3 +- 20 files changed, 128 insertions(+), 99 deletions(-) diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index a251427..316c369 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=wrong-import-position """Train a PPO agent using JAX on the specified environment.""" import os @@ -39,7 +40,6 @@ import jax.numpy as jp import mediapy as media from ml_collections import config_dict -from ml_collections import config_flags import mujoco from orbax import checkpoint as ocp from tensorboardX import SummaryWriter @@ -267,11 +267,11 @@ def main(argv): print(f"Checkpoint path: {ckpt_path}") # Save environment configuration - with open(ckpt_path / "config.json", "w") as fp: + with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp: json.dump(env_cfg.to_json(), fp, indent=4) # Define policy parameters function for saving checkpoints - def policy_params_fn(current_step, make_policy, params): + def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument orbax_checkpointer = ocp.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(params) path = ckpt_path / f"{current_step}" @@ -352,7 +352,7 @@ def progress(num_steps, metrics): ) # Train or load the model - make_inference_fn, params, _ = train_fn( + make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter environment=env, progress_fn=progress, eval_env=None if _VISION.value else eval_env, @@ -389,7 +389,7 @@ def progress(num_steps, metrics): rollout = [state0] # Run evaluation rollout - for i in range(env_cfg.episode_length): + for _ in range(env_cfg.episode_length): act_rng, rng = jax.random.split(rng) ctrl, _ = jit_inference_fn(state.obs, act_rng) state = jit_step(state, ctrl) diff --git a/learning/train_rsl_rl.py b/learning/train_rsl_rl.py index b5fd08a..afd9b06 100644 --- a/learning/train_rsl_rl.py +++ b/learning/train_rsl_rl.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=wrong-import-position """Train a PPO agent using RSL-RL for the specified environment.""" import os @@ -28,7 +29,6 @@ from absl import flags from absl import logging import jax -import jax.numpy as jp import mediapy as media from ml_collections import config_dict import mujoco @@ -133,7 +133,9 @@ def main(argv): wandb.config.update({"env_name": _ENV_NAME.value}) # Save environment config to JSON - with open(os.path.join(ckpt_path, "config.json"), "w") as fp: + with open( + os.path.join(ckpt_path, "config.json"), "w", encoding="utf-8" + ) as fp: json.dump(env_cfg.to_json(), fp, indent=4) # Domain randomization @@ -143,7 +145,7 @@ def main(argv): render_trajectory = [] # Callback to gather states for rendering - def render_callback(env, state): + def render_callback(_, state): render_trajectory.append(state) # Create the environment diff --git a/mujoco_playground/_src/locomotion/g1/randomize.py b/mujoco_playground/_src/locomotion/g1/randomize.py index 3feb61d..1d20c2f 100644 --- a/mujoco_playground/_src/locomotion/g1/randomize.py +++ b/mujoco_playground/_src/locomotion/g1/randomize.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Utilities for randomization.""" import jax from mujoco import mjx diff --git a/mujoco_playground/_src/locomotion/locomotion_test.py b/mujoco_playground/_src/locomotion/locomotion_test.py index 745b342..6afd8f8 100644 --- a/mujoco_playground/_src/locomotion/locomotion_test.py +++ b/mujoco_playground/_src/locomotion/locomotion_test.py @@ -24,6 +24,7 @@ class TestSuite(parameterized.TestCase): + """Tests for the locomotion environments.""" @parameterized.named_parameters( {"testcase_name": f"test_can_create_{env_name}", "env_name": env_name} diff --git a/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py b/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py index ceb2326..cf23d7e 100644 --- a/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py +++ b/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py @@ -19,8 +19,8 @@ import jax import jax.numpy as jp from ml_collections import config_dict -import mujoco from mujoco import mjx +import mujoco # pylint: disable=unused-import from mujoco.mjx._src import math from mujoco_playground._src import collision diff --git a/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py b/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py index 21cb548..6b699ea 100644 --- a/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py +++ b/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py @@ -82,7 +82,7 @@ class PandaPickCubeCartesian(pick.PandaPickCube): """Environment for training the Franka Panda robot to pick up a cube in Cartesian space.""" - def __init__( + def __init__( # pylint: disable=non-parent-init-called,super-init-not-called self, config=default_config(), config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, diff --git a/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py b/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py index acd0b66..ac883e7 100644 --- a/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py +++ b/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py @@ -23,7 +23,6 @@ import numpy as np from mujoco_playground._src import mjx_env -from mujoco_playground._src import reward from mujoco_playground._src.manipulation.leap_hand import base as leap_hand_base from mujoco_playground._src.manipulation.leap_hand import leap_hand_constants as consts @@ -145,7 +144,7 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: rewards = { k: v * self._config.reward_config.scales[k] for k, v in rewards.items() } - reward = sum(rewards.values()) * self.dt + reward = sum(rewards.values()) * self.dt # pylint: disable=redefined-outer-name state.info["last_last_act"] = state.info["last_act"] state.info["last_act"] = action diff --git a/mujoco_playground/_src/manipulation/manipulation_test.py b/mujoco_playground/_src/manipulation/manipulation_test.py index d8fa915..0b52d2e 100644 --- a/mujoco_playground/_src/manipulation/manipulation_test.py +++ b/mujoco_playground/_src/manipulation/manipulation_test.py @@ -24,6 +24,7 @@ class TestSuite(parameterized.TestCase): + """Tests for the manipulation environments.""" @parameterized.named_parameters( {"testcase_name": f"test_can_create_{env_name}", "env_name": env_name} diff --git a/mujoco_playground/_src/random.py b/mujoco_playground/_src/random.py index 5553d90..1c8b037 100644 --- a/mujoco_playground/_src/random.py +++ b/mujoco_playground/_src/random.py @@ -30,6 +30,7 @@ def uniform_quat(rng: jax.Array) -> jax.Array: ]) +# pylint: disable=line-too-long # Reference: https://github.com/google-deepmind/dm_control/blob/main/dm_control/locomotion/arenas/bowl.py def random_hfield( model: mujoco.MjModel, @@ -38,7 +39,7 @@ def random_hfield( terrain_smoothness: float = 0.4, ): """Randomize the heightfield.""" - from scipy import ndimage # pylint: disable=g-import-not-at-top + from scipy import ndimage # pylint: disable=g-import-not-at-top, import-outside-toplevel res = model.hfield_nrow[heightfield_id] diff --git a/mujoco_playground/_src/wrapper_torch.py b/mujoco_playground/_src/wrapper_torch.py index da54f4f..c943769 100644 --- a/mujoco_playground/_src/wrapper_torch.py +++ b/mujoco_playground/_src/wrapper_torch.py @@ -35,7 +35,7 @@ def _jax_to_torch(tensor): from jax._src.dlpack import to_dlpack # pylint: disable=import-outside-toplevel - import torch.utils.dlpack as tpack # pytype: disable=import-error + import torch.utils.dlpack as tpack # pytype: disable=import-error # pylint: disable=import-outside-toplevel tensor = to_dlpack(tensor) tensor = tpack.from_dlpack(tensor) @@ -44,7 +44,7 @@ def _jax_to_torch(tensor): def _torch_to_jax(tensor): from jax._src.dlpack import from_dlpack # pylint: disable=import-outside-toplevel - import torch.utils.dlpack as tpack # pytype: disable=import-error + import torch.utils.dlpack as tpack # pytype: disable=import-error # pylint: disable=import-outside-toplevel tensor = tpack.to_dlpack(tensor) tensor = from_dlpack(tensor) @@ -56,28 +56,29 @@ def get_load_path(root, load_run=-1, checkpoint=-1): runs = os.listdir(root) # TODO sort by date to handle change of month runs.sort() - if 'exported' in runs: - runs.remove('exported') + if "exported" in runs: + runs.remove("exported") last_run = os.path.join(root, runs[-1]) - except: - raise ValueError('No runs in this directory: ' + root) - if load_run == -1 or load_run == '-1': + except Exception as exc: + raise ValueError("No runs in this directory: " + root) from exc + if load_run == -1 or load_run == "-1": load_run = last_run else: load_run = os.path.join(root, load_run) if checkpoint == -1: - models = [file for file in os.listdir(load_run) if 'model' in file] - models.sort(key=lambda m: '{0:0>15}'.format(m)) + models = [file for file in os.listdir(load_run) if "model" in file] + models.sort(key=lambda m: m.zfill(15)) model = models[-1] else: - model = 'model_{}.pt'.format(checkpoint) + model = f"model_{checkpoint}.pt" load_path = os.path.join(load_run, model) return load_path class RSLRLBraxWrapper(VecEnv): + """Wrapper for Brax environments that interop with torch.""" def __init__( self, @@ -90,8 +91,7 @@ def __init__( render_callback=None, device_rank=None, ): - from rsl_rl.env import VecEnv # pytype: disable=import-error - import torch # pytype: disable=import-error + import torch # pytype: disable=import-error # pylint: disable=redefined-outer-name,unused-import,import-outside-toplevel self.seed = seed self.batch_size = num_actors @@ -100,11 +100,11 @@ def __init__( self.key = jax.random.PRNGKey(self.seed) if device_rank is not None: - gpu_devices = jax.devices('gpu') + gpu_devices = jax.devices("gpu") self.key = jax.device_put(self.key, gpu_devices[device_rank]) - self.device = f'cuda:{device_rank}' - print(f'Device -- {gpu_devices[device_rank]}') - print(f'Key device -- {self.key.devices()}') + self.device = f"cuda:{device_rank}" + print(f"Device -- {gpu_devices[device_rank]}") + print(f"Key device -- {self.key.devices()}") # split key into two for reset and randomization key_reset, key_randomization = jax.random.split(self.key) @@ -130,13 +130,13 @@ def __init__( self.asymmetric_obs = False obs_shape = self.env.env.unwrapped.observation_size - print(f'obs_shape: {obs_shape}') + print(f"obs_shape: {obs_shape}") if isinstance(obs_shape, dict): - print('Asymmetric observation space') + print("Asymmetric observation space") self.asymmetric_obs = True - self.num_obs = obs_shape['state'] - self.num_privileged_obs = obs_shape['privileged_state'] + self.num_obs = obs_shape["state"] + self.num_privileged_obs = obs_shape["privileged_state"] else: self.num_obs = obs_shape self.num_privileged_obs = None @@ -148,47 +148,48 @@ def __init__( # todo -- specific to leap environment self.success_queue = deque(maxlen=100) - print(f'JITing reset and step') + print("JITing reset and step") self.reset_fn = jax.jit(self.env.reset) self.step_fn = jax.jit(self.env.step) - print(f'Done JITing reset and step') + print("Done JITing reset and step") self.env_state = None def step(self, action): action = torch.clip(action, -1.0, 1.0) # pytype: disable=attribute-error action = _torch_to_jax(action) self.env_state = self.step_fn(self.env_state, action) + critic_obs = None if self.asymmetric_obs: - obs = _jax_to_torch(self.env_state.obs['state']) - critic_obs = _jax_to_torch(self.env_state.obs['privileged_state']) + obs = _jax_to_torch(self.env_state.obs["state"]) + critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) else: obs = _jax_to_torch(self.env_state.obs) reward = _jax_to_torch(self.env_state.reward) done = _jax_to_torch(self.env_state.done) info = self.env_state.info - truncation = _jax_to_torch(info['truncation']) + truncation = _jax_to_torch(info["truncation"]) info_ret = { - 'time_outs': truncation, - 'observations': {'critic': critic_obs}, - 'log': {}, + "time_outs": truncation, + "observations": {"critic": critic_obs}, + "log": {}, } - if 'last_episode_success_count' in info: + if "last_episode_success_count" in info: last_episode_success_count = ( - _jax_to_torch(info['last_episode_success_count'])[done > 0] + _jax_to_torch(info["last_episode_success_count"])[done > 0] # pylint: disable=unsubscriptable-object .float() .tolist() ) if len(last_episode_success_count) > 0: self.success_queue.extend(last_episode_success_count) - info_ret['log']['last_episode_success_count'] = np.mean( + info_ret["log"]["last_episode_success_count"] = np.mean( self.success_queue ) for k, v in self.env_state.metrics.items(): - if k not in info_ret['log']: - info_ret['log'][k] = _jax_to_torch(v).float().mean().item() + if k not in info_ret["log"]: + info_ret["log"][k] = _jax_to_torch(v).float().mean().item() return obs, reward, done, info_ret @@ -197,38 +198,38 @@ def reset(self): self.env_state = self.reset_fn(self.key_reset) if self.asymmetric_obs: - obs = _jax_to_torch(self.env_state.obs['state']) - # critic_obs = jax_to_torch(self.env_state.obs['privileged_state']) + obs = _jax_to_torch(self.env_state.obs["state"]) + # critic_obs = jax_to_torch(self.env_state.obs["privileged_state"]) else: obs = _jax_to_torch(self.env_state.obs) return obs def reset_with_critic_obs(self): self.env_state = self.reset_fn(self.key_reset) - obs = _jax_to_torch(self.env_state.obs['state']) - critic_obs = _jax_to_torch(self.env_state.obs['privileged_state']) + obs = _jax_to_torch(self.env_state.obs["state"]) + critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) return obs, critic_obs def get_observations(self): if self.asymmetric_obs: obs, critic_obs = self.reset_with_critic_obs() - return obs, {'observations': {'critic': critic_obs}} + return obs, {"observations": {"critic": critic_obs}} else: - return self.reset(), {'observations': {}} + return self.reset(), {"observations": {}} - def render(self, mode='human'): + def render(self, mode="human"): # pylint: disable=unused-argument if self.render_callback is not None: self.render_callback(self.env.env.env, self.env_state) else: - raise ValueError('No render callback specified') + raise ValueError("No render callback specified") def get_number_of_agents(self): return 1 def get_env_info(self): info = {} - info['action_space'] = self.action_space # pytype: disable=attribute-error - info['observation_space'] = ( + info["action_space"] = self.action_space # pytype: disable=attribute-error + info["observation_space"] = ( self.observation_space ) # pytype: disable=attribute-error return info diff --git a/mujoco_playground/config/locomotion_params.py b/mujoco_playground/config/locomotion_params.py index bcf1cc4..502d3d1 100644 --- a/mujoco_playground/config/locomotion_params.py +++ b/mujoco_playground/config/locomotion_params.py @@ -143,7 +143,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: init_noise_std=1.0, actor_hidden_dims=[512, 256, 128], critic_hidden_dims=[512, 256, 128], - activation="elu", # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation="elu", # can be elu, relu, selu, crelu, lrelu, tanh, + # sigmoid class_name="ActorCritic", ), algorithm=config_dict.create( @@ -153,7 +154,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: clip_param=0.2, entropy_coef=0.001, num_learning_epochs=5, - num_mini_batches=4, # mini batch size = num_envs*nsteps / nminibatches + num_mini_batches=4, # mini batch size = \ + # num_envs*nsteps / nminibatches learning_rate=3.0e-4, # 5.e-4 schedule="fixed", # could be adaptive, fixed gamma=0.99, diff --git a/mujoco_playground/config/manipulation_params.py b/mujoco_playground/config/manipulation_params.py index b1253ea..66b1390 100644 --- a/mujoco_playground/config/manipulation_params.py +++ b/mujoco_playground/config/manipulation_params.py @@ -179,7 +179,7 @@ def brax_vision_ppo_config(env_name: str) -> config_dict.ConfigDict: return rl_config -def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: +def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: # pylint: disable=unused-argument """Returns tuned RSL-RL PPO config for the given environment.""" rl_config = config_dict.create( @@ -189,7 +189,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: init_noise_std=1.0, actor_hidden_dims=[512, 256, 128], critic_hidden_dims=[512, 256, 128], - activation="elu", # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation="elu", # can be elu, relu, selu, crelu, lrelu, tanh, \ + # sigmoid class_name="ActorCritic", ), algorithm=config_dict.create( @@ -199,7 +200,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: clip_param=0.2, entropy_coef=0.001, num_learning_epochs=5, - num_mini_batches=4, # mini batch size = num_envs*nsteps / nminibatches + num_mini_batches=4, # mini batch size = \ + # num_envs*nsteps / nminibatches learning_rate=3.0e-4, # 5.e-4 schedule="adaptive", # could be adaptive, fixed gamma=0.99, diff --git a/mujoco_playground/experimental/madrona_benchmarking/benchmark.py b/mujoco_playground/experimental/madrona_benchmarking/benchmark.py index 1d21f5b..31864e1 100644 --- a/mujoco_playground/experimental/madrona_benchmarking/benchmark.py +++ b/mujoco_playground/experimental/madrona_benchmarking/benchmark.py @@ -99,7 +99,8 @@ def unvmap(x, ind): if vision: # Load the compiled rendering backend to save time! - # os.environ["MADRONA_MWGPU_KERNEL_CACHE"] = "/madrona_mjx/build/cache" + # os.environ["MADRONA_MWGPU_KERNEL_CACHE"] = \ + # "/madrona_mjx/build/cache" # Coordinate between Jax and the Madrona rendering backend def limit_jax_mem(limit): @@ -127,7 +128,8 @@ def limit_jax_mem(limit): episode_length = int(3 / ctrl_dt) if img_size > 400: # Memory saving mode. - # Should not affect benchmarking results, as the same rendering calls are made. + # Should not affect benchmarking results, as the same rendering calls are + # made. env_specific = {"vision_config.history": 1} else: @@ -168,12 +170,12 @@ def limit_jax_mem(limit): ) jit_reset = jax.jit(env.reset) - state = jit_reset(jax.random.split(jax.random.PRNGKey(0), num_envs)) + state_outer = jit_reset(jax.random.split(jax.random.PRNGKey(0), num_envs)) # Random noise inference function. if mode in [MeasurementMode.STATE.value, MeasurementMode.STATE_VISION.value]: - def inference_fn(_obs, key): + def inference_fn(_, key): return ( jax.random.uniform( key, (num_envs, env.action_size), minval=-1.0, maxval=1.0 @@ -185,7 +187,6 @@ def inference_fn(_obs, key): else: # Randomly initialized Brax inference function. from brax.training.acme import running_statistics - from brax.training.acme import specs from brax.training.agents.ppo import losses as ppo_losses from brax.training.agents.ppo.networks import make_inference_fn from brax.training.agents.ppo.networks_vision import make_ppo_networks_vision @@ -193,7 +194,7 @@ def inference_fn(_obs, key): network_factory = make_ppo_networks_vision - env_state = unvmap(state, 0) + env_state = unvmap(state_outer, 0) preprocess_fn = running_statistics.normalize ppo_network = network_factory( env.observation_size, @@ -229,8 +230,9 @@ def inference_fn(_obs, key): def rollout(state, seed): """ Main benchmarking component. - The "token" system ensures proper timing, as it depends on all of the final output dimensions. - Naively returning the final state adds several GB of additional memory requirement for high res. + The "token" system ensures proper timing, as it depends on all of the + final output dimensions. Naively returning the final state adds several GB + of additional memory requirement for high res. """ def env_step(c, _): @@ -245,10 +247,10 @@ def env_step(c, _): (state, _), tokens = jax.lax.scan(env_step, (state, key_act), length=N) return jp.sum(tokens) - jit_rollout = rollout.lower(state, 0).compile() + jit_rollout = rollout.lower(state_outer, 0).compile() t0 = time.time() - output = jit_rollout(state, 1) + output = jit_rollout(state_outer, 1) jax.tree_util.tree_map( lambda x: x.block_until_ready(), output ) # Await device completion @@ -278,13 +280,13 @@ def env_step(c, _): # Check if the file already exists to write the header only if needed try: - with open(csv_file_path, "r") as f: + with open(csv_file_path, "r", encoding="utf-8") as f: existing_headers = f.readline().strip().split(",") except FileNotFoundError: existing_headers = [] # Open the CSV file in append mode - with open(csv_file_path, "a", newline="") as f: + with open(csv_file_path, "a", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=cur_row.keys()) # Write the header if the file is new or doesn't have matching headers diff --git a/mujoco_playground/experimental/madrona_benchmarking/make_plots.py b/mujoco_playground/experimental/madrona_benchmarking/make_plots.py index d8c20af..8b7cc71 100644 --- a/mujoco_playground/experimental/madrona_benchmarking/make_plots.py +++ b/mujoco_playground/experimental/madrona_benchmarking/make_plots.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== +# ================================================================================= +# pylint: skip-file """ -Generate plots comparing Madrona MJX with ManiSkill3 and Isaac Lab for the CartpoleBalance environment. +Generate plots comparing Madrona MJX with ManiSkill3 and Isaac Lab for the +CartpoleBalance environment. """ import matplotlib.pyplot as plt @@ -24,7 +26,7 @@ # Custom function for scientific notation formatting -def scientific_notation_formatter(x, pos): +def scientific_notation_formatter(x, _): if x == 0: return "0" else: @@ -32,12 +34,12 @@ def scientific_notation_formatter(x, pos): # Format the y-axis as powers of 2 -def log2_to_exp_formatter(x, pos): +def log2_to_exp_formatter(x, _): return f"${{2^{{{int(x)}}}}}$" # Custom formatter function -def format_power_of_10(num, pos): +def format_power_of_10(num, _): # num = 2**num """Format a number into scientific notation as, for example, 2.3 x 10^4.""" if num == 0: @@ -51,7 +53,7 @@ def format_power_of_10(num, pos): return r"$10^{%d}$" % exponent else: # Otherwise, show something like 2.3 x 10^4 - return r"${:.1f}\times 10^{{{}}}$".format(base, exponent) + return f"${base:.1f}\times 10^{{{exponent}}}$" def load_maniskill_result(name, state=False): @@ -69,14 +71,14 @@ def load_maniskill_result(name, state=False): (df_filtered["num_cameras"] == 1) & (df_filtered["obs_mode"] == "rgb") ] - _df = { + df_dict = { "num_envs": df_filtered["num_envs"], "fps": df_filtered["env.step/fps"], "source_file": name, } if not state: - _df["camera_size"] = df_filtered["camera_width"] - return pd.DataFrame(_df) + df_dict["camera_size"] = df_filtered["camera_width"] + return pd.DataFrame(df_dict) def load_madrona_mjx_result(name, state=False): @@ -87,13 +89,13 @@ def load_madrona_mjx_result(name, state=False): df = df[df["img_size"] == 0] else: df = df[df["img_size"] > 0] - _df = { + df_dict = { "num_envs": df["num_envs"], "fps": df["fps"], "source_file": name, "camera_size": df["img_size"], } - return pd.DataFrame(_df) + return pd.DataFrame(df_dict) # --- Vision --- @@ -258,7 +260,8 @@ def configure_plotting_sn_params( ax.set_xlabel("Batch Size") ax.set_ylabel("FPS") ax.yaxis.set_major_formatter(FuncFormatter(scientific_notation_formatter)) - # If you only want a single legend overall, remove the per-axes legend and place one globally at the end + # If you only want a single legend overall, remove the per-axes legend and + # place one globally at the end ax.legend().remove() # Optionally add a single legend for the entire figure @@ -450,14 +453,16 @@ def log2_to_exp_formatter(x, pos): df_seaborn["camera_size"] = df_seaborn["camera_size"].astype(int) # Add a column for bar stack order (state on bottom, render on top). - # Seaborn doesn't support having bars on top of eachother so we plot total in back and state on top of it. + # Seaborn doesn't support having bars on top of eachother so we plot total + # in back and state on top of it. df_seaborn["stack_order"] = df_seaborn["time_component"].map( {"tps_state": 0, "tps_total": 1} ) def plot_stacked_bars_seaborn(data, camera_sizes, simulators, sim_colors): """ - Generate a FacetGrid of stacked bar charts comparing rendering and state times. + Generate a FacetGrid of stacked bar charts comparing rendering and state + times. """ # Configure Seaborn global sn diff --git a/mujoco_playground/experimental/madrona_benchmarking/print_tables.py b/mujoco_playground/experimental/madrona_benchmarking/print_tables.py index ad0537f..ee1a201 100644 --- a/mujoco_playground/experimental/madrona_benchmarking/print_tables.py +++ b/mujoco_playground/experimental/madrona_benchmarking/print_tables.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Print tables to analyze bottlenecks for the Cartpole and Franka Pixel environments.""" +# pylint: skip-file +"""Print tables to analyze bottlenecks for the Cartpole and Franka Pixel +environments.""" -import enum from pathlib import Path import pandas as pd @@ -25,7 +26,7 @@ fpath = Path(__file__).parent / "data" / fname df = pd.read_csv(fpath) # Select the correct mode and image size. -df = df[(df["bottleneck_mode"] == True) & (df["img_size"] == 64)] +df = df[(df["bottleneck_mode"] is True) & (df["img_size"] == 64)] fname_train = "madrona.csv" fpath_train = Path(__file__).parent.parent / "data" / fname_train @@ -37,7 +38,8 @@ def calculate_average_fps(df, env): Calculate the average FPS for a given env across all seeds. Parameters: - df (pd.DataFrame): The input dataframe with columns 'step', 'seed', 'training/walltime', 'env'. + df (pd.DataFrame): The input dataframe with columns 'step', 'seed', + 'training/walltime', 'env'. env (str): The environment to filter by. Returns: @@ -48,7 +50,7 @@ def calculate_average_fps(df, env): # Group by seed and calculate FPS for each seed fps_values = [] - for seed, seed_df in env_df.groupby("seed"): + for _, seed_df in env_df.groupby("seed"): # Sort by step to ensure chronological order seed_df = seed_df.sort_values("step") @@ -76,9 +78,12 @@ def calculate_average_fps(df, env): fps_train = calculate_average_fps(df_train, env_name) tm_train = 1 / fps_train df_env = df[df["env_name"] == env_name] - # tm_winf = 1/df[df["bottleneck_mode"] == MeasurementMode.STEP_WINFERENCE.value]["fps"].mean() - # tm_wvis = 1/df[df["bottleneck_mode"] == MeasurementMode.STEP_WVISION.value]["fps"].mean() - # tm_state = 1/df[df["bottleneck_mode"] == MeasurementMode.STEP_STATE.value]["fps"].mean() + # tm_winf = 1/df[df["bottleneck_mode"] == \ + # MeasurementMode.STEP_WINFERENCE.value]["fps"].mean() + # tm_wvis = 1/df[df["bottleneck_mode"] == \ + # MeasurementMode.STEP_WVISION.value]["fps"].mean() + # tm_state = 1/df[df["bottleneck_mode"] == \ + # MeasurementMode.STEP_STATE.value]["fps"].mean() tm_winf = ( 1 / df_env[ diff --git a/mujoco_playground/experimental/sim2sim/gamepad_reader.py b/mujoco_playground/experimental/sim2sim/gamepad_reader.py index a4e613d..8cbd2fb 100644 --- a/mujoco_playground/experimental/sim2sim/gamepad_reader.py +++ b/mujoco_playground/experimental/sim2sim/gamepad_reader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=line-too-long """Logitech F710 Gamepad class that uses HID under the hood. Adapted from motion_imitation: https://github.com/erwincoumans/motion_imitation/tree/master/motion_imitation/robots/gamepad/gamepad_reader.py. @@ -31,6 +32,7 @@ def _interpolate(value, old_max, new_scale, deadzone=0.01): class Gamepad: + """Gamepad class that reads from a Logitech F710 gamepad.""" def __init__( self, @@ -63,10 +65,11 @@ def _connect_device(self): self._device.set_nonblocking(True) print( "Connected to" - f" {self._device.get_manufacturer_string()} {self._device.get_product_string()}" + f" {self._device.get_manufacturer_string()} " + f"{self._device.get_product_string()}" ) return True - except Exception as e: + except (hid.HIDException, OSError) as e: print(f"Error connecting to device: {e}") return False @@ -80,7 +83,7 @@ def read_loop(self): data = self._device.read(64) if data: self.update_command(data) - except Exception as e: + except (hid.HIDException, OSError) as e: print(f"Error reading from device: {e}") self._device.close() diff --git a/mujoco_playground/experimental/sim2sim/play_bh_joystick.py b/mujoco_playground/experimental/sim2sim/play_bh_joystick.py index 500537d..31b9a3d 100644 --- a/mujoco_playground/experimental/sim2sim/play_bh_joystick.py +++ b/mujoco_playground/experimental/sim2sim/play_bh_joystick.py @@ -29,6 +29,7 @@ class OnnxController: + """ONNX controller for the Berkeley humanoid.""" def __init__( self, diff --git a/mujoco_playground/experimental/sim2sim/play_g1_joystick.py b/mujoco_playground/experimental/sim2sim/play_g1_joystick.py index 42f1583..f76c345 100644 --- a/mujoco_playground/experimental/sim2sim/play_g1_joystick.py +++ b/mujoco_playground/experimental/sim2sim/play_g1_joystick.py @@ -29,6 +29,7 @@ class OnnxController: + """ONNX controller for the Go-1 robot.""" def __init__( self, diff --git a/mujoco_playground/experimental/sim2sim/play_go1_joystick.py b/mujoco_playground/experimental/sim2sim/play_go1_joystick.py index 73975ad..3433017 100644 --- a/mujoco_playground/experimental/sim2sim/play_go1_joystick.py +++ b/mujoco_playground/experimental/sim2sim/play_go1_joystick.py @@ -29,6 +29,7 @@ class OnnxController: + """ONNX controller for the Go-1 robot.""" def __init__( self, diff --git a/mujoco_playground/experimental/sim2sim/play_leap_reorient.py b/mujoco_playground/experimental/sim2sim/play_leap_reorient.py index 3da9f48..db16489 100644 --- a/mujoco_playground/experimental/sim2sim/play_leap_reorient.py +++ b/mujoco_playground/experimental/sim2sim/play_leap_reorient.py @@ -30,6 +30,7 @@ class OnnxController: + """ONNX controller for the Leap hand.""" def __init__( self, @@ -58,7 +59,7 @@ def __init__( self._counter = 0 self._n_substeps = n_substeps - def get_obs(self, model, data) -> np.ndarray: + def get_obs(self, model, data) -> np.ndarray: # pylint: disable=unused-argument joint_angles = data.qpos[self._hand_qids] qpos_error = joint_angles - self._motor_targets cube_pos_error = ( From 819e41049f7508a219a711fa9b40f4f671a0b483 Mon Sep 17 00:00:00 2001 From: Sam Holt Date: Sat, 18 Jan 2025 02:14:38 +0100 Subject: [PATCH 3/6] fix: Updatd linting errors to zero pre-commit errors --- .../franka_emika_panda/randomize_vision.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mujoco_playground/_src/manipulation/franka_emika_panda/randomize_vision.py b/mujoco_playground/_src/manipulation/franka_emika_panda/randomize_vision.py index 044382c..5343ea2 100644 --- a/mujoco_playground/_src/manipulation/franka_emika_panda/randomize_vision.py +++ b/mujoco_playground/_src/manipulation/franka_emika_panda/randomize_vision.py @@ -60,9 +60,9 @@ def domain_randomize( ) -> Tuple[mjx.Model, mjx.Model]: """Tile the necessary axes for the Madrona BatchRenderer.""" mj_model = pick_cartesian.PandaPickCubeCartesian().mj_model - FLOOR_GEOM_ID = mj_model.geom('floor').id - BOX_GEOM_ID = mj_model.geom('box').id - STRIP_GEOM_ID = mj_model.geom('init_space').id + floor_geom_id = mj_model.geom('floor').id + box_geom_id = mj_model.geom('box').id + strip_geom_id = mj_model.geom('init_space').id in_axes = jax.tree_util.tree_map(lambda x: None, mjx_model) in_axes = in_axes.tree_replace({ @@ -93,16 +93,16 @@ def rand(rng: jax.Array, light_position: jax.Array): rgba = jp.array( [jax.random.uniform(key_box, (), minval=0.5, maxval=1.0), 0.0, 0.0, 1.0] ) - geom_rgba = mjx_model.geom_rgba.at[BOX_GEOM_ID].set(rgba) + geom_rgba = mjx_model.geom_rgba.at[box_geom_id].set(rgba) strip_white = jax.random.uniform(key_strip, (), minval=0.8, maxval=1.0) - geom_rgba = mjx_model.geom_rgba.at[STRIP_GEOM_ID].set( + geom_rgba = mjx_model.geom_rgba.at[strip_geom_id].set( jp.array([strip_white, strip_white, strip_white, 1.0]) ) # Sample a shade of gray gray_scale = jax.random.uniform(key_floor, (), minval=0.0, maxval=0.25) - geom_rgba = geom_rgba.at[FLOOR_GEOM_ID].set( + geom_rgba = geom_rgba.at[floor_geom_id].set( jp.array([gray_scale, gray_scale, gray_scale, 1.0]) ) @@ -112,11 +112,11 @@ def rand(rng: jax.Array, light_position: jax.Array): jax.random.randint(key_matid, shape=(num_geoms,), minval=0, maxval=10) + mat_offset ) - geom_matid = geom_matid.at[BOX_GEOM_ID].set( + geom_matid = geom_matid.at[box_geom_id].set( -2 ) # Use the above randomized colors - geom_matid = geom_matid.at[FLOOR_GEOM_ID].set(-2) - geom_matid = geom_matid.at[STRIP_GEOM_ID].set(-2) + geom_matid = geom_matid.at[floor_geom_id].set(-2) + geom_matid = geom_matid.at[strip_geom_id].set(-2) #### Cameras #### key_pos, key_ori, key = jax.random.split(key, 3) @@ -134,7 +134,7 @@ def rand(rng: jax.Array, light_position: jax.Array): assert ( nlight == 1 ), f'Sim2Real was trained with a single light source, got {nlight}' - key_lsha, key_ldir, key_ldct, key = jax.random.split(key, 4) + key_lsha, key_ldir, key = jax.random.split(key, 3) # Direction shine_at = jp.array([0.661, -0.001, 0.179]) # Gripper starting position From 650d470e00116226c6af8a46c1cff6e270c07dca Mon Sep 17 00:00:00 2001 From: Sam Holt Date: Mon, 20 Jan 2025 19:22:19 +0000 Subject: [PATCH 4/6] feat: PyLint improvement updates --- mujoco_playground/config/locomotion_params.py | 8 ++++---- mujoco_playground/config/manipulation_params.py | 8 ++++---- .../experimental/madrona_benchmarking/benchmark.py | 3 ++- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/mujoco_playground/config/locomotion_params.py b/mujoco_playground/config/locomotion_params.py index 502d3d1..22c63b4 100644 --- a/mujoco_playground/config/locomotion_params.py +++ b/mujoco_playground/config/locomotion_params.py @@ -143,8 +143,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: init_noise_std=1.0, actor_hidden_dims=[512, 256, 128], critic_hidden_dims=[512, 256, 128], - activation="elu", # can be elu, relu, selu, crelu, lrelu, tanh, - # sigmoid + # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation="elu", class_name="ActorCritic", ), algorithm=config_dict.create( @@ -154,8 +154,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: clip_param=0.2, entropy_coef=0.001, num_learning_epochs=5, - num_mini_batches=4, # mini batch size = \ - # num_envs*nsteps / nminibatches + # mini batch size = num_envs*nsteps / nminibatches + num_mini_batches=4, learning_rate=3.0e-4, # 5.e-4 schedule="fixed", # could be adaptive, fixed gamma=0.99, diff --git a/mujoco_playground/config/manipulation_params.py b/mujoco_playground/config/manipulation_params.py index 66b1390..f5faaa0 100644 --- a/mujoco_playground/config/manipulation_params.py +++ b/mujoco_playground/config/manipulation_params.py @@ -189,8 +189,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: # pylint: disable=u init_noise_std=1.0, actor_hidden_dims=[512, 256, 128], critic_hidden_dims=[512, 256, 128], - activation="elu", # can be elu, relu, selu, crelu, lrelu, tanh, \ - # sigmoid + # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation="elu", class_name="ActorCritic", ), algorithm=config_dict.create( @@ -200,8 +200,8 @@ def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: # pylint: disable=u clip_param=0.2, entropy_coef=0.001, num_learning_epochs=5, - num_mini_batches=4, # mini batch size = \ - # num_envs*nsteps / nminibatches + # mini batch size = num_envs*nsteps / nminibatches + num_mini_batches=4, learning_rate=3.0e-4, # 5.e-4 schedule="adaptive", # could be adaptive, fixed gamma=0.99, diff --git a/mujoco_playground/experimental/madrona_benchmarking/benchmark.py b/mujoco_playground/experimental/madrona_benchmarking/benchmark.py index 31864e1..3af7ed7 100644 --- a/mujoco_playground/experimental/madrona_benchmarking/benchmark.py +++ b/mujoco_playground/experimental/madrona_benchmarking/benchmark.py @@ -99,8 +99,9 @@ def unvmap(x, ind): if vision: # Load the compiled rendering backend to save time! - # os.environ["MADRONA_MWGPU_KERNEL_CACHE"] = \ + # os.environ["MADRONA_MWGPU_KERNEL_CACHE"] = ( # "/madrona_mjx/build/cache" + # ) # Coordinate between Jax and the Madrona rendering backend def limit_jax_mem(limit): From 4f029d202b42faf2ebbb2b7ba9d223c60c80795e Mon Sep 17 00:00:00 2001 From: Sam Holt Date: Mon, 20 Jan 2025 19:27:38 +0000 Subject: [PATCH 5/6] fix: PyLint skip PyLint on experimental folder --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index bdfc7cf..6a0a18b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,6 +142,9 @@ exclude = [ output = ".pytype" report_errors = true +[tool.pylint] +ignore-paths = 'mujoco_menagerie/experimental/**$' + [tool.hatch.build] include = [ "mujoco_playground/__init__.py", From 4075c5cdb5d12133087800ce9da10a3bf0559e0a Mon Sep 17 00:00:00 2001 From: Sam Holt Date: Tue, 21 Jan 2025 16:57:34 +0000 Subject: [PATCH 6/6] fix: Updated PyLint changes to fix repo --- learning/train_jax_ppo.py | 18 ++++++++---------- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index 316c369..b2e7682 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -12,21 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# pylint: disable=wrong-import-position """Train a PPO agent using JAX on the specified environment.""" -import os - -xla_flags = os.environ.get("XLA_FLAGS", "") -xla_flags += " --xla_gpu_triton_gemm_any=True" -os.environ["XLA_FLAGS"] = xla_flags -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" -os.environ["MUJOCO_GL"] = "egl" - from datetime import datetime import functools import json +import os import time +import warnings from absl import app from absl import flags @@ -52,11 +45,16 @@ from mujoco_playground.config import locomotion_params from mujoco_playground.config import manipulation_params +xla_flags = os.environ.get("XLA_FLAGS", "") +xla_flags += " --xla_gpu_triton_gemm_any=True" +os.environ["XLA_FLAGS"] = xla_flags +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +os.environ["MUJOCO_GL"] = "egl" + # Ignore the info logs from brax logging.set_verbosity(logging.WARNING) # Suppress warnings -import warnings # Suppress RuntimeWarnings from JAX warnings.filterwarnings("ignore", category=RuntimeWarning, module="jax") diff --git a/pyproject.toml b/pyproject.toml index 6a0a18b..6137519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,7 +143,7 @@ output = ".pytype" report_errors = true [tool.pylint] -ignore-paths = 'mujoco_menagerie/experimental/**$' +ignore-paths = 'mujoco_playground/experimental/**$' [tool.hatch.build] include = [