Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error loading loop from a checkpoint #1790

Open
nwbvt opened this issue Oct 27, 2023 · 1 comment
Open

Error loading loop from a checkpoint #1790

nwbvt opened this issue Oct 27, 2023 · 1 comment

Comments

@nwbvt
Copy link

nwbvt commented Oct 27, 2023

Description

I'm getting an error while trying to load a training loop from a checkpoint

File ~/srrf/srrf.py:104, in make_trainer(model, train_data, test_data, loss, metrics, name, optimizer, schedule, steps_per_cp, eval_batch_size)
     94 train_task = trax.supervised.training.TrainTask(labeled_data=train_data,
     95                                                 loss_layer=loss,
     96                                                 optimizer=optimizer,
     97                                                 lr_schedule=schedule,
     98                                                 n_steps_per_checkpoint=500)
    100 eval_task = trax.supervised.training.EvalTask(labeled_data=test_data,
    101                                               metrics=metrics,
    102                                               n_eval_batches=eval_batch_size)
--> 104 return trax.supervised.training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=name)

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    289       layer.weights, layer.state = tl.on_cpu(self._unreplicate(
    290           _make_weights_and_state_same_across_hosts(
    291               self._for_n_devices(weights_and_state))))
    293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
    296 # Prepare eval components.
    297 self._eval_at = eval_at or default_at

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
    940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
    941   matched_flat_slots = _match_by_shape(
    942       self._to_bits(_flatten_and_remove_empty(trainer.slots)),
    943       _flatten_and_remove_empty(slots))
--> 944   matched_slots, _ = fastmath.tree_unflatten(
    945       self._from_bits(matched_flat_slots),
    946       trainer.slots, copy_from_tree=[None, ()])
    947   trainer.slots = matched_slots
    948 self._step = d['step']

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
    216 def tree_unflatten(flat, tree, copy_from_tree=None):
    217   """Unflatten a list into a tree given the tree shape as second argument.
    218 
    219   Args:
   (...)
    237     more were provided than the number of leaves of tree (useful for recursion).
    238   """
--> 239   if copy_from_tree is not None and tree in copy_from_tree:
    240     return tree, flat
    241   if isinstance(tree, (list, tuple)):

File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:260, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    257 # Note: don't use isinstance here, because we don't want to raise for
    258 # subclasses, e.g. NamedTuple objects that may override operators.
    259 if type(other) in _rejected_binop_types:
--> 260   raise TypeError(f"unsupported operand type(s) for {opchar}: "
    261                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
    262 return NotImplemented

TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'

Environment information

OS: Ubuntu

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
tensorboard==2.14.1
tensorboard-data-server==0.7.1
tensorflow==2.14.0
tensorflow-datasets==4.9.3
tensorflow-estimator==2.14.0
tensorflow-hub==0.15.0
tensorflow-io-gcs-filesystem==0.34.0
tensorflow-metadata==1.14.0
tensorflow-text==2.14.0

$ pip freeze | grep jax
jax==0.4.19
jaxlib==0.4.19+cuda12.cudnn89

$ python3 -V
Python 3.10.12

For bugs: reproduction and error logs

# Steps to reproduce:
Create a training loop, give it an output dir, and then train the model.
Then create a new training loop, give it that same output dir. It will try to load from a checkpoint but fail.
# Error logs:
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
  return [f for f in flat if f is not None and f is not ()]  # pylint: disable=literal-comparison
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[57], line 1
----> 1 training_loop = make_trainer(model, train_data, test_data,
      2                              trax.layers.Fn("ScaledLoss", ScaledLoss),
      3                              [trax.layers.Fn("ScaledLoss", ScaledLoss),
      4                               trax.layers.Fn("InRange", InRange), 
      5                               trax.layers.Fn("ClippedMAE", ClippedMAE)],
      6                              "model_dff=4096")

File ~/srrf/srrf.py:104, in make_trainer(model, train_data, test_data, loss, metrics, name, optimizer, schedule, steps_per_cp, eval_batch_size)
     94 train_task = trax.supervised.training.TrainTask(labeled_data=train_data,
     95                                                 loss_layer=loss,
     96                                                 optimizer=optimizer,
     97                                                 lr_schedule=schedule,
     98                                                 n_steps_per_checkpoint=500)
    100 eval_task = trax.supervised.training.EvalTask(labeled_data=test_data,
    101                                               metrics=metrics,
    102                                               n_eval_batches=eval_batch_size)
--> 104 return trax.supervised.training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=name)

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    289       layer.weights, layer.state = tl.on_cpu(self._unreplicate(
    290           _make_weights_and_state_same_across_hosts(
    291               self._for_n_devices(weights_and_state))))
    293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
    296 # Prepare eval components.
    297 self._eval_at = eval_at or default_at

File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
    940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
    941   matched_flat_slots = _match_by_shape(
    942       self._to_bits(_flatten_and_remove_empty(trainer.slots)),
    943       _flatten_and_remove_empty(slots))
--> 944   matched_slots, _ = fastmath.tree_unflatten(
    945       self._from_bits(matched_flat_slots),
    946       trainer.slots, copy_from_tree=[None, ()])
    947   trainer.slots = matched_slots
    948 self._step = d['step']

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
    242 new_tree, rest = [], flat
    243 for t in tree:
--> 244   new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
    245   new_tree.append(new_t)
    246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree

File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
    216 def tree_unflatten(flat, tree, copy_from_tree=None):
    217   """Unflatten a list into a tree given the tree shape as second argument.
    218 
    219   Args:
   (...)
    237     more were provided than the number of leaves of tree (useful for recursion).
    238   """
--> 239   if copy_from_tree is not None and tree in copy_from_tree:
    240     return tree, flat
    241   if isinstance(tree, (list, tuple)):

File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:260, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    257 # Note: don't use isinstance here, because we don't want to raise for
    258 # subclasses, e.g. NamedTuple objects that may override operators.
    259 if type(other) in _rejected_binop_types:
--> 260   raise TypeError(f"unsupported operand type(s) for {opchar}: "
    261                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
    262 return NotImplemented

TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'
@nwbvt
Copy link
Author

nwbvt commented Oct 27, 2023

Note I can load the model itself with model.init_from_file(f"{outputdir}/model.pkl.gz") just fine. It's only when the training loop is created where it fails.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant