You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Steps to reproduce:
Create a training loop, give it an output dir, and then train the model.
Then create a new training loop, give it that same output dir. It will try to load from a checkpoint but fail.
# Error logs:
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
/home/nick/.local/lib/python3.10/site-packages/trax/supervised/training.py:1388: SyntaxWarning: "is not" with a literal. Did you mean "!="?
return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[57], line 1
----> 1 training_loop = make_trainer(model, train_data, test_data,
2 trax.layers.Fn("ScaledLoss", ScaledLoss),
3 [trax.layers.Fn("ScaledLoss", ScaledLoss),
4 trax.layers.Fn("InRange", InRange),
5 trax.layers.Fn("ClippedMAE", ClippedMAE)],
6 "model_dff=4096")
File ~/srrf/srrf.py:104, in make_trainer(model, train_data, test_data, loss, metrics, name, optimizer, schedule, steps_per_cp, eval_batch_size)
94 train_task = trax.supervised.training.TrainTask(labeled_data=train_data,
95 loss_layer=loss,
96 optimizer=optimizer,
97 lr_schedule=schedule,
98 n_steps_per_checkpoint=500)
100 eval_task = trax.supervised.training.EvalTask(labeled_data=test_data,
101 metrics=metrics,
102 n_eval_batches=eval_batch_size)
--> 104 return trax.supervised.training.Loop(model, train_task, eval_tasks=[eval_task], output_dir=name)
File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:294, in Loop.__init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
289 layer.weights, layer.state = tl.on_cpu(self._unreplicate(
290 _make_weights_and_state_same_across_hosts(
291 self._for_n_devices(weights_and_state))))
293 # Load checkpoint if it exists.
--> 294 self.load_checkpoint()
296 # Prepare eval components.
297 self._eval_at = eval_at or default_at
File ~/.local/lib/python3.10/site-packages/trax/supervised/training.py:944, in Loop.load_checkpoint(self, directory, filename)
940 for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
941 matched_flat_slots = _match_by_shape(
942 self._to_bits(_flatten_and_remove_empty(trainer.slots)),
943 _flatten_and_remove_empty(slots))
--> 944 matched_slots, _ = fastmath.tree_unflatten(
945 self._from_bits(matched_flat_slots),
946 trainer.slots, copy_from_tree=[None, ()])
947 trainer.slots = matched_slots
948 self._step = d['step']
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:244, in tree_unflatten(flat, tree, copy_from_tree)
242 new_tree, rest = [], flat
243 for t in tree:
--> 244 new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree)
245 new_tree.append(new_t)
246 new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree
File ~/.local/lib/python3.10/site-packages/trax/fastmath/numpy.py:239, in tree_unflatten(flat, tree, copy_from_tree)
216 def tree_unflatten(flat, tree, copy_from_tree=None):
217 """Unflatten a list into a tree given the tree shape as second argument.
218
219 Args:
(...)
237 more were provided than the number of leaves of tree (useful for recursion).
238 """
--> 239 if copy_from_tree is not None and tree in copy_from_tree:
240 return tree, flat
241 if isinstance(tree, (list, tuple)):
File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:260, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
257 # Note: don't use isinstance here, because we don't want to raise for
258 # subclasses, e.g. NamedTuple objects that may override operators.
259 if type(other) in _rejected_binop_types:
--> 260 raise TypeError(f"unsupported operand type(s) for {opchar}: "
261 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
262 return NotImplemented
TypeError: unsupported operand type(s) for ==: 'ArrayImpl' and 'tuple'
The text was updated successfully, but these errors were encountered:
Note I can load the model itself with model.init_from_file(f"{outputdir}/model.pkl.gz") just fine. It's only when the training loop is created where it fails.
Description
I'm getting an error while trying to load a training loop from a checkpoint
Environment information
For bugs: reproduction and error logs
The text was updated successfully, but these errors were encountered: