Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft_retrace #695

Open
wants to merge 10 commits into
base: pytorch
Choose a base branch
from
Open

Conversation

zhuboli
Copy link

@zhuboli zhuboli commented Sep 29, 2020

Change code in file value_ops and td_loss. Default value for train_info is None. If we give the train_info parameter and lambda is not equal to 1 and 0, we will use retrace method. So we do not need to change the code of sac_algorithm or sarsa_algorithm when other people do not want retrace method.

Copy link
Contributor

@emailweixu emailweixu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -99,15 +102,37 @@ def forward(self, experience, value, target_value):
values=target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma)
else:
elif train_info == None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking whether train_info is None, you should add an argument in __init__ to indicate whether use retrace.
You should also change SarsaAlgorithm and SacAlgorithm to pass in train_info.

@@ -255,3 +255,36 @@ def generalized_advantage_estimation(rewards,
advs = advs.transpose(0, 1)

return advs.detach()
####### add for the retrace method
def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please comment following the way of other functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need unittest for this function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. line too long
  2. add space after ,
  3. comments for the function need to be added

@Haichao-Zhang Haichao-Zhang mentioned this pull request Oct 19, 2020
@@ -435,7 +435,7 @@ def calc_loss(self, experience, info: SarsaInfo):
target_critic = tensor_utils.tensor_prepend_zero(
info.target_critics)
loss_info = self._critic_losses[i](shifted_experience, critic,
target_critic)
target_critic,info)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add space after ,

@@ -31,6 +31,10 @@ def __init__(self,
td_error_loss_fn=element_wise_squared_loss,
td_lambda=0.95,
normalize_target=False,
some-feature-retrace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to be removed

some-feature-retrace
use_retrace=0,

pytorch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to be removed

@@ -76,8 +80,13 @@ def __init__(self,
self._debug_summaries = debug_summaries
self._normalize_target = normalize_target
self._target_normalizer = None
some-feature-retrace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove, seems to be the tags from a merge


def forward(self, experience, value, target_value):
pytorch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

else:
scope = alf.summary.scope(self.__class__.__name__)
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio(
action_distribution=train_info.action_distribution,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format, line is too long

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fixed?

Copy link
Contributor

@Haichao-Zhang Haichao-Zhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be many format issues. You may need to follow the workflow here to setup the formatting tools and also get a reference of coding standard:
https://alf.readthedocs.io/en/latest/contributing.html#workflow

@@ -46,7 +50,7 @@ def __init__(self,
:math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)`
where the generalized advantage estimation is defined as:
:math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))`

use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can change use_retrace use bool value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update comment


else:
scope = alf.summary.scope(self.__class__.__name__)
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add space after ,

@@ -255,3 +255,36 @@ def generalized_advantage_estimation(rewards,
advs = advs.transpose(0, 1)

return advs.detach()
####### add for the retrace method
def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. line too long
  2. add space after ,
  3. comments for the function need to be added

@@ -170,7 +170,32 @@ def test_generalized_advantage_estimation(self):
discounts=discounts,
td_lambda=td_lambda,
expected=expected)

class GeneralizedAdvantage_retrace_Test(unittest.TestCase):
"""Tests for alf.utils.value_ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments not correct

@@ -46,7 +50,7 @@ def __init__(self,
:math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)`
where the generalized advantage estimation is defined as:
:math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))`

use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update comment

log_prob_clipping=0.0,
scope=scope,
check_numerics=False,
debug_summaries=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug_summaries= debug_summaries



####### add for the retrace method
def generalized_advantage_estimation_retrace(importance_ratio, discounts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be merged with generalized_advantage_estimation function

else:
scope = alf.summary.scope(self.__class__.__name__)
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio(
action_distribution=train_info.action_distribution,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fixed?

@@ -91,6 +97,8 @@ def forward(self, experience, value, target_value):
target_value (torch.Tensor): the time-major tensor for the value at
each time step. This is used to calculate return. ``target_value``
can be same as ``value``.
train_info (sarsa info, sac info): information used to calcuate importance_ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is sarsa info, sac info here? Can this function be used with other algorithms beyond sac and sarsa?

@emailweixu emailweixu mentioned this pull request May 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants