Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with Optimizer Update in A2C Network with Optax Body: #4391

Open
Tomato-toast opened this issue Nov 20, 2024 · 6 comments
Open

Issue with Optimizer Update in A2C Network with Optax Body: #4391

Tomato-toast opened this issue Nov 20, 2024 · 6 comments

Comments

@Tomato-toast
Copy link

Hello everyone,

I've encountered a problem while implementing an A2C (Advantage Actor-Critic) network involving Flax and Optax. My network includes policy_network and value_network, each containing policy_head and torso. When attempting to use optimizer.update(grad), I received the following error:

ValueError: Mismatch custom node data: ('policy_head', 'torso') != ('policy_network', 'value_network');

The error message indicates that the expected keys are ('policy_network', 'value_network'), but the actual provided keys are ('policy_head', 'torso'). The structure of my model parameters is as follows:

State({
'policy_network': {
'policy_head': {...},
'torso': {...},
},
'value_network': {
'policy_head': {...},
'torso': {...},
})

I have tried to combine the model parameters and pass them to the optimizer, like this:

params = {'w1': model1_params, 'w2': model2_params}

However, this approach did not resolve the issue. I'm wondering if there is another way to correctly initialize and update the parameters of the A2C network's parameters using Optax in Flax.

If you have any suggestions or need more information, please let me know. Thank you very much for your help!

@cgarciae
Copy link
Collaborator

Hi @Tomato-toast, can you post some psuedo code of how you are constructing the Optimizer and gradients?

@Tomato-toast
Copy link
Author

Tomato-toast commented Nov 21, 2024

Hi @Tomato-toast, can you post some psuedo code of how you are constructing the Optimizer and gradients?

Below is a pseudo-code example of how the Optimizer and gradients are constructed and applied:

    class ConnectorTorso(nnx.Module):
        def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
            self.rngs = rngs
            self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
        def __call__(self, x):
            x = self.linear(x)
            return x

    def make_actor_network_connector(rngs = nnx.Rngs(0)):
        class PolicyNetwork(nnx.Module):
            def __init__(self, in_features, out_features):
                self.torso = ConnectorTorso(in_features, out_features, rngs=rngs)
    
            def __call__(self, x):
                return self.torso(x)
    
        return PolicyNetwork

    def make_critic_network_connector(rngs = nnx.Rngs(0)):
        class CriticNetwork(nnx.Module):
            def __init__(self, in_features, out_features):
                self.torso = ConnectorTorso(in_features, out_features, rngs=rngs)
    
            def __call__(self, x):
                return self.torso(x)
    
        return CriticNetwork
            
    class A2CAgent:
        def __init__(self):
            self.optimizer = nnx.Optimizer(
                optax.adam(learning_rate=1e-3)
            )
        # Initialize the policy network and value network parameters.
        def init_params(self, key: chex.PRNGKey) -> ParamsState:
            _, policy_params, _= nnx.split(self.actor_critic_networks.policy_network, nnx.Param, ...)
            _, critic_params, _ = nnx.split(self.actor_critic_networks.value_network, nnx.Param, ...)
      
            params = ActorCriticParams(
                actor = policy_params,
                critic = critic_params,
            )
            params_state = ParamsState(
                params=params,
                opt_state=opt_state,
                update_count=jnp.array(0, float),
            )
            return params_state
    
        def a2c_loss(self, policy_network, params, observations, actions, returns):
            # Calculate the outputs of the policy and value networks
            policy_output = policy_network(params.actor, observations)
            critic_output = policy_network(params.critic, observations)
    
            # Policy Loss: Based on the Advantage Function
            advantages = returns - critic_output
            policy_loss = -jnp.mean(jnp.log(policy_output) * advantages)
    
            # Value Loss: Mean Squared Error (MSE)
            critic_loss = jnp.mean((critic_output - returns) ** 2)
    
            # Entropy Loss: Encouraging Policy Exploration
            entropy_loss = -jnp.mean(policy_output * jnp.log(policy_output + 1e-8))
    
            # total loss
            return policy_loss + critic_loss - 0.01 * entropy_loss  # 熵损失系数为0.01
    
    
        # Execute a training epoch and update the parameters.
        def run_epoch(self, training_state: TrainingState) -> Tuple[TrainingState, Dict]:
            grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)
            ((policy_output, critic_output), metrics), grad = grad_fn(actor_critic_networks.policy_network, params, training_state.acting_state)
            grad = optax.clip_by_global_norm(grad, max_norm=1.0)
            updates, opt_state = self.optimizer.update(grad, training_state)
            params = optax.apply_updates(training_state.params_state.params, updates)
            training_state = TrainingState(
                params_state=ParamsState(
                    params=params,  
                    opt_state=opt_state,           
                    update_count=training_state.params_state.update_count + 1,
                ),
                acting_state=acting_state,  
            )
            return training_state, metrics

Thanks!

@stergiosba
Copy link

stergiosba commented Nov 22, 2024

I would just offer my input here and some suggestions based on my relatively short experience with NNX.

I noticed you are using the flax.Linen.TrainState and you also split the graphdef and parameters using nnx.split thus I am going to assume you need backwards compatibility with Linen. If that is not the case, you should be happy to know that in flax.nnx you don't have to do this anymore, at list for this simple example. With that being said, splitting the parameters and using trainstate is a perfectly fine working option (I don't like it personally, that's why I switched from linen to nnx recently).

Ok, to the matter at hand, the problem here is with the a2c_loss function definition and the way it's transformed with nnx.value_and_grad.

You have:

def a2c_loss(self, policy_network, params, observations, actions, returns):

When you transform with nnx.value_and_grad you are taking the derivative with respect to the argnums argument in the definition of the nnx.value_and_grad as seen here:

flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())

You see by default argnums=0 which you leave on default in your code as seen here:

grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)

Thus you are taking the derivative with respect to self in a2c_loss. Now you can change it to the following and tell us what you get:

grad_fn = nnx.value_and_grad(self.a2c_loss, argnums=1, has_aux=True)

This will take the derivative with respect to params. (Corrected after @cgarciae correctly pointed out a missed detail on my part.)

On a more general note, you can simplify your code significantly. For example, ConnectorTorso and the two heads, actor and critic could be combined as this will probably lead to faster compilation if in the future your aim is to make a more complex model, just some food for thought.

@cgarciae
Copy link
Collaborator

cgarciae commented Nov 22, 2024

@Tomato-toast how are the policy_network functions implemented?

policy_output = policy_network(params.actor, observations)
critic_output = policy_network(params.critic, observations)

They seem to be Modules that take in their params which is a bit peculiar.

Since you are using a functional style training loop, I'd recommend to storing the graphdefs and using nnx.merge to reconstruct the Modules inside the loss function before calling them. Check out this examples/nnx_toy_examples/03_train_state.py that shows how to use NNX with TrainState.

Regarding the argnum situation, small correction to what @stergiosba pointed out, yes you should change the argnum to match the params position but because self is passed via a bound method it doesn't count so it should be argnums=1.

@Tomato-toast
Copy link
Author

Hi @Tomato-toast, can you post some psuedo code of how you are constructing the Optimizer and gradients?

Below is a pseudo-code example of how the Optimizer and gradients are constructed and applied:

    class ConnectorTorso(nnx.Module):
        def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
            self.rngs = rngs
            self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
        def __call__(self, x):
            x = self.linear(x)
            return x

    def make_actor_network_connector(rngs = nnx.Rngs(0)):
        class PolicyNetwork(nnx.Module):
            def __init__(self, in_features, out_features):
                self.torso = ConnectorTorso(in_features, out_features, rngs=rngs)
    
            def __call__(self, x):
                return self.torso(x)
    
        return PolicyNetwork

    def make_critic_network_connector(rngs = nnx.Rngs(0)):
        class CriticNetwork(nnx.Module):
            def __init__(self, in_features, out_features):
                self.torso = ConnectorTorso(in_features, out_features, rngs=rngs)
    
            def __call__(self, x):
                return self.torso(x)
    
        return CriticNetwork
            
    class A2CAgent:
        def __init__(self):
            self.optimizer = nnx.Optimizer(
                optax.adam(learning_rate=1e-3)
            )
        # Initialize the policy network and value network parameters.
        def init_params(self, key: chex.PRNGKey) -> ParamsState:
            _, policy_params, _= nnx.split(self.actor_critic_networks.policy_network, nnx.Param, ...)
            _, critic_params, _ = nnx.split(self.actor_critic_networks.value_network, nnx.Param, ...)
      
            params = ActorCriticParams(
                actor = policy_params,
                critic = critic_params,
            )
            params_state = ParamsState(
                params=params,
                opt_state=opt_state,
                update_count=jnp.array(0, float),
            )
            return params_state
    
        def a2c_loss(self, policy_network, params, observations, actions, returns):
            # Calculate the outputs of the policy and value networks
            policy_output = policy_network(params.actor, observations)
            critic_output = policy_network(params.critic, observations)
    
            # Policy Loss: Based on the Advantage Function
            advantages = returns - critic_output
            policy_loss = -jnp.mean(jnp.log(policy_output) * advantages)
    
            # Value Loss: Mean Squared Error (MSE)
            critic_loss = jnp.mean((critic_output - returns) ** 2)
    
            # Entropy Loss: Encouraging Policy Exploration
            entropy_loss = -jnp.mean(policy_output * jnp.log(policy_output + 1e-8))
    
            # total loss
            return policy_loss + critic_loss - 0.01 * entropy_loss  # 熵损失系数为0.01
    
    
        # Execute a training epoch and update the parameters.
        def run_epoch(self, training_state: TrainingState) -> Tuple[TrainingState, Dict]:
            grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)
            ((policy_output, critic_output), metrics), grad = grad_fn(actor_critic_networks.policy_network, params, training_state.acting_state)
            grad = optax.clip_by_global_norm(grad, max_norm=1.0)
            updates, opt_state = self.optimizer.update(grad, training_state)
            params = optax.apply_updates(training_state.params_state.params, updates)
            training_state = TrainingState(
                params_state=ParamsState(
                    params=params,  
                    opt_state=opt_state,           
                    update_count=training_state.params_state.update_count + 1,
                ),
                acting_state=acting_state,  
            )
            return training_state, metrics

Thanks!

I would just offer my input here and some suggestions based on my relatively short experience with NNX.

I noticed you are using the flax.Linen.TrainState and you also split the graphdef and parameters using nnx.split thus I am going to assume you need backwards compatibility with Linen. If that is not the case, you should be happy to know that in flax.nnx you don't have to do this anymore, at list for this simple example. With that being said, splitting the parameters and using trainstate is a perfectly fine working option (I don't like it personally, that's why I switched from linen to nnx recently).

Ok, to the matter at hand, the problem here is with the a2c_loss function definition and the way it's transformed with nnx.value_and_grad.

You have:

def a2c_loss(self, policy_network, params, observations, actions, returns):
When you transform with nnx.value_and_grad you are taking the derivative with respect to the argnums argument in the definition of the nnx.value_and_grad as seen here:

flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())
You see by default argnums=0 which you leave on default in your code as seen here:

grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)
Thus you are taking the derivative with respect to self in a2c_loss. Now you can change it to the following and tell us what you get:

grad_fn = nnx.value_and_grad(self.a2c_loss, argnums=1, has_aux=True)
This will take the derivative with respect to params. (Corrected after @cgarciae correctly pointed out a missed detail on my part.)

On a more general note, you can simplify your code significantly. For example, ConnectorTorso and the two heads, actor and critic could be combined as this will probably lead to faster compilation if in the future your aim is to make a more complex model, just some food for thought.

I would just offer my input here and some suggestions based on my relatively short experience with NNX.

I noticed you are using the flax.Linen.TrainState and you also split the graphdef and parameters using nnx.split thus I am going to assume you need backwards compatibility with Linen. If that is not the case, you should be happy to know that in flax.nnx you don't have to do this anymore, at list for this simple example. With that being said, splitting the parameters and using trainstate is a perfectly fine working option (I don't like it personally, that's why I switched from linen to nnx recently).

Ok, to the matter at hand, the problem here is with the a2c_loss function definition and the way it's transformed with nnx.value_and_grad.

You have:

def a2c_loss(self, policy_network, params, observations, actions, returns):
When you transform with nnx.value_and_grad you are taking the derivative with respect to the argnums argument in the definition of the nnx.value_and_grad as seen here:

flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())
You see by default argnums=0 which you leave on default in your code as seen here:

grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)
Thus you are taking the derivative with respect to self in a2c_loss. Now you can change it to the following and tell us what you get:

grad_fn = nnx.value_and_grad(self.a2c_loss, argnums=1, has_aux=True)
This will take the derivative with respect to params. (Corrected after @cgarciae correctly pointed out a missed detail on my part.)

On a more general note, you can simplify your code significantly. For example, ConnectorTorso and the two heads, actor and critic could be combined as this will probably lead to faster compilation if in the future your aim is to make a more complex model, just some food for thought.

I would like to extend my heartfelt gratitude for your previous assistance and suggestions. Following your guidance, I have attempted to modify the code from

    grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True)

to

    grad_fn = nnx.value_and_grad(self.a2c_loss, argnums=1, has_aux=True).

However, I have encountered a new error with the following message:

    ValueError: Expected named tuple, got State({
      'policy_network': {
        'policy_head': {...},
        'torso': {...},
      },
      'value_network': {
        'policy_head': {...},
        'torso': {...},
    })

This indicates that the issue may not be related to the argnums parameter. Therefore, I suspect that the error might stem from another part of the code. Could you please provide any additional solutions or suggestions that might assist me in further diagnosing and resolving this issue?

Moreover, I have taken your previous advice regarding code optimization under serious consideration and will incorporate it into my future work plans. Thank you once again for your valuable insights and support.

@Tomato-toast
Copy link
Author

@Tomato-toast how are the policy_network functions implemented?

policy_output = policy_network(params.actor, observations)
critic_output = policy_network(params.critic, observations)
They seem to be Modules that take in their params which is a bit peculiar.

Since you are using a functional style training loop, I'd recommend to storing the graphdefs and using nnx.merge to reconstruct the Modules inside the loss function before calling them. Check out this examples/nnx_toy_examples/03_train_state.py that shows how to use NNX with TrainState.

Regarding the argnum situation, small correction to what @stergiosba pointed out, yes you should change the argnum to match the params position but because self is passed via a bound method it doesn't count so it should be argnums=1.

Thank you very much for your valuable suggestion! Following your guidance, I referred to the example in examples/nnx_toy_examples/03_train_state.py and attempted to use the nnx.merge function to reconstruct the module inside the loss function. However, I encountered the same error message as when I only modified argnums=1.

Here are the details of my attempt and the error message:

Example code

    grad_fn = nnx.value_and_grad(self.a2c_loss, argnums=0, has_aux=True)
   
    class TrainState(train_state.TrainState):
     counts: nnx.State
     graphdef: nnx.GraphDef
    graphdef, params, counts = nnx.split(self.actor_critic_networks.value_network, nnx.Param, ...)

    state = TrainState.create(
                    apply_fn=None,
                    graphdef=graphdef,
                    params=params,
                    tx=optax.sgd(0.1),
                    counts=counts,
                    )
    del params, counts
    
    (acting_state, metrics), grad = grad_fn(params1, training_state.acting_state, state)

 ...

def a2c_loss(
    self,
    params: ActorCriticParams,
    acting_state: ActingState,
    state
) -> Tuple[float, Tuple[ActingState, Dict]]:
    parametric_action_distribution = (
        self.actor_critic_networks.parametric_action_distribution
    )
    # value_model = self.actor_critic_networks.value_network
    value_model = nnx.merge(state.graphdef, state.params, state.counts)
....

Error message

      ValueError: Expected named tuple, got State({
        'policy_network': {
          'policy_head': {...},
          'torso': {...},
        },
        'value_network': {
          'policy_head': {...},
          'torso': {...},
      })

At this point, I’m unsure whether the issue could be related to:

The way graphdef is defined—does it require specific attention?
The nnx.merge usage—could I be missing something or misconfiguring it?
The initialization of TrainState—is there a potential problem here?
If possible, could you provide further guidance or clarify how to properly combine these components to achieve the desired outcome? I deeply appreciate your time and assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants